import io import re import joblib import requests import importlib import urllib.request from pathlib import Path from itertools import groupby from langcodes import Language, LanguageTagError from typing import Union, Tuple, List, Dict, Optional import torch import torchaudio import numpy as np from transformers import Pipeline, AutoModel, AutoFeatureExtractor, GenerationConfig from transformers.utils import logging _EXTRA_DEPENDENCIES = [ "sentencepiece", "google.protobuf", "sklearn", ] missing_packages = [] for dependency in _EXTRA_DEPENDENCIES: try: importlib.import_module(dependency) except ImportError: missing_packages.append(dependency) if len(missing_packages) > 0: raise ImportError( "This pipeline requires the following packages that were not found in your environment: " f"{', '.join(missing_packages)}." ) logging.set_verbosity_info() logger = logging.get_logger("transformers") class ChatHistory: """ A class to store the chat history. """ def __init__(self, messages: List[Dict[str, str]]): self.messages = messages def __len__(self): return len(self.messages) def __getitem__(self, index: int) -> Dict[str, str]: return self.messages[index] def __setitem__(self, index: int, value: Dict[str, str]): self.messages[index] = value @classmethod def parse_chat(cls, formatted_chat: str): """ Parse a ChatML-formatted string into a list of message dicts. """ # Pattern matches all <|im_start|>role\ncontent blocks # Allows optional whitespace around roles and handles missing <|im_end|> pattern = r"<\|im_start\|>\s*(\w+)\s*\n(.*?)(?=<\|im_start\|>|\Z)" matches = re.findall(pattern, formatted_chat, re.DOTALL) messages = [] for role, content in matches: # Remove trailing <|im_end|> if present, along with any trailing whitespace content = re.sub(r"<\|im_end\|>\s*$", "", content, flags=re.DOTALL).strip() messages.append({ "role": role.strip(), "content": content.strip() }) return cls(messages) def get_messages(self, role: Optional[str] = None) -> List[str]: """ Retrieve message contents, optionally filtered by role. Args: role (Optional[str]): If specified, only messages with this role are returned. Returns: List[str]: List of message contents. """ if role is None: return [m["content"] for m in self.messages] else: return [m["content"] for m in self.messages if m["role"] == role] def get_assistant_messages(self) -> List[str]: """ Retrieve all assistant message contents. Returns: List[str]: List of assistant message contents. """ return self.get_messages("assistant") def get_user_messages(self) -> List[str]: """ Retrieve all user message contents. Returns: List[str]: List of user message contents. """ return self.get_messages("user") class MultimodalMtPipeline(Pipeline): """ A HuggingFace pipeline that integrates speech into an LLM, by discretizing the speech, and enables multimodal translation. This pipeline: 1. Encodes speech inputs by extracting features from a specific layer of a HuBERT model. 2. Applies k-means clustering to discretize the features and obtain discrete units. 3. Integrates the discrete units into the LLM, by passing them as input to the model. """ _pipeline_calls_generate = True _modes = ["asr", "t2tt", "s2tt", "lid"] def __init__( self, speech_encoder_name_or_path: str = "slprl/mhubert-base-25hz", speech_encoder_revision: str = "a319086e1d343190047d02b7f81133fb310c1b90", speech_encoder_layer: int = 11, kmeans_url_or_path: str = "https://dl.fbaipublicfiles.com/textless_nlp/twist/speech_tokenizer/mhubert_base_25hz_cp_mls_cv_sp_fisher_L11_km500.bin", deduplicate_dsu: bool = True, **kwargs ): super().__init__(**kwargs) logger.info(f"Loading speech encoder from: {speech_encoder_name_or_path}") self.feature_extractor = AutoFeatureExtractor.from_pretrained(speech_encoder_name_or_path) self.speech_encoder = AutoModel.from_pretrained(speech_encoder_name_or_path, revision=speech_encoder_revision).eval().to(self.device) self.speech_encoder_layer = speech_encoder_layer logger.info(f"Loading k-means model from: {kmeans_url_or_path}") self.kmeans = self._load_kmeans(kmeans_url_or_path) self.deduplicate_dsu = deduplicate_dsu def _sanitize_parameters(self, **kwargs): def _sanitize_lang(lang: str) -> str: try: lang = Language.get(lang) except LanguageTagError: lang = Language.find(lang) lang = lang.prefer_macrolanguage().simplify_script() if not lang.is_valid(): raise ValueError(f"Invalid language: {lang}") if lang.language_name() != lang.display_name(): logger.warning( f"Using '{lang.language_name()}' instead of '{lang.display_name()}' in the generated TSVs." ) return lang.language_name() preprocess_kwargs = {} forward_kwargs = {} postprocess_kwargs = {} if "mode" in kwargs: preprocess_kwargs["mode"] = kwargs["mode"] if "src_lang" in kwargs: preprocess_kwargs["src_lang"] = _sanitize_lang(kwargs["src_lang"]) if "tgt_lang" in kwargs: preprocess_kwargs["tgt_lang"] = _sanitize_lang(kwargs["tgt_lang"]) forward_kwargs["generation_cfg"] = GenerationConfig(**kwargs) if "return_chat_history" in kwargs: postprocess_kwargs["return_chat_history"] = kwargs["return_chat_history"] return preprocess_kwargs, forward_kwargs, postprocess_kwargs def _load_kmeans(self, kmeans_path: str): """ Load the k-means model from a local file or URL. """ from urllib.error import URLError import hashlib def _download_and_cache(url, cache_dir): cache_dir = Path(cache_dir).expanduser() cache_dir.mkdir(parents=True, exist_ok=True) # Use a hash of the URL for the filename to avoid collisions url_hash = hashlib.md5(url.encode('utf-8')).hexdigest() filename = Path(url).name cached_path = cache_dir / f"{url_hash}_{filename}" if not cached_path.exists(): try: urllib.request.urlretrieve(url, cached_path) except URLError as e: raise RuntimeError(f"Failed to download k-means model from {url}: {e}") return cached_path if isinstance(kmeans_path, str) and kmeans_path.startswith(('http://', 'https://')): cache_dir = "~/.cache/huggingface/kmeans" kmeans_path = _download_and_cache(kmeans_path, cache_dir) else: kmeans_path = Path(kmeans_path).expanduser() if not kmeans_path.exists(): raise FileNotFoundError(f"K-means model not found at: {kmeans_path}") try: kmeans_model = joblib.load(kmeans_path) except Exception as e: raise RuntimeError(f"Failed to load k-means model from {kmeans_path}: {e}") return kmeans_model def _maybe_load_and_preprocess_audio(self, audio_path_or_array: Union[str, Tuple[np.ndarray, int], Tuple[torch.Tensor, int]]) -> np.ndarray: """ Load an audio file or array and return as a mono numpy array at the model's sampling rate. Supports: - File path (str) - Tuple of (np.ndarray or torch.Tensor, sampling_rate) Returns: np.ndarray: 1D mono audio at self.feature_extractor.sampling_rate """ target_sr = self.feature_extractor.sampling_rate if isinstance(audio_path_or_array, str): if audio_path_or_array.startswith(('http://', 'https://')): try: with requests.get(audio_path_or_array) as response: response.raise_for_status() audio_buffer = io.BytesIO(response.content) audio, sr = torchaudio.load(audio_buffer) except Exception as e: raise RuntimeError(f"Failed to load audio file from URL '{audio_path_or_array}': {e}") else: try: audio, sr = torchaudio.load(audio_path_or_array) except FileNotFoundError: raise FileNotFoundError(f"Audio file not found: {audio_path_or_array}") except Exception as e: raise RuntimeError(f"Failed to load audio file '{audio_path_or_array}': {e}") elif isinstance(audio_path_or_array, tuple): audio, sr = audio_path_or_array if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio) elif not isinstance(audio, torch.Tensor): raise TypeError("Audio must be a numpy array or torch tensor") if not isinstance(sr, int): raise TypeError("Sampling rate must be an integer") else: raise TypeError("Input must be a file path or a tuple of (audio, sampling_rate)") if audio.ndim == 1: audio = audio.unsqueeze(0) elif audio.ndim > 2: raise ValueError("Audio input must be 1D or 2D (channels, samples)") if audio.shape[0] > 1: logger.debug(f"Input audio has {audio.shape[0]} channels; converting to mono by averaging channels.") audio = audio.mean(dim=0, keepdim=True) if sr != target_sr: audio = torchaudio.functional.resample(audio, sr, target_sr) audio_np = audio.squeeze().cpu().numpy() return audio_np def _remove_dsu_padding( self, dsu_sequences: List[np.ndarray], batch_input_lens: List[int] ) -> List[np.ndarray]: """ Remove padding from the discrete units sequences, returning only the valid (unpadded) portion for each sequence in the batch. Args: dsu_sequences (List[np.ndarray]): List of discrete unit sequences (possibly padded). batch_input_lens (List[int]): List of original input lengths (in samples) for each sequence. Returns: List[np.ndarray]: List of discrete unit sequences with padding removed. """ # Compute the output lengths after feature extraction for each input discrete_seq_lens = self.speech_encoder._get_feat_extract_output_lengths( torch.as_tensor(batch_input_lens) ).tolist() # Remove padding according to the padding side if self.feature_extractor.padding_side == "left": return [ dsu_seq[-seq_len:] if seq_len > 0 else np.array([], dtype=dsu_seq.dtype) for dsu_seq, seq_len in zip(dsu_sequences, discrete_seq_lens) ] else: return [ dsu_seq[:seq_len] if seq_len > 0 else np.array([], dtype=dsu_seq.dtype) for dsu_seq, seq_len in zip(dsu_sequences, discrete_seq_lens) ] def _deduplicate_sequence(self, sequence: Union[str, np.ndarray]) -> Union[str, np.ndarray]: """ Deduplicate a sequence. """ if isinstance(sequence, str): return " ".join([key for key, _ in groupby(sequence.strip().split(" "))]) elif isinstance(sequence, np.ndarray): assert len(sequence.shape) == 1, "Sequence must be a 1D array" return np.array([key for key, _ in groupby(sequence)]) else: raise ValueError(f"Invalid discrete units sequence type: {type(sequence)}") def _dsu_array_to_str(self, dsu_array: np.ndarray) -> str: """ Convert a discrete units array to a string. """ def int_to_pua(n: int) -> chr: assert 0 <= n <= (0xFFFFD - 0xF0000), f"number out of SPUA-A range: {n}" return chr(0xF0000 + n) return " ".join([int_to_pua(unit) for unit in dsu_array]) def _dsu_str_to_array(self, dsu_str: str) -> np.ndarray: """ Convert a discrete units string to an array. """ def pua_to_int(pua: chr) -> int: assert 0xF0000 <= ord(pua) <= 0xFFFFD, f"PUA not in SPUA-A range: {pua}" return ord(pua) - 0xF0000 return np.array([pua_to_int(unit) for unit in dsu_str]) def discretize_speech( self, inputs: Union[str, Tuple[np.ndarray, int], List[Union[str, Tuple[np.ndarray, int]]]], deduplicate: bool = True, as_str: bool = False ) -> Union[List[np.ndarray], List[str]]: """ Discretize one or more speech inputs into discrete units. Args: inputs: Audio input(s). Can be a file path (str), a tuple of (numpy array, sampling_rate), or a list of any of these types. deduplicate: If True, consecutive duplicate discrete units will be removed. as_str: If True, returns the discrete units as strings; otherwise, returns as numpy arrays. Returns: List of discrete speech unit arrays or strings, depending on `as_str`. """ if isinstance(inputs, (str, np.ndarray)): inputs = [inputs] else: assert isinstance(inputs, list), "Input must be a string, numpy array, or list" batch_inputs = [self._maybe_load_and_preprocess_audio(audio_in) for audio_in in inputs] batch_input_lens = [audio.shape[0] for audio in batch_inputs] batch_inputs = self.feature_extractor( batch_inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt", padding=True ) batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()} with torch.no_grad(): speech_encoder_outs = self.speech_encoder(**batch_inputs, output_hidden_states=True) speech_encoder_hs = speech_encoder_outs.hidden_states[self.speech_encoder_layer] speech_encoder_hs = speech_encoder_hs.cpu().float().numpy() B, T, C = speech_encoder_hs.shape dsu_sequences = self.kmeans.predict(speech_encoder_hs.reshape(B * T, C)).reshape(B, T) dsu_sequences = self._remove_dsu_padding(dsu_sequences, batch_input_lens) if deduplicate: dsu_sequences = [self._deduplicate_sequence(dsu_seq) for dsu_seq in dsu_sequences] if as_str: dsu_sequences = [self._dsu_array_to_str(dsu_seq) for dsu_seq in dsu_sequences] return dsu_sequences def preprocess(self, input_: Union[str, np.ndarray, ChatHistory], mode: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None, ) -> Dict[str, torch.Tensor]: """ Preprocess input audio. Args: input_: Audio/text input. Can be a file path (str), a tuple of (numpy array, sampling_rate), or a sentence (str). It can also be a ChatHistory object for multi-turn conversations. mode: Mode of the pipeline. Can be "asr", "t2tt", "s2tt-direct", or "s2tt-cot". chat_history: Chat history for the model. Only needed for multi-turn conversations. src_lang: Source language. If None, infer automatically. tgt_lang: Target language. Only needed in translation modes. **kwargs: Additional preprocessing arguments Returns: Dict containing the processed inputs as needed by the model """ assert mode in self._modes, "Invalid mode" is_speech_mode = mode in ["asr", "s2tt"] is_translation_mode = mode in ["t2tt", "s2tt"] no_src_lang = src_lang is None is_multiturn_step = isinstance(input_, ChatHistory) if is_translation_mode and not tgt_lang: raise ValueError("You must specify 'tgt_lang' when using a translation mode (t2tt or s2tt).") if is_speech_mode and not is_multiturn_step: input_ = " " + self.discretize_speech(input_, deduplicate=self.deduplicate_dsu, as_str=True)[0] # Add space to avoid tokenization issues if mode == "asr": new_message = "{inp}\n" if not is_multiturn_step else "" new_message += "Transcribe" new_message += " in {src_lang}" if not no_src_lang else "" elif mode == "t2tt" or mode == "s2tt": new_message = "{inp}\n" if not is_multiturn_step else "" new_message += "Translate" new_message += " from {src_lang} to {tgt_lang}" if not no_src_lang else " to {tgt_lang}" elif mode == "lid": new_message = "Identify the source language" else: raise ValueError("Invalid mode.") chat_history = input_.messages if is_multiturn_step else [] messages = chat_history + [{"role": "user", "content": new_message}] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) prompt = prompt.format(inp=input_, src_lang=src_lang, tgt_lang=tgt_lang) model_inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, add_special_tokens=False ) return model_inputs def _forward( self, model_inputs: Dict[str, torch.Tensor], generation_cfg: GenerationConfig = GenerationConfig() ) -> Dict[str, torch.Tensor]: """ Forward pass through the model. Args: model_inputs: Preprocessed inputs Returns: Dict containing the generated outputs """ model_outputs = self.model.generate( **model_inputs, pad_token_id=self.tokenizer.eos_token_id, generation_config=generation_cfg, ) return {"generated_token_ids": model_outputs} def postprocess( self, model_outputs: Union[Dict[str, torch.Tensor], List[Dict[str, str]]], return_chat_history: Optional[bool] = False ) -> List[np.ndarray]: """ Postprocess the model outputs to get discretized units. Args: model_outputs: Model outputs containing generated outputs chain: Whether to return a list of messages for chaining pipelines instead of postprocessed outputs Returns: List of output strings and optionally a list of messages if return_chat_history is True """ outputs_detok = self.tokenizer.decode(model_outputs["generated_token_ids"][0]) output_chat_history = ChatHistory.parse_chat(outputs_detok) output = output_chat_history[-1]["content"] if return_chat_history: return output_chat_history else: return output