# Copyright 2025 - EWAAST Project # Adapted from Google's appoint-ready architecture # Licensed under the Apache License, Version 2.0 """ MedGemma Client Wrapper for MedGemma inference, supporting both: 1. Vertex AI endpoint (production) 2. Local Transformers model (development) """ import os import json import requests from cache import cache # ===== CONFIGURATION ===== USE_LOCAL_MODEL = os.environ.get("USE_LOCAL_MODEL", "true").lower() == "true" LOCAL_MODEL_PATH = os.environ.get("LOCAL_MODEL_PATH", "models/ewaast-medgemma-finetuned") # Vertex AI configuration VERTEX_ENDPOINT = os.environ.get("GCP_MEDGEMMA_ENDPOINT") SERVICE_ACCOUNT_KEY = os.environ.get("GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY") # ===== LOCAL MODEL (Development) ===== _local_model = None _local_processor = None def _load_local_model(): """Lazy-load the local MedGemma model.""" global _local_model, _local_processor if _local_model is None: try: from transformers import AutoModelForImageTextToText, AutoProcessor import torch print(f"Loading local MedGemma model from {LOCAL_MODEL_PATH}...") _local_processor = AutoProcessor.from_pretrained( LOCAL_MODEL_PATH, trust_remote_code=True ) # Load with 4-bit quantization if bitsandbytes is available try: from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) print("Loading with 4-bit quantization...") _local_model = AutoModelForImageTextToText.from_pretrained( LOCAL_MODEL_PATH, quantization_config=quantization_config, device_map="auto", trust_remote_code=True ) except ImportError: print("BitsAndBytes not found, loading full precision (may OOM)...") _local_model = AutoModelForImageTextToText.from_pretrained( LOCAL_MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) print("Local MedGemma model loaded successfully!") except Exception as e: print(f"Failed to load local model: {e}") raise return _local_model, _local_processor def _local_inference(messages: list, max_tokens: int = 2048) -> str: """Run inference using the local model (Gemma 3 / MedGemma 1.5 format).""" model, processor = _load_local_model() import torch from PIL import Image import io import base64 # Convert OpenAI-style messages to Gemma 3 chat format # Gemma 3 requires strict alternation: user/assistant/user/assistant # For our use case (single inference), we merge all content into ONE user message all_images = [] all_text_parts = [] for msg in messages: content = msg.get("content", []) if isinstance(content, str): all_text_parts.append(content) elif isinstance(content, list): for part in content: if part.get("type") == "text": all_text_parts.append(part["text"]) elif part.get("type") == "image_url": # Handle base64 image img_url = part["image_url"]["url"] if img_url.startswith("data:image"): base64_str = img_url.split(",")[1] img_data = base64.b64decode(base64_str) image = Image.open(io.BytesIO(img_data)).convert("RGB") all_images.append(image) # Build a single user message with images first, then text gemma_content = [] for img in all_images: gemma_content.append({"type": "image", "image": img}) gemma_content.append({"type": "text", "text": "\n\n".join(all_text_parts)}) gemma_messages = [{"role": "user", "content": gemma_content}] # Use the processor's apply_chat_template to properly format the prompt # This automatically handles image tokens for Gemma 3 inputs = processor.apply_chat_template( gemma_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) # Ensure images are included if there are any if all_images: input_len = inputs["input_ids"].shape[-1] else: input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False # Greedy decoding - more stable with quantized models ) # Decode only the new tokens (skip input) response = processor.decode(outputs[0][input_len:], skip_special_tokens=True) return response.strip() # ===== VERTEX AI (Production) ===== _vertex_credentials = None def _get_vertex_credentials(): """Get or refresh Vertex AI credentials.""" global _vertex_credentials if _vertex_credentials is None and SERVICE_ACCOUNT_KEY: from google.oauth2 import service_account import google.auth.transport.requests service_account_info = json.loads(SERVICE_ACCOUNT_KEY) _vertex_credentials = service_account.Credentials.from_service_account_info( service_account_info, scopes=['https://www.googleapis.com/auth/cloud-platform'] ) if _vertex_credentials: # Refresh if needed import datetime import google.auth.transport.requests if _vertex_credentials.expiry: time_remaining = _vertex_credentials.expiry.replace(tzinfo=datetime.timezone.utc) - datetime.datetime.now(datetime.timezone.utc) if time_remaining < datetime.timedelta(minutes=5): request = google.auth.transport.requests.Request() _vertex_credentials.refresh(request) else: request = google.auth.transport.requests.Request() _vertex_credentials.refresh(request) return _vertex_credentials def _vertex_inference(messages: list, max_tokens: int = 2048, temperature: float = 0.1) -> str: """Run inference using Vertex AI endpoint.""" credentials = _get_vertex_credentials() headers = { "Authorization": f"Bearer {credentials.token}", "Content-Type": "application/json", } payload = { "messages": messages, "max_tokens": max_tokens, "temperature": temperature } response = requests.post(VERTEX_ENDPOINT, headers=headers, json=payload, timeout=60) response.raise_for_status() return response.json()["choices"][0]["message"]["content"] # ===== PUBLIC API ===== @cache.memoize() def medgemma_get_text_response( messages: list, temperature: float = 0.1, max_tokens: int = 2048, ) -> str: """ Get a text response from MedGemma. Automatically selects between local model and Vertex AI based on configuration. Args: messages: List of message dictionaries in OpenAI format temperature: Sampling temperature max_tokens: Maximum tokens to generate Returns: Generated text response """ if USE_LOCAL_MODEL: return _local_inference(messages, max_tokens) elif VERTEX_ENDPOINT and SERVICE_ACCOUNT_KEY: return _vertex_inference(messages, max_tokens, temperature) else: raise RuntimeError( "No MedGemma inference method configured. " "Set USE_LOCAL_MODEL=true or provide GCP_MEDGEMMA_ENDPOINT." )