""" small engine for krave Usage: from engine import KraveEngine engine = KraveEngine( ckpt_path="/path/to/Krave-2.5", config="inference/configs/config_671B.json" ) response = engine.generate("Explain quantum computing.") print(response) engine.chat() """ import json import os import sys from typing import List, Optional INFERENCE_DIR = os.path.join(os.path.dirname(__file__), "inference") if INFERENCE_DIR not in sys.path: sys.path.insert(0, INFERENCE_DIR) from generate import generate as generate_tokens from model import ModelArgs, Transformer from safetensors.torch import load_model from transformers import AutoTokenizer class KraveEngine: def __init__( self, ckpt_path: str, config: str = "inference/configs/config_671B.json", max_new_tokens: int = 200, temperature: float = 0.7, device: str = "cuda", weights_file: str | None = None, ): self.ckpt_path = ckpt_path self.config_path = config self.max_new_tokens = max_new_tokens self.temperature = temperature self.device = device self.weights_file = weights_file self._model = None self._tokenizer = None self._args = None self._loaded = False def _weight_file(self) -> str: if self.weights_file: return self.weights_file world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) return os.path.join(self.ckpt_path, f"model{rank}-mp{world_size}.safetensors") def _load(self): if self._loaded: return try: import torch except ImportError as exc: raise RuntimeError("PyTorch is not installed. Krave Engine requires torch==2.4.1.") from exc with open(self.config_path) as f: self._args = ModelArgs(**json.load(f)) torch.set_default_dtype(torch.bfloat16) torch.manual_seed(42) with torch.device(self.device): self._model = Transformer(self._args) self._tokenizer = AutoTokenizer.from_pretrained(self.ckpt_path) weight_file = self._weight_file() if not os.path.exists(weight_file): raise FileNotFoundError( f"Missing weight file: {weight_file}. Put your custom weights in the checkpoint directory." ) load_model(self._model, weight_file) self._loaded = True def generate(self, prompt: str) -> str: self._load() tokens = self._tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], add_generation_prompt=True, ) output_tokens = generate_tokens( self._model, [tokens], self.max_new_tokens, self._tokenizer.eos_token_id, self.temperature, ) return self._tokenizer.decode(output_tokens[0], skip_special_tokens=True) def generate_batch(self, prompts: List[str]) -> List[str]: self._load() assert len(prompts) <= self._args.max_batch_size all_tokens = [ self._tokenizer.apply_chat_template( [{"role": "user", "content": p}], add_generation_prompt=True, ) for p in prompts ] output_tokens = generate_tokens( self._model, all_tokens, self.max_new_tokens, self._tokenizer.eos_token_id, self.temperature, ) return self._tokenizer.batch_decode(output_tokens, skip_special_tokens=True) def chat(self, first_message: Optional[str] = None): self._load() messages = [] if first_message: messages.append({"role": "user", "content": first_message}) self._respond(messages) while True: try: user_input = input("You: ").strip() except (EOFError, KeyboardInterrupt): break if user_input == "/exit": break if user_input == "/clear": messages.clear() continue if not user_input: continue messages.append({"role": "user", "content": user_input}) reply = self._respond(messages) messages.append({"role": "assistant", "content": reply}) def _respond(self, messages) -> str: tokens = self._tokenizer.apply_chat_template(messages, add_generation_prompt=True) output_tokens = generate_tokens( self._model, [tokens], self.max_new_tokens, self._tokenizer.eos_token_id, self.temperature, ) reply = self._tokenizer.decode(output_tokens[0], skip_special_tokens=True) print(reply) return reply if __name__ == "__main__": print("Krave Prepared")