ewaast-demo / medgemma_client.py
Nursing Citizen Development
Fix: Use greedy decoding to avoid CUDA NaN probability error
9af128b
# Copyright 2025 - EWAAST Project
# Adapted from Google's appoint-ready architecture
# Licensed under the Apache License, Version 2.0
"""
MedGemma Client
Wrapper for MedGemma inference, supporting both:
1. Vertex AI endpoint (production)
2. Local Transformers model (development)
"""
import os
import json
import requests
from cache import cache
# ===== CONFIGURATION =====
USE_LOCAL_MODEL = os.environ.get("USE_LOCAL_MODEL", "true").lower() == "true"
LOCAL_MODEL_PATH = os.environ.get("LOCAL_MODEL_PATH", "models/ewaast-medgemma-finetuned")
# Vertex AI configuration
VERTEX_ENDPOINT = os.environ.get("GCP_MEDGEMMA_ENDPOINT")
SERVICE_ACCOUNT_KEY = os.environ.get("GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY")
# ===== LOCAL MODEL (Development) =====
_local_model = None
_local_processor = None
def _load_local_model():
"""Lazy-load the local MedGemma model."""
global _local_model, _local_processor
if _local_model is None:
try:
from transformers import AutoModelForImageTextToText, AutoProcessor
import torch
print(f"Loading local MedGemma model from {LOCAL_MODEL_PATH}...")
_local_processor = AutoProcessor.from_pretrained(
LOCAL_MODEL_PATH,
trust_remote_code=True
)
# Load with 4-bit quantization if bitsandbytes is available
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("Loading with 4-bit quantization...")
_local_model = AutoModelForImageTextToText.from_pretrained(
LOCAL_MODEL_PATH,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
except ImportError:
print("BitsAndBytes not found, loading full precision (may OOM)...")
_local_model = AutoModelForImageTextToText.from_pretrained(
LOCAL_MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
print("Local MedGemma model loaded successfully!")
except Exception as e:
print(f"Failed to load local model: {e}")
raise
return _local_model, _local_processor
def _local_inference(messages: list, max_tokens: int = 2048) -> str:
"""Run inference using the local model (Gemma 3 / MedGemma 1.5 format)."""
model, processor = _load_local_model()
import torch
from PIL import Image
import io
import base64
# Convert OpenAI-style messages to Gemma 3 chat format
# Gemma 3 requires strict alternation: user/assistant/user/assistant
# For our use case (single inference), we merge all content into ONE user message
all_images = []
all_text_parts = []
for msg in messages:
content = msg.get("content", [])
if isinstance(content, str):
all_text_parts.append(content)
elif isinstance(content, list):
for part in content:
if part.get("type") == "text":
all_text_parts.append(part["text"])
elif part.get("type") == "image_url":
# Handle base64 image
img_url = part["image_url"]["url"]
if img_url.startswith("data:image"):
base64_str = img_url.split(",")[1]
img_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(img_data)).convert("RGB")
all_images.append(image)
# Build a single user message with images first, then text
gemma_content = []
for img in all_images:
gemma_content.append({"type": "image", "image": img})
gemma_content.append({"type": "text", "text": "\n\n".join(all_text_parts)})
gemma_messages = [{"role": "user", "content": gemma_content}]
# Use the processor's apply_chat_template to properly format the prompt
# This automatically handles image tokens for Gemma 3
inputs = processor.apply_chat_template(
gemma_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
# Ensure images are included if there are any
if all_images:
input_len = inputs["input_ids"].shape[-1]
else:
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False # Greedy decoding - more stable with quantized models
)
# Decode only the new tokens (skip input)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
return response.strip()
# ===== VERTEX AI (Production) =====
_vertex_credentials = None
def _get_vertex_credentials():
"""Get or refresh Vertex AI credentials."""
global _vertex_credentials
if _vertex_credentials is None and SERVICE_ACCOUNT_KEY:
from google.oauth2 import service_account
import google.auth.transport.requests
service_account_info = json.loads(SERVICE_ACCOUNT_KEY)
_vertex_credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
if _vertex_credentials:
# Refresh if needed
import datetime
import google.auth.transport.requests
if _vertex_credentials.expiry:
time_remaining = _vertex_credentials.expiry.replace(tzinfo=datetime.timezone.utc) - datetime.datetime.now(datetime.timezone.utc)
if time_remaining < datetime.timedelta(minutes=5):
request = google.auth.transport.requests.Request()
_vertex_credentials.refresh(request)
else:
request = google.auth.transport.requests.Request()
_vertex_credentials.refresh(request)
return _vertex_credentials
def _vertex_inference(messages: list, max_tokens: int = 2048, temperature: float = 0.1) -> str:
"""Run inference using Vertex AI endpoint."""
credentials = _get_vertex_credentials()
headers = {
"Authorization": f"Bearer {credentials.token}",
"Content-Type": "application/json",
}
payload = {
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
response = requests.post(VERTEX_ENDPOINT, headers=headers, json=payload, timeout=60)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
# ===== PUBLIC API =====
@cache.memoize()
def medgemma_get_text_response(
messages: list,
temperature: float = 0.1,
max_tokens: int = 2048,
) -> str:
"""
Get a text response from MedGemma.
Automatically selects between local model and Vertex AI
based on configuration.
Args:
messages: List of message dictionaries in OpenAI format
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Generated text response
"""
if USE_LOCAL_MODEL:
return _local_inference(messages, max_tokens)
elif VERTEX_ENDPOINT and SERVICE_ACCOUNT_KEY:
return _vertex_inference(messages, max_tokens, temperature)
else:
raise RuntimeError(
"No MedGemma inference method configured. "
"Set USE_LOCAL_MODEL=true or provide GCP_MEDGEMMA_ENDPOINT."
)