from kani_tts import KaniTTS from kani_tts import SpeakerEmbedder import os import json import torch from omegaconf import OmegaConf import numpy as np def load_config(config_path: str): """Load configuration from a YAML file using OmegaConf. Args: config_path (str): Path to the YAML configuration file. Returns: Any: The loaded OmegaConf DictConfig. """ resolved_path = os.path.abspath(config_path) if not os.path.exists(resolved_path): raise FileNotFoundError(f"Config file not found: {resolved_path}") config = OmegaConf.load(resolved_path) return config class InitModels: """ Lazy initializer that constructs a map of model name -> KaniTTS. Parameters ---------- models_configs : OmegaConf | DictConfig The `models` section from `model_config.yaml` describing one or more HF model checkpoints and their options (device_map, use_bematts, etc.). Returns ------- dict When called, returns a dictionary `{model_name: KaniTTS}`. Notes ----- - All models are loaded immediately in `__call__` so the UI can list them and switch between them without extra latency. - Each KaniTTS instance is initialized with its config directly. """ def __init__(self, models_configs: OmegaConf): self.models_configs = models_configs def __call__(self): models = {} for model_name, config in self.models_configs.items(): print(f"Loading {model_name}...") # Convert OmegaConf to dict to access parameters cfg_dict = dict(config) models[model_name] = KaniTTS( model_name=cfg_dict.get('model_name'), device_map=cfg_dict.get('device_map'), ) print(f"{model_name} loaded!") print("All models loaded!") return models class SpeakerManager: """ Manages speaker embeddings for the TTS application. Supports three modes: 1. Select speaker: Load pre-saved speaker embeddings from speaker_map.json 2. Generate embedding: Generate speaker embedding from uploaded audio using SpeakerEmbedder 3. JSON embedding: Parse speaker embedding from JSON string (list of 128 floats) Parameters ---------- speaker_map_path : str Path to speaker_map.json file Methods ------- get_speaker_emb(mode, speaker_name=None, json_emb=None) -> str | torch.Tensor | None Returns speaker embedding based on mode: - "select": Returns path to .pt file from speaker_map - "generate": Returns cached generated embedding tensor or None - "json": Returns embedding tensor parsed from JSON string or None generate_embedding(audio_data, sample_rate) -> torch.Tensor Generates speaker embedding from audio using SpeakerEmbedder. Expects audio at 16kHz. Caches the result internally. parse_json_embedding(json_str, speaker_emb_dim=128) -> torch.Tensor | None Parses speaker embedding from JSON string. Returns [1, dim] tensor or None if parsing fails. clean() Clears cached generated embedding. get_speaker_names() -> list[str] Returns list of available speaker names from speaker_map.json. """ def __init__(self, speaker_map_path: str = "./speakers/speaker_map.json"): self.speaker_map_path = speaker_map_path self.speaker_map = self._load_speaker_map() self.cached_embedding = None self.embedder = None def _load_speaker_map(self): """Load speaker map from JSON file.""" if not os.path.exists(self.speaker_map_path): return {} with open(self.speaker_map_path, "r") as f: return json.load(f) def get_speaker_names(self): """Get list of available speaker names.""" return list(self.speaker_map.keys()) def get_speaker_emb(self, mode: str, speaker_name: str = None, json_emb: str = None): """ Get speaker embedding based on mode. Parameters ---------- mode : str Either "select", "generate", or "json" speaker_name : str, optional Name of speaker from speaker_map (only used in "select" mode) json_emb : str, optional JSON string containing embedding list (only used in "json" mode) Returns ------- str | torch.Tensor | None Path to .pt file (select mode) or embedding tensor (generate/json mode) """ if mode == "select": if speaker_name and speaker_name in self.speaker_map: return self.speaker_map[speaker_name] return None elif mode == "generate": return self.cached_embedding elif mode == "json": return self.parse_json_embedding(json_emb) return None def generate_embedding(self, audio_data): """ Generate speaker embedding from audio data. Parameters ---------- audio_data : tuple Tuple of (sample_rate, audio_array) from Gradio Audio component Returns ------- torch.Tensor Generated speaker embedding [1, 128] """ # Initialize embedder lazily if self.embedder is None: self.embedder = SpeakerEmbedder() # Handle Gradio audio format (sr, audio) tuple if isinstance(audio_data, tuple): sample_rate, audio_array = audio_data else: # Fallback: assume it's just audio array at 16kHz audio_array = audio_data sample_rate = 16000 # convert audio from int16 (gradio.audio() returns int16) to float32 audio_array = audio_array.astype(np.float32) / 32768.0 # make mono if stereo (gradio specific)! gradio returns waveform with shape (num_samples, num_channels) that not typical for torch paradigm if audio_array.ndim == 2: print('Make MONO from STEREO') audio_array = audio_array.mean(axis=1) # Generate embedding (SpeakerEmbedder will handle resampling if needed) embedding = self.embedder.embed_audio(audio_array, sample_rate=sample_rate) # Cache the result self.cached_embedding = embedding return embedding def parse_json_embedding(self, json_str: str, speaker_emb_dim: int = 128): """ Parse speaker embedding from JSON string. Parameters ---------- json_str : str JSON string containing list of floats [0.123, -0.456, ...] speaker_emb_dim : int Expected embedding dimension (default: 128) Returns ------- torch.Tensor | None Speaker embedding tensor [1, dim] or None if parsing fails """ if not json_str or not json_str.strip(): print("No JSON embedding provided") return None try: # Parse JSON array emb_list = json.loads(json_str.strip()) # Validate it's a list if not isinstance(emb_list, list): print(f"Error: Speaker embedding must be a JSON array, got {type(emb_list)}") return None # Validate length if len(emb_list) != speaker_emb_dim: print(f"Error: Speaker embedding must have {speaker_emb_dim} dimensions, got {len(emb_list)}") return None # Convert to torch tensor [1, dim] speaker_emb = torch.tensor([emb_list], dtype=torch.float32) print(f"Using speaker embedding from JSON: shape {speaker_emb.shape}") return speaker_emb except json.JSONDecodeError as e: print(f"Error parsing JSON: {e}") return None except Exception as e: print(f"Error processing speaker embedding: {e}") return None def clean(self): """Clear cached generated embedding.""" self.cached_embedding = None return "Embedding cleared" def get_status(self): """Get current status of generated embedding.""" if self.cached_embedding is not None: return "✅ Embedding ready" return "No embedding generated" class Examples: """ Adapter that converts YAML examples into Gradio `gr.Examples` rows. Parameters ---------- exam_cfg : OmegaConf | DictConfig Parsed contents of `examples.yaml`. Expected structure: `examples: [ {text, model, speaker?, temperature?, top_p?, repetition_penalty?}, ... ]`. Behavior -------- - Produces a list-of-lists whose order must match the `inputs` order used when constructing `gr.Examples` in `app.py`. - Current order: `[text, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp]`. Why this exists --------------- - Keeps format and defaults centralized, so changing the UI inputs order only requires a single change here and in `app.py`. """ def __init__(self, exam_cfg: OmegaConf): self.exam_cfg = exam_cfg def __call__(self) -> list[list]: rows = [] for e in self.exam_cfg.examples: text = e.get("text") model = e.get("model") speaker_mode = e.get("speaker_mode", "select") # Default to "select" mode speaker = e.get("speaker", "Kore (en)") embedding_state = None # Examples always use select mode, so no embedding state needed json_input = e.get("json_input", "") # Empty JSON input for examples temperature = e.get("temperature", 1.0) top_p = e.get("top_p", 0.95) repetition_penalty = e.get("repetition_penalty", 1.1) # Order must match gr.Examples inputs: [text, model_dropdown, speaker_mode, speaker_dropdown, embedding_state, json_input, temp, top_p, rp] rows.append([text, model, speaker_mode, speaker, embedding_state, json_input, temperature, top_p, repetition_penalty]) return rows