Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Optional | |
| from litellm import completion | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class LLMClient: | |
| """LLM client using LiteLLM""" | |
| def __init__( | |
| self, | |
| model: str = "groq/llama-3.3-70b-versatile", | |
| api_key: Optional[str] = None, | |
| temperature: float = 0.1 | |
| ): | |
| """ | |
| Initialize LLM client | |
| Args: | |
| model: Model identifier (e.g., "groq/llama-3.3-70b-versatile") | |
| api_key: API key (if None, uses GROQ_API_KEY env var) | |
| temperature: Sampling temperature | |
| """ | |
| self.model = model | |
| self.temperature = temperature | |
| if api_key: | |
| os.environ["GROQ_API_KEY"] = api_key | |
| elif "GROQ_API_KEY" not in os.environ: | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "GROQ_API_KEY not found. Please set it as environment variable " | |
| "or pass as api_key parameter. Get free key from https://console.groq.com/" | |
| ) | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 512, | |
| system_prompt: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Generate text using LLM | |
| Args: | |
| prompt: User prompt | |
| max_tokens: Maximum tokens to generate | |
| system_prompt: Optional system prompt | |
| Returns: | |
| Generated text | |
| """ | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| response = completion( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| raise Exception(f"Error calling LLM: {str(e)}") | |