Spaces:
Sleeping
Sleeping
| # Copyright 2025 - EWAAST Project | |
| # Adapted from Google's appoint-ready architecture | |
| # Licensed under the Apache License, Version 2.0 | |
| """ | |
| MedGemma Client | |
| Wrapper for MedGemma inference, supporting both: | |
| 1. Vertex AI endpoint (production) | |
| 2. Local Transformers model (development) | |
| """ | |
| import os | |
| import json | |
| import requests | |
| from cache import cache | |
| # ===== CONFIGURATION ===== | |
| USE_LOCAL_MODEL = os.environ.get("USE_LOCAL_MODEL", "true").lower() == "true" | |
| LOCAL_MODEL_PATH = os.environ.get("LOCAL_MODEL_PATH", "models/ewaast-medgemma-finetuned") | |
| # Vertex AI configuration | |
| VERTEX_ENDPOINT = os.environ.get("GCP_MEDGEMMA_ENDPOINT") | |
| SERVICE_ACCOUNT_KEY = os.environ.get("GCP_MEDGEMMA_SERVICE_ACCOUNT_KEY") | |
| # ===== LOCAL MODEL (Development) ===== | |
| _local_model = None | |
| _local_processor = None | |
| def _load_local_model(): | |
| """Lazy-load the local MedGemma model.""" | |
| global _local_model, _local_processor | |
| if _local_model is None: | |
| try: | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| import torch | |
| print(f"Loading local MedGemma model from {LOCAL_MODEL_PATH}...") | |
| _local_processor = AutoProcessor.from_pretrained( | |
| LOCAL_MODEL_PATH, | |
| trust_remote_code=True | |
| ) | |
| # Load with 4-bit quantization if bitsandbytes is available | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| print("Loading with 4-bit quantization...") | |
| _local_model = AutoModelForImageTextToText.from_pretrained( | |
| LOCAL_MODEL_PATH, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| except ImportError: | |
| print("BitsAndBytes not found, loading full precision (may OOM)...") | |
| _local_model = AutoModelForImageTextToText.from_pretrained( | |
| LOCAL_MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("Local MedGemma model loaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to load local model: {e}") | |
| raise | |
| return _local_model, _local_processor | |
| def _local_inference(messages: list, max_tokens: int = 2048) -> str: | |
| """Run inference using the local model (Gemma 3 / MedGemma 1.5 format).""" | |
| model, processor = _load_local_model() | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| # Convert OpenAI-style messages to Gemma 3 chat format | |
| # Gemma 3 requires strict alternation: user/assistant/user/assistant | |
| # For our use case (single inference), we merge all content into ONE user message | |
| all_images = [] | |
| all_text_parts = [] | |
| for msg in messages: | |
| content = msg.get("content", []) | |
| if isinstance(content, str): | |
| all_text_parts.append(content) | |
| elif isinstance(content, list): | |
| for part in content: | |
| if part.get("type") == "text": | |
| all_text_parts.append(part["text"]) | |
| elif part.get("type") == "image_url": | |
| # Handle base64 image | |
| img_url = part["image_url"]["url"] | |
| if img_url.startswith("data:image"): | |
| base64_str = img_url.split(",")[1] | |
| img_data = base64.b64decode(base64_str) | |
| image = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| all_images.append(image) | |
| # Build a single user message with images first, then text | |
| gemma_content = [] | |
| for img in all_images: | |
| gemma_content.append({"type": "image", "image": img}) | |
| gemma_content.append({"type": "text", "text": "\n\n".join(all_text_parts)}) | |
| gemma_messages = [{"role": "user", "content": gemma_content}] | |
| # Use the processor's apply_chat_template to properly format the prompt | |
| # This automatically handles image tokens for Gemma 3 | |
| inputs = processor.apply_chat_template( | |
| gemma_messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| # Ensure images are included if there are any | |
| if all_images: | |
| input_len = inputs["input_ids"].shape[-1] | |
| else: | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False # Greedy decoding - more stable with quantized models | |
| ) | |
| # Decode only the new tokens (skip input) | |
| response = processor.decode(outputs[0][input_len:], skip_special_tokens=True) | |
| return response.strip() | |
| # ===== VERTEX AI (Production) ===== | |
| _vertex_credentials = None | |
| def _get_vertex_credentials(): | |
| """Get or refresh Vertex AI credentials.""" | |
| global _vertex_credentials | |
| if _vertex_credentials is None and SERVICE_ACCOUNT_KEY: | |
| from google.oauth2 import service_account | |
| import google.auth.transport.requests | |
| service_account_info = json.loads(SERVICE_ACCOUNT_KEY) | |
| _vertex_credentials = service_account.Credentials.from_service_account_info( | |
| service_account_info, | |
| scopes=['https://www.googleapis.com/auth/cloud-platform'] | |
| ) | |
| if _vertex_credentials: | |
| # Refresh if needed | |
| import datetime | |
| import google.auth.transport.requests | |
| if _vertex_credentials.expiry: | |
| time_remaining = _vertex_credentials.expiry.replace(tzinfo=datetime.timezone.utc) - datetime.datetime.now(datetime.timezone.utc) | |
| if time_remaining < datetime.timedelta(minutes=5): | |
| request = google.auth.transport.requests.Request() | |
| _vertex_credentials.refresh(request) | |
| else: | |
| request = google.auth.transport.requests.Request() | |
| _vertex_credentials.refresh(request) | |
| return _vertex_credentials | |
| def _vertex_inference(messages: list, max_tokens: int = 2048, temperature: float = 0.1) -> str: | |
| """Run inference using Vertex AI endpoint.""" | |
| credentials = _get_vertex_credentials() | |
| headers = { | |
| "Authorization": f"Bearer {credentials.token}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature | |
| } | |
| response = requests.post(VERTEX_ENDPOINT, headers=headers, json=payload, timeout=60) | |
| response.raise_for_status() | |
| return response.json()["choices"][0]["message"]["content"] | |
| # ===== PUBLIC API ===== | |
| def medgemma_get_text_response( | |
| messages: list, | |
| temperature: float = 0.1, | |
| max_tokens: int = 2048, | |
| ) -> str: | |
| """ | |
| Get a text response from MedGemma. | |
| Automatically selects between local model and Vertex AI | |
| based on configuration. | |
| Args: | |
| messages: List of message dictionaries in OpenAI format | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| Returns: | |
| Generated text response | |
| """ | |
| if USE_LOCAL_MODEL: | |
| return _local_inference(messages, max_tokens) | |
| elif VERTEX_ENDPOINT and SERVICE_ACCOUNT_KEY: | |
| return _vertex_inference(messages, max_tokens, temperature) | |
| else: | |
| raise RuntimeError( | |
| "No MedGemma inference method configured. " | |
| "Set USE_LOCAL_MODEL=true or provide GCP_MEDGEMMA_ENDPOINT." | |
| ) | |