| | import dataclasses |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import transformers |
| |
|
| | from .ultravox_config import UltravoxConfig |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): |
| | |
| | include_alt_fields: bool = False |
| |
|
| | def __call__(self, features, *args, **kwargs): |
| | audio_values = [x for f in features for x in f.pop("audio_values", [])] |
| | audio_lens = [x for f in features for x in f.pop("audio_lens", [])] |
| | audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])] |
| | audio_token_start_idx = [ |
| | x for f in features for x in f.pop("audio_token_start_idx", []) |
| | ] |
| |
|
| | if self.include_alt_fields: |
| | |
| | alt_features = [ |
| | { |
| | "input_ids": f.pop("alt_input_ids"), |
| | "attention_mask": f.pop("alt_attention_mask"), |
| | "labels": f.pop("alt_labels"), |
| | } |
| | for f in features |
| | ] |
| |
|
| | batch = super().__call__(features, *args, **kwargs) |
| | if self.include_alt_fields: |
| | alt_batch = super().__call__(alt_features, *args, **kwargs) |
| | batch["alt_input_ids"] = alt_batch["input_ids"] |
| | batch["alt_attention_mask"] = alt_batch["attention_mask"] |
| | batch["alt_labels"] = alt_batch["labels"] |
| |
|
| | batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx) |
| | batch["audio_lens"] = torch.stack(audio_lens) |
| | batch["audio_token_len"] = torch.stack(audio_token_len) |
| |
|
| | |
| | if audio_values: |
| | max_len = max([x.shape[-1] for x in audio_values]) |
| | batch["audio_values"] = torch.stack( |
| | [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values] |
| | ) |
| | if self.tokenizer.padding_side == "left": |
| | input_ids_lens = torch.LongTensor( |
| | [f["input_ids"].shape[-1] for f in features] |
| | ) |
| | displacement = batch["input_ids"].shape[-1] - input_ids_lens |
| | displacement = displacement.repeat_interleave( |
| | batch["audio_batch_size"].squeeze(-1) |
| | ) |
| | batch["audio_token_start_idx"] += displacement.to( |
| | batch["audio_token_start_idx"].device |
| | ) |
| | return batch |
| |
|
| |
|
| | class UltravoxProcessor(transformers.ProcessorMixin): |
| | """ |
| | Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor. |
| | |
| | Args: |
| | audio_processor: The audio processor for the audio encoder. |
| | tokenizer: The tokenizer for the language model. |
| | """ |
| |
|
| | attributes = ["audio_processor", "tokenizer"] |
| | audio_processor_class = ("WhisperProcessor",) |
| | tokenizer_class = ( |
| | "PreTrainedTokenizer", |
| | "PreTrainedTokenizerFast", |
| | ) |
| |
|
| | tokenizer: transformers.PreTrainedTokenizerBase |
| | audio_processor: transformers.ProcessorMixin |
| |
|
| | def __init__( |
| | self, |
| | audio_processor=None, |
| | tokenizer=None, |
| | audio_padding: str = "longest", |
| | encoder_ds_factor: int = 2, |
| | stack_factor: int = 8, |
| | audio_placeholder: str = "<|audio|>", |
| | |
| | audio_context_size: Optional[int] = 3000, |
| | ): |
| | """ |
| | Args: |
| | audio_processor: The audio processor for the audio encoder. |
| | tokenizer: The tokenizer for the language model. |
| | audio_padding: The padding strategy for the audio encoder. |
| | stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector. |
| | encoder_ds_factor: The downsampling factor of the audio encoder. |
| | audio_placeholder: The placeholder for the audio in the text. |
| | audio_context_size: The maximum number of frames that the audio encoder can handle. |
| | """ |
| | self.audio_padding = audio_padding |
| | self.encoder_ds_factor = encoder_ds_factor |
| | self.stack_factor = stack_factor |
| | self.audio_placeholder = audio_placeholder |
| | self.audio_context_size = audio_context_size |
| | assert ( |
| | tokenizer.eos_token is not None |
| | ), "The tokenizer has no EOS token. Cannot recover." |
| | self.vocab = tokenizer.get_vocab() |
| | self.audio_token_replacement = tokenizer.eos_token |
| | if tokenizer.pad_token_id is None: |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| |
|
| | super().__init__(audio_processor=audio_processor, tokenizer=tokenizer) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| | config: UltravoxConfig = transformers.AutoConfig.from_pretrained( |
| | pretrained_model_name_or_path, **kwargs |
| | ) |
| | audio_processor = transformers.AutoProcessor.from_pretrained( |
| | config.audio_model_id |
| | or config.audio_config._name_or_path |
| | or "openai/whisper-tiny" |
| | ) |
| |
|
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | pretrained_model_name_or_path, **kwargs |
| | ) |
| | tokenizer.padding_side = "left" |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | return cls( |
| | audio_processor=audio_processor, |
| | tokenizer=tokenizer, |
| | stack_factor=config.stack_factor, |
| | ) |
| |
|
| | def _chunk_and_pad_audio( |
| | self, |
| | audio_values: torch.Tensor, |
| | audio_lens: torch.Tensor, |
| | include_audio_num_chunks: bool = False, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Processes the audio batch by chunking any items in the batch according to the audio_context_size, |
| | padding the last chunk if needed, and returns a dictionary with updated audio data. |
| | |
| | Args: |
| | audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format). |
| | audio_lens (torch.Tensor): A tensor of audio lengths. |
| | |
| | Returns: |
| | Dict[str, Any]: Dictionary with the following keys: |
| | - "audio_values": The concatenated audio tensor after chunking and padding. |
| | - "audio_lens": Tensor of lengths for each chunk. |
| | - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk. |
| | - "audio_batch_size": A Tensor with one integer representing the number of chunks. |
| | |
| | """ |
| | chunked_audio_values: List[torch.Tensor] = [] |
| | chunked_audio_lens: List[int] = [] |
| | is_continuation_list: List[bool] = [] |
| | num_chunks: List[int] = [] |
| | context_size = self.audio_context_size or audio_values.shape[-1] |
| |
|
| | for i in range(audio_values.shape[0]): |
| | num_chunks.append(int(np.ceil(audio_lens[i] / context_size))) |
| | for offset in range(0, audio_lens[i], context_size): |
| | is_continuation = offset > 0 |
| | chunk = audio_values[i, :, offset : offset + context_size] |
| | if is_continuation and chunk.shape[-1] < context_size: |
| | |
| | |
| | |
| | |
| | |
| | chunk = F.pad(chunk, (0, context_size - chunk.shape[-1])) |
| | chunked_audio_values.append(chunk) |
| | chunked_audio_lens.append( |
| | min(int(audio_lens[i].item()) - offset, context_size) |
| | ) |
| | is_continuation_list.append(is_continuation) |
| |
|
| | data = { |
| | "audio_values": torch.stack(chunked_audio_values, dim=0), |
| | "audio_lens": torch.tensor( |
| | chunked_audio_lens, dtype=torch.int64, device=audio_values.device |
| | ), |
| | "audio_is_continuation": torch.tensor( |
| | is_continuation_list, dtype=torch.bool, device=audio_values.device |
| | ), |
| | "audio_batch_size": torch.tensor( |
| | [len(chunked_audio_values)], device=audio_values.device |
| | ), |
| | } |
| | if include_audio_num_chunks: |
| | data["audio_num_chunks"] = torch.tensor( |
| | num_chunks, dtype=torch.int64, device=audio_values.device |
| | ) |
| | return data |
| |
|
| | def __call__( |
| | self, |
| | text: Optional[str] = None, |
| | audio: Optional[Union[np.ndarray, torch.Tensor]] = None, |
| | audios: Optional[ |
| | Union[ |
| | List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] |
| | ] |
| | ] = None, |
| | sampling_rate: Optional[int] = None, |
| | return_tensors: Optional[ |
| | Union[str, transformers.TensorType] |
| | ] = transformers.TensorType.PYTORCH, |
| | include_audio_num_chunks: bool = False, |
| | **kwargs, |
| | ) -> transformers.BatchFeature: |
| | """ |
| | Main method to prepare for the model one text sequence and audio. This method forwards the `text` |
| | and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode |
| | the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to |
| | audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring |
| | of the above two methods for more information. |
| | |
| | Args: |
| | text (`str`, `List[str]`): |
| | The sequence to be encoded. Sequence can be a string or (pretokenized string). |
| | audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): |
| | The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor. |
| | audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): |
| | A list or two dimensional array of audio to be prepared. |
| | sampling_rate (`int`, *optional*, defaults to 16000): |
| | Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what |
| | you are doing. |
| | return_tensors (`str` or [`~utils.TensorType`], *optional*): |
| | If set, will return tensors of a particular framework. Acceptable values are: |
| | |
| | - `'tf'`: Return TensorFlow `tf.constant` objects. |
| | - `'pt'`: Return PyTorch `torch.Tensor` objects. |
| | - `'np'`: Return NumPy `np.ndarray` objects. |
| | - `'jax'`: Return JAX `jnp.ndarray` objects. |
| | |
| | Returns: |
| | [`BatchFeature`]: A [`BatchFeature`] with the following fields: |
| | |
| | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
| | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
| | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
| | `None`). |
| | - **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`. |
| | - **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound. |
| | Returned when `audio` is not `None`. |
| | - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`. |
| | """ |
| | |
| | if audio is not None and audios is not None: |
| | raise ValueError("Only one of `audio` or `audios` should be provided.") |
| | elif audio is not None: |
| | audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio] |
| | elif audios is None: |
| | audios = [] |
| |
|
| | data = {} |
| | audio_is_continuation = [] |
| | if len(audios) > 0: |
| | audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios] |
| |
|
| | |
| | hop_length = self.audio_processor.feature_extractor.hop_length |
| | audios = [ |
| | ( |
| | np.pad(x, (0, 2 * hop_length - len(x)), mode="constant") |
| | if len(x) < 2 * hop_length |
| | else x |
| | ) |
| | for x in audios |
| | ] |
| |
|
| | |
| | x: transformers.BatchFeature = self.audio_processor( |
| | audios, |
| | sampling_rate=sampling_rate, |
| | padding="longest", |
| | pad_to_multiple_of=hop_length, |
| | truncation=False, |
| | return_attention_mask=True, |
| | **kwargs, |
| | ) |
| |
|
| | data.update( |
| | self._chunk_and_pad_audio( |
| | audio_values=torch.as_tensor( |
| | x.input_features if "input_features" in x else x.input_values |
| | ), |
| | audio_lens=torch.as_tensor(x.attention_mask).sum(-1), |
| | include_audio_num_chunks=include_audio_num_chunks, |
| | ) |
| | ) |
| |
|
| | audio_is_continuation = data.pop("audio_is_continuation") |
| | data["audio_token_len"] = torch.ceil( |
| | data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor) |
| | ).to(dtype=torch.int) |
| |
|
| | if text is not None: |
| | if not isinstance(text, str): |
| | raise ValueError("Text must be a string. Batch mode not supported yet.") |
| |
|
| | |
| | tokenized_parts = self.tokenizer( |
| | text.split( |
| | "<|audio|>" |
| | ), |
| | add_special_tokens=False, |
| | **kwargs, |
| | ) |
| |
|
| | audio_token_start_idx = [] |
| | placeholder_index = -1 |
| | split_input_ids = tokenized_parts["input_ids"] |
| | input_ids: List[int] = [] |
| |
|
| | audio_token_replacement_token_id = self.vocab[self.audio_token_replacement] |
| |
|
| | for i, token_len in enumerate(data.get("audio_token_len", [])): |
| | if not audio_is_continuation[i]: |
| | placeholder_index += 1 |
| | if placeholder_index >= len(split_input_ids): |
| | raise ValueError( |
| | f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)" |
| | ) |
| |
|
| | input_ids.extend(split_input_ids[placeholder_index]) |
| |
|
| | audio_token_start_idx.append(len(input_ids)) |
| |
|
| | input_ids.extend([audio_token_replacement_token_id] * token_len) |
| |
|
| | |
| | placeholder_index += 1 |
| | if placeholder_index != len(split_input_ids) - 1: |
| | raise ValueError( |
| | f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)" |
| | ) |
| | input_ids.extend(split_input_ids[placeholder_index]) |
| |
|
| | if "audio_token_len" in data: |
| | data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx) |
| |
|
| | data["input_ids"] = [input_ids] |
| | data["attention_mask"] = [[1] * len(input_ids)] |
| |
|
| | |
| |
|
| | return transformers.BatchFeature(data=data, tensor_type=return_tensors) |
| |
|
| | def batch_decode(self, *args, **kwargs): |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | def decode(self, *args, **kwargs): |
| | return self.tokenizer.decode(*args, **kwargs) |
| |
|
| | @property |
| | def model_input_names(self): |
| | tokenizer_input_names = self.tokenizer.model_input_names |
| | audio_processor_input_names = self.audio_processor.model_input_names |
| | return list(set(tokenizer_input_names + audio_processor_input_names)) |
| |
|
| |
|
| | UltravoxProcessor.register_for_auto_class() |
| |
|
| | transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) |
| |
|