#!/usr/bin/env python # -*- coding: utf-8 -*- import json import logging import math import os import tempfile import threading import time import types from copy import deepcopy from dataclasses import dataclass from threading import Thread from typing import Dict from typing import List from typing import Literal from typing import Optional from typing import Tuple from typing import Union import numpy as np import soundfile as sf import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.parametrize as P from PIL import Image from pydantic import BaseModel from torch.nn.utils.parametrizations import weight_norm from tqdm import tqdm from transformers import LlamaConfig from transformers import LlamaModel from transformers import PreTrainedModel from transformers import Qwen3ForCausalLM from transformers import Qwen3PreTrainedModel from transformers import TextIteratorStreamer from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.cache_utils import DynamicCache from transformers.cache_utils import EncoderDecoderCache from transformers.cache_utils import StaticCache from transformers.generation.logits_process import TopKLogitsWarper from transformers.generation.logits_process import TopPLogitsWarper from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import ModelOutput from transformers.models.whisper.configuration_whisper import WhisperConfig from transformers.models.whisper.modeling_whisper import WhisperEncoder from .chunk_prefill_generate import ChunkPrefillChunkGenerate from .configuration_minicpmo import MiniCPMOConfig from .configuration_minicpmo import MiniCPMODuplexConfig from .configuration_minicpmtts import MiniCPMTTSConfig from .modeling_navit_siglip import SiglipVisionTransformer from .processing_minicpmo import MiniCPMOProcessor from .resampler import Resampler from .sliding_utils import as_dynamic_cache from .sliding_utils import drop_tokens_from_cache from .sliding_utils import get_kv_cache_length from .sliding_utils import realign_rotary_suffix from .sliding_utils import StreamingWindowConfig from .tts_streaming_generate import TTSStreamingGenerator from .utils import streaming_token_decoder from .utils import torch_clone_recursive # NOTE: These imports are added to work around a Transformers bug where secondary # relative imports are not copied to the cache when loading from a local path. # See: https://github.com/huggingface/transformers/issues/XXXXX from .audio_utils import process_audio_batch as _ # noqa: F401 from .processing_audio_minicpma import MiniCPMAAudioProcessor as __ # noqa: F401 from .processing_image_minicpmv import MiniCPMOBatchFeature as ___ # noqa: F401 from .processing_streaming_mel import StreamingMelProcessorExact as ____ # noqa: F401 logger = logging.getLogger(__name__) @dataclass class SpeculativeSnapshot: """Speculative snapshot for VAD speculative rollback. Used in VAD speculative execution: creates a snapshot after streaming_prefill and before streaming_generate. If speculation fails (user continues speaking), the state can be restored to continue streaming_prefill. Implementation: - LLM KV Cache: only record length, restore by truncation (zero extra VRAM) - Audio KV Cache: requires cloning, as generate sets it to None - Mel processor: save full state snapshot (including buffer) """ # KV Cache 长度(用于裁剪恢复) llm_cache_length: int audio_cache_length: int # 会话状态 new_user_msg: bool llm_generated: bool llm_generate_completed: bool # Round 管理 next_round_id: int pending_round_id: Optional[int] omni_chunk_history_length: int # TTS 状态(需要克隆,但通常很小) tts_last_turn_tokens: Optional[torch.Tensor] # Streaming 处理器状态 audio_chunk_idx: int # Mel processor 状态快照(包括 buffer) mel_processor_snapshot: Optional[dict] = None # Audio encoder KV cache(需要克隆以确保恢复后 continue prefill 的确定性) audio_past_key_values: Optional[tuple] = None # 时间戳(调试用) timestamp: float = 0.0 # 调试字段:用于验证恢复的正确性 llm_cache_checksum: Optional[float] = None # LLM KV Cache 第一层 K 的 sum audio_cache_checksum: Optional[float] = None # Audio KV Cache 第一层 K 的 sum mel_buffer_checksum: Optional[float] = None # Mel buffer 的 sum # RNG 状态(关键:用于恢复后确保 dithering 等随机操作的确定性) rng_state_cpu: Optional[torch.Tensor] = None # torch CPU RNG state rng_state_cuda: Optional[torch.Tensor] = None # torch CUDA RNG state (if on GPU) def summary(self) -> str: """返回快照摘要,用于日志""" mel_buf_len = 0 if self.mel_processor_snapshot: buf = self.mel_processor_snapshot.get("buffer") if buf is not None: mel_buf_len = len(buf) return ( f"llm_cache={self.llm_cache_length}, " f"audio_cache={self.audio_cache_length}, " f"audio_chunk_idx={self.audio_chunk_idx}, " f"mel_buffer={mel_buf_len}, " f"history_len={self.omni_chunk_history_length}, " f"new_user_msg={self.new_user_msg}, " f"llm_generated={self.llm_generated}" ) class TTSSamplingParams(BaseModel): top_p: float = 0.85 min_p: float = 0.01 top_k: int = 25 repetition_penalty: float = 1.05 temperature: float = 0.8 win_size: int = 16 tau_r: float = 0.1 class MiniCPMOPreTrainedModel(Qwen3PreTrainedModel): config_class = MiniCPMOConfig class MiniCPMO(MiniCPMOPreTrainedModel): def __init__(self, config): super().__init__(config) self.llm = Qwen3ForCausalLM(config) self.embed_dim = self.llm.config.hidden_size self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm # init vision module if self.config.init_vision: self.vpm = self.init_vision_module() self.vision_dim = self.vpm.embed_dim self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) # init audio module if self.config.init_audio: self.apm = self.init_audio_module() audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim) self.audio_encoder_layer = -1 # init tts module if self.config.init_tts: self.tts = self.init_tts_module() self.terminators = ["<|im_end|>", "<|endoftext|>"] self.think_str = "" if self.llm.__class__.__name__ == "Qwen3ForCausalLM": self.think_str = "\\n\\n\\n\\n" self.default_tts_chat_template = ( "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n" + self.think_str + "<|tts_bos|>' }}{% endif %}" ) # for streaming self.reset_session(reset_token2wav_cache=True) # streaming audio processing constants self.SAMPLE_RATE = 16000 self.CHUNK_MS = 1000 # regular chunk length (ms) self.FIRST_CHUNK_MS = 1035 # first chunk length (ms) self.CNN_REDUNDANCY_MS = 0 # CNN redundancy (ms) # for sliding window self.streaming_window_config = StreamingWindowConfig() self.streaming_require_system_prompt = True self.streaming_window_enabled = True self.force_rope_reindex = False # RoPE reindex testing switch def init_streaming_processor(self): if not hasattr(self, "processor") or self.processor is None: self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) if hasattr(self.processor, "set_streaming_mode"): self.processor.set_streaming_mode( mode="exact", chunk_ms=self.CHUNK_MS, verbose=False, first_chunk_ms=self.FIRST_CHUNK_MS, cnn_redundancy_ms=self.CNN_REDUNDANCY_MS, enable_sliding_window=True, slide_trigger_seconds=30.0, slide_stride_seconds=10.0, ) self.processor.reset_streaming() self.audio_chunk_idx = 0 def reset_session(self, reset_token2wav_cache=True): self.llm_past_key_values = None self.audio_past_key_values = None self.tts_last_turn_tokens = None self.llm_generated = False # last turn generated by llm or not self.llm_generate_completed = False self.new_user_msg = True self.session_id = None if reset_token2wav_cache: self.token2wav_cache = None # for sliding window self.streaming_text_preserve = 0 self.streaming_position_offset = 0 self._rope_inv_freq_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} self._next_round_id = 0 self._pending_round_id = None self._omni_chunk_history: List[Dict[str, Union[str, int]]] = [] self._round_history: List[Dict[str, Union[int, str, torch.Tensor, Optional[int]]]] = [] def init_vision_module(self): if self.config._attn_implementation == "flash_attention_2": self.config.vision_config._attn_implementation = "flash_attention_2" else: self.config.vision_config._attn_implementation = "eager" model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] setattr(model, "embed_dim", model.embeddings.embed_dim) setattr(model, "patch_size", model.embeddings.patch_size) return model def init_resampler(self, embed_dim, vision_dim): return Resampler( num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True, ) def init_audio_module(self): if self.config._attn_implementation == "eager": self.config.audio_config._attn_implementation = "eager" else: # using flash_attention_2 will cause: RuntimeError: cu_seqlens_q must have shape (batch_size + 1) self.config.audio_config._attn_implementation = "sdpa" return MiniCPMWhisperEncoder(self.config.audio_config) def init_tts_module(self): if self.config._attn_implementation == "flash_attention_2": self.config.tts_config.attn_implementation = "flash_attention_2" else: self.config.tts_config.attn_implementation = "eager" return MiniCPMTTS(config=self.config.tts_config, audio_tokenizer=None) def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_timesteps=10): if streaming: if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio": logger.warning("audio tokenizer type is set to s3tokenizer_step_audio") self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio" try: from stepaudio2 import Token2wav except ImportError: raise ImportError(f"please install Token2wav via: pip install stepaudio2-minicpmo") model_dir = model_dir or os.path.join(self.config._name_or_path, "assets/token2wav") self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps) return self.tts.audio_tokenizer else: if self.config.tts_config.audio_tokenizer_type != "s3tokenizer": logger.warning("audio tokenizer type is set to s3tokenizer") self.tts.config.audio_tokenizer_type = "s3tokenizer" try: from cosyvoice.cli.cosyvoice import CosyVoice2 except ImportError: raise ImportError(f"please install cosyvoice via: pip install cosyvoice-minicpmo") model_dir = model_dir or os.path.join(self.config._name_or_path, "assets/CosyVoice2-0.5B") self.tts.audio_tokenizer = CosyVoice2(model_dir=model_dir, load_jit=False, load_trt=False, fp16=False) return self.tts.audio_tokenizer def get_input_embeddings(self): return self.llm.get_input_embeddings() def set_input_embeddings(self, value): self.llm.embed_tokens = value def get_output_embeddings(self): return self.llm.lm_head def set_output_embeddings(self, new_embeddings): self.llm.lm_head = new_embeddings def set_decoder(self, decoder): self.llm = decoder def get_decoder(self): return self.llm @staticmethod def get_sys_prompt(ref_audio=None, mode="default", language="en", ref_audio_max_ms=None): if ref_audio is not None: if isinstance(ref_audio, str): if ref_audio == "assets/demo.wav": import librosa duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration) else: import os import librosa if os.path.isfile(ref_audio) and os.path.exists(ref_audio): duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration) else: logger.error(f"Could not find {ref_audio}") ref_audio = None assert isinstance(ref_audio, np.ndarray), "ref_audio error" if mode == "omni": if language == "zh": sys_prompt = "" vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。" vc_prompt_suffix = ( "请用这种声音风格来为用户提供帮助。 请认真、高质量地回复用户的问题。 请用高自然度的方式和用户聊天。" ) else: sys_prompt = "" vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt." vc_prompt_suffix = "As an assistant, you will speak using this voice style." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} else: sys_msgs = {"role": "system", "content": [sys_prompt]} return sys_msgs elif mode == "audio_assistant": if language == "zh": vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。" vc_prompt_suffix = "你的任务是用这种声音模式来当一个助手。请认真、高质量地回复用户的问题。请用高自然度的方式和用户聊天。你是由面壁智能开发的人工智能助手:面壁小钢炮。" else: vc_prompt_prefix = "Use the voice in the audio prompt to synthesize new content." vc_prompt_suffix = "You are a helpful assistant with the above voice style." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} else: logger.warning( "Warning: ref_audio is None, speech generation will be performed based on the default voice." ) sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]} return sys_msgs elif mode == "audio_roleplay": if language == "zh": vc_prompt_prefix = "模仿输入音频中的声音特征。" vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。" else: vc_prompt_prefix = "Clone the voice in the provided audio prompt." vc_prompt_suffix = "Try to role-play the character based on the audio prompt above." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} else: print("Warning: ref_audio is None, speech generation will be performed based on the default voice.") sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]} return sys_msgs elif mode == "voice_cloning": if language == "zh": vc_prompt_prefix = "模仿输入音频中的声音特征。" else: vc_prompt_prefix = "Clone the voice in the provided audio prompt." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio]} else: raise ValueError("ref_audio con't be None in voice_cloning mode.") return sys_msgs else: sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text." sys_msgs = {"role": "system", "content": [sys_prompt]} return sys_msgs @staticmethod def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = torch.device("cpu"), num_lookhead: int = 0, ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk >=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device num_lookhead: Returns: torch.Tensor: mask """ ret = torch.zeros(size, size, device=device, dtype=torch.bool) for i in range(size): if num_left_chunks < 0: start = 0 else: start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) ret[i, start:ending] = True return ret def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """Computes the output length of the convolutional layers and the output length of the audio encoder""" input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 input_lengths_after_pooling = ( input_lengths_after_cnn - self.config.audio_pool_step ) // self.config.audio_pool_step + 1 input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling def get_vision_embedding(self, data): if "vision_hidden_states" not in data: dtype = self.llm.model.embed_tokens.weight.dtype device = self.llm.model.embed_tokens.weight.device tgt_sizes = data["tgt_sizes"] pixel_values_list = data["pixel_values"] vision_hidden_states = [] all_pixel_values = [] img_cnt = [] for pixel_values in pixel_values_list: img_cnt.append(len(pixel_values)) all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # exist image if all_pixel_values: tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values, batch_first=True, padding_value=0.0 ) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_batch_size = self.config.vision_batch_size all_pixel_values = all_pixel_values.type(dtype) if B > vision_batch_size: hs = [] for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size tmp_hs = self.vpm( all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx], ).last_hidden_state hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: vision_embedding = self.vpm( all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, ).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) start = 0 for pixel_values in pixel_values_list: img_cnt = len(pixel_values) if img_cnt > 0: vision_hidden_states.append(vision_embedding[start : start + img_cnt]) start += img_cnt else: vision_hidden_states.append([]) else: # no image if self.training: dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype) tgt_sizes = torch.Tensor( [ [ (224 // self.config.patch_size), math.ceil(224 / self.config.patch_size), ] ] ).type(torch.int32) dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) else: dummy_feature = [] for _ in range(len(pixel_values_list)): vision_hidden_states.append(dummy_feature) else: vision_hidden_states = data["vision_hidden_states"] return vision_hidden_states def get_vllm_embedding(self, data): vision_hidden_states = self.get_vision_embedding(data) if hasattr(self.llm.config, "scale_emb"): vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb else: vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) vision_hidden_states = [ i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states ] bs = len(data["input_ids"]) for i in range(bs): cur_vs_hs = vision_hidden_states[i] if len(cur_vs_hs) > 0: cur_vllm_emb = vllm_embedding[i] cur_image_bound = data["image_bound"][i] if len(cur_image_bound) > 0: image_indices = torch.stack( [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] ).to(vllm_embedding.device) cur_vllm_emb.scatter_( 0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1]), ) elif self.training: cur_vllm_emb += cur_vs_hs[0].mean() * 0 return vllm_embedding, vision_hidden_states def get_audio_embedding_streaming( self, data, use_extra_context=False, prefix_extra_frames=1, suffix_extra_frames=1, return_debug=False, cnn_min_length=None, ): """Extract audio embeddings in a streaming manner using cached key-value pairs. This method processes incoming audio features incrementally and stores/updates `past_key_values` for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended for streaming scenarios. Args: data (dict): - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. use_extra_context (bool): If True, assumes input contains extra frames for CNN context. prefix_extra_frames (int): Number of prefix extra frames. suffix_extra_frames (int): Number of suffix extra frames. return_debug (bool): Whether to return debug information. cnn_min_length (int): Minimum length for CNN input padding. Returns: List[List[torch.Tensor]]: audio embeddings """ wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] # exist audio if len(wavforms) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape assert batch_size == 1 max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # whisper's past_key_values management (core) if self.audio_past_key_values is not None: cache_length = self.audio_past_key_values[0][0].shape[2] apm_max_len = self.apm.embed_positions.weight.shape[0] if cache_length + max_seq_len >= apm_max_len: logger.warning( f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset." ) self.audio_past_key_values = None # build attention mask (bidirectional attention, same as offline mode) batch_size, _, max_mel_seq_len = wavforms.shape current_seq_len = (max_mel_seq_len - 1) // 2 + 1 # if use extra context, need to adjust sequence length if use_extra_context: # calculate actual sequence length after removing redundancy # conv2's stride=2, so the mapping from mel frames to output frames is ceil(x/2) prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0 suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0 current_seq_len = current_seq_len - prefix_to_remove - suffix_to_remove # calculate history length (if there is KV cache) if self.audio_past_key_values is not None: past_len = self.audio_past_key_values[0][0].shape[2] # get history sequence length total_seq_len = past_len + current_seq_len else: past_len = 0 total_seq_len = current_seq_len # create bidirectional attention mask (full attention) audio_attention_mask = torch.zeros( (batch_size, 1, current_seq_len, total_seq_len), dtype=self.apm.conv1.weight.dtype, device=wavforms.device, ) debug_info = {} if return_debug else None # DEBUG: 打印输入和 past_key_values 状态 if hasattr(self, "_debug_prefill") and self._debug_prefill: wavform_sum = wavforms.sum().item() past_kv_info = ( "None" if self.audio_past_key_values is None else f"len={self.audio_past_key_values[0][0].shape[2]}" ) print( f"[DEBUG audio_embed] wavforms sum={wavform_sum:.6f}, shape={wavforms.shape}, past_kv={past_kv_info}" ) # Step 1: APM processing ret_apm = self.apm( wavforms, past_key_values=self.audio_past_key_values, use_cache=True, output_hidden_states=True, attention_mask=audio_attention_mask, use_extra_context=use_extra_context, prefix_extra_frames=prefix_extra_frames, suffix_extra_frames=suffix_extra_frames, return_debug=return_debug, cnn_min_length=cnn_min_length, ) if return_debug: audio_outputs, debug_info["apm_debug"] = ret_apm else: audio_outputs = ret_apm if hasattr(self, "audio_encoder_layer"): audio_states = audio_outputs.hidden_states[self.audio_encoder_layer] else: audio_states = audio_outputs.last_hidden_state # DEBUG: 打印 apm 输出的 checksum if hasattr(self, "_debug_prefill") and self._debug_prefill: apm_out_sum = audio_outputs.last_hidden_state.sum().item() audio_states_sum = audio_states.sum().item() print(f"[DEBUG audio_embed] apm_output sum={apm_out_sum:.6f}, audio_states sum={audio_states_sum:.6f}") if return_debug: debug_info["apm_output"] = audio_outputs.last_hidden_state.clone() debug_info["audio_states_before_proj"] = audio_states.clone() debug_info["attention_mask"] = audio_attention_mask.clone() debug_info["use_extra_context"] = use_extra_context debug_info["prefix_extra_frames"] = prefix_extra_frames debug_info["suffix_extra_frames"] = suffix_extra_frames self.audio_past_key_values = audio_outputs.past_key_values # Step 2: Projection audio_embeds = self.audio_projection_layer(audio_states) # DEBUG: 打印 projection 和 pooling 后的 checksum if hasattr(self, "_debug_prefill") and self._debug_prefill: proj_sum = audio_embeds.sum().item() print(f"[DEBUG audio_embed] after_projection sum={proj_sum:.6f}, shape={audio_embeds.shape}") if return_debug: debug_info["after_projection"] = audio_embeds.clone() # Step 3: Pooling audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) # DEBUG: 打印 pooling 后的 checksum if hasattr(self, "_debug_prefill") and self._debug_prefill: pool_sum = audio_embeds.sum().item() print(f"[DEBUG audio_embed] after_pooling sum={pool_sum:.6f}, shape={audio_embeds.shape}") if return_debug: debug_info["after_pooling"] = audio_embeds.clone() _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) num_audio_tokens = feature_lens_after_pooling final_audio_embeds = [] idx = 0 for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) idx += 1 final_audio_embeds.append(target_audio_embeds) if return_debug: debug_info["final_embeddings"] = final_audio_embeds return final_audio_embeds, debug_info else: return final_audio_embeds else: return [] if not return_debug else ([], {}) def get_audio_embedding(self, data, chunk_length=-1, dummy=True): dtype = self.apm.embed_positions.weight.dtype device = self.apm.embed_positions.weight.device wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] if len(wavforms) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) seq_range = ( torch.arange( 0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device, ) .unsqueeze(0) .expand(batch_size, max_seq_len) ) lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) # Create mask padding_mask = seq_range >= lengths_expand # 1 for padded values audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( batch_size, 1, max_seq_len, max_seq_len ) audio_attention_mask = audio_attention_mask_.to( dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device ) if chunk_length > 0: chunk_num_frame = int(chunk_length * 50) chunk_mask = self.subsequent_chunk_mask( size=max_seq_len, chunk_size=chunk_num_frame, num_left_chunks=-1, device=audio_attention_mask_.device, ) audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( wavforms, output_hidden_states=True, attention_mask=audio_attention_mask ).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) num_audio_tokens = feature_lens_after_pooling final_audio_embeds = [] idx = 0 for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds elif self.training and dummy: dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) return [audio_embeds] else: return [] def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False): """ Args: data: input_embeddings: chunk_length: whisper use full attention or chunk attention stream_input: use streaming audio embedding or not Returns: final embeddings with audio feature """ if stream_input: audio_embeddings = self.get_audio_embedding_streaming(data) else: audio_embeddings = self.get_audio_embedding(data, chunk_length) bs = len(input_embeddings) if len(data.get("audio_features", [])) > 0: assert len(audio_embeddings) == len(input_embeddings) if len(audio_embeddings) > 0: audio_bounds = data["audio_bounds"] if self.config.stream_input: assert bs == 1, "audio stream_input mode only support batch size 1" for i in range(bs): audio_embs = torch.cat(audio_embeddings[i], dim=0).to( device=input_embeddings.device, dtype=input_embeddings.dtype ) audio_start_pos = 0 for bound in audio_bounds[i]: audio_len = bound[1] - bound[0] input_embeddings[i, bound[0] : bound[1]] = audio_embs[ audio_start_pos : audio_start_pos + audio_len, : ] audio_start_pos += audio_len else: for i in range(bs): audio_embs = audio_embeddings[i] bounds = audio_bounds[i] for embs, bound in zip(audio_embs, bounds): audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( input_embeddings.device ) if embs.shape[0] != len(audio_indices): logger.error(f"Sample {i}:") logger.error(f" Bounds: {bound}, Indices Length: {len(audio_indices)}") logger.error(f" Embeddings Shape: {embs.shape}") logger.error( f" Input Embedding Shape at Indices: {input_embeddings[i, audio_indices].shape}" ) raise ValueError( f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " f"to input indices of length {len(audio_indices)}" ) input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype) elif self.training: for i in range(bs): # dummy audio_embedings input_embeddings += audio_embeddings[0].mean() * 0 return input_embeddings def forward(self, data, **kwargs): vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) vllm_embedding = self.get_omni_embedding( data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length, ) position_ids = data["position_ids"] if position_ids.dtype != torch.int64: position_ids = position_ids.long() return self.llm( input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs, ) def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] outputs = self.llm.generate( inputs_embeds=inputs_embeds, pad_token_id=0, eos_token_id=terminators, attention_mask=attention_mask, output_hidden_states=True, return_dict_in_generate=True, **kwargs, ) return outputs def _decode_stream(self, inputs_embeds, tokenizer, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_config = { "inputs_embeds": inputs_embeds, "pad_token_id": 0, "eos_token_id": terminators, "streamer": streamer, } generation_config.update(kwargs) thread = Thread(target=self.llm.generate, kwargs=generation_config) thread.start() return streamer def _decode_text(self, result_ids, tokenizer): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] result_text = [] for result in result_ids: result = result[result != 0] if result[0] == tokenizer.bos_id: result = result[1:] if result[-1] in terminators: result = result[:-1] result_text.append(tokenizer.decode(result)) return result_text @torch.inference_mode() def generate( self, input_ids=None, pixel_values=None, tgt_sizes=None, audio_features=None, audio_feature_lens=None, image_bound=None, audio_bounds=None, spk_bounds=None, attention_mask=None, tokenizer=None, vision_hidden_states=None, stream=False, **kwargs, ): assert input_ids is not None assert len(input_ids) == len(pixel_values) model_inputs = { "input_ids": input_ids, "audio_features": audio_features, "audio_feature_lens": audio_feature_lens, "image_bound": image_bound, "audio_bounds": audio_bounds, "spk_bounds": spk_bounds, } if vision_hidden_states is None: model_inputs["pixel_values"] = pixel_values model_inputs["tgt_sizes"] = tgt_sizes else: model_inputs["vision_hidden_states"] = vision_hidden_states with torch.inference_mode(): model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) model_inputs["inputs_embeds"] = self.get_omni_embedding( model_inputs, input_embeddings=model_inputs["inputs_embeds"], chunk_length=self.config.audio_chunk_length, ) if stream: result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs) outputs = {} # if stream return TextIteratorStreamer and output is empty else: outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs) result = self._decode_text(outputs.sequences, tokenizer) return result, outputs def _build_streaming_mask(self, tts_tokens_len): tts_sequence_full_length = 1 + self.tts.streaming_text_reserved_len + 1 streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8) streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1 streaming_attention_mask[-1] = 1 return streaming_attention_mask def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048): spk_embeds = self._get_last_spk_embeds(inputs, outputs) text = text.split("<|tts_bos|>")[-1] gen_text = text.split("<|tts_eos|>")[0] tts_text, tts_token_lens = self.prepare_tts_text(gen_text) tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long) streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) logits_warpers, logits_processors = gen_logits( num_code=626, top_p=self.tts.top_p, top_k=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty, ) condition_length = 1 + self.tts.streaming_text_reserved_len + 1 dtype = self.tts.emb_text.weight.dtype emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) past_key_values = [ ( torch.zeros( 1, self.tts.config.num_attention_heads, condition_length - 1, self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, device=self.tts.device, ), torch.zeros( 1, self.tts.config.num_attention_heads, condition_length - 1, self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, device=self.tts.device, ), ) for _ in range(self.tts.config.num_hidden_layers) ] audio_input_ids = torch.zeros( 1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device, ) eos_lab = False for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)): if chunk_idx == 0: begin = chunk_idx * self.tts.streaming_text_chunk_size + 0 end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1 else: begin = chunk_idx * self.tts.streaming_text_chunk_size + 1 end = min( (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1, condition_length - 1, ) if end - begin > 0: text_input_ids = tts_input_ids[:, begin:end] position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) if begin == 0: past_key_values = self.tts.prefill_text( input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values, lm_spk_emb_last_hidden_states=spk_embeds, ) else: past_key_values = self.tts.prefill_text( input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values, ) outputs = self.tts.generate( input_ids=audio_input_ids, past_key_values=past_key_values, streaming_tts_text_mask=streaming_tts_text_mask, max_new_token=output_chunk_size, force_no_stop=self.force_no_stop, temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), logits_warpers=logits_warpers, logits_processors=logits_processors, ) audio_input_ids = outputs.audio_input_ids past_key_values = outputs.past_key_values if outputs.finished: logger.debug("Generation finished.") eos_lab = True break if not eos_lab: logger.debug("eos_lab False, Generation continue.") while True: outputs = self.tts.generate( input_ids=audio_input_ids, past_key_values=past_key_values, streaming_tts_text_mask=streaming_tts_text_mask, max_new_token=output_chunk_size, force_no_stop=self.force_no_stop, temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), logits_warpers=logits_warpers, logits_processors=logits_processors, ) audio_input_ids = outputs.audio_input_ids past_key_values = outputs.past_key_values if outputs.finished: logger.debug("Generation finished.") break if outputs.new_ids.shape[1] > tts_max_new_tokens: logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.") break @staticmethod def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs): num_beams = kwargs.get("num_beams", 3) generation_config = { "num_beams": num_beams, "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.02, } if do_sample: generation_config.update( { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.02, } ) elif num_beams > 1: generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False}) else: generation_config.update({"do_sample": False, "repetition_penalty": 1.02}) generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) generation_config["min_new_tokens"] = min_new_tokens generation_config["max_new_tokens"] = max_new_tokens generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) return generation_config @torch.inference_mode() def chat( self, image=None, msgs=None, tokenizer=None, # deprecated processor=None, # deprecated vision_hidden_states=None, max_new_tokens=4096, min_new_tokens=0, do_sample=True, sampling=None, # deprecated, please use do_sample max_inp_length=8192, stream=False, stream_input=False, max_slice_nums=None, use_image_id=None, enable_thinking=False, use_tts_template=False, generate_audio=False, output_audio_path=None, output_tts_inputs_embeds_path=None, # add omni_mode=False, omni_input=None, # deprecated, please use omni_mode teacher_forcing=False, return_prompt=False, tts_proj_layer=-1, tts_sampling_params: TTSSamplingParams = TTSSamplingParams(), merge_audio_from_same_content=True, **kwargs, ): # todo: deprecated if sampling is not None: do_sample = sampling if omni_input is not None: omni_mode = omni_input batched = isinstance(msgs[0], list) msgs_list = msgs images_list = image if not batched: images_list, msgs_list = [images_list], [msgs_list] else: assert images_list is None, "Please integrate image to msgs when using batch inference." images_list = [None] * len(msgs_list) assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." if not hasattr(self, "processor") or self.processor is None: self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) prompts_lists = [] input_images_list = [] input_audios_list = [] audio_parts_list = [] for image, msgs in zip(images_list, msgs_list): if isinstance(msgs, str): msgs = json.loads(msgs) copy_msgs = deepcopy(msgs) assert len(msgs) > 0, "msgs is empty" assert do_sample or not stream, "if use stream mode, make sure do_sample=True" if image is not None and isinstance(copy_msgs[0]["content"], str): copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] images = [] audios = [] audio_parts = [] for i, msg in enumerate(copy_msgs): role = msg["role"] content = msg["content"] assert role in ["system", "user", "assistant"] if i == 0: assert role in ["user", "system"], "The role of first msg should be user" if isinstance(content, str): content = [content] cur_msgs = [] for c in content: if isinstance(c, Image.Image): images.append(c) cur_msgs.append("./") elif isinstance(c, np.ndarray): # audio audios.append(c) audio_parts.append(i) cur_msgs.append("") use_tts_template = True elif isinstance(c, str): cur_msgs.append(c) if omni_mode or stream_input: msg["content"] = "".join(cur_msgs) else: msg["content"] = "\n".join(cur_msgs) prompts_lists.append( self.processor.tokenizer.apply_chat_template( copy_msgs, tokenize=False, add_generation_prompt=False if teacher_forcing else True, use_tts_template=use_tts_template, enable_thinking=enable_thinking, ) ) input_images_list.append(images) input_audios_list.append(audios) audio_parts_list.append(audio_parts) if not merge_audio_from_same_content: audio_parts_list = None inputs = self.processor( prompts_lists, input_images_list, input_audios_list, audio_parts_list, max_slice_nums=max_slice_nums, use_image_id=use_image_id, stream_input=stream_input, return_tensors="pt", max_length=max_inp_length, ).to(self.device) generation_config = self.prepare_generation_config( do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs ) generation_config.pop("max_new_tokens", None) inputs.pop("image_sizes") # teacher_forcing = True => generate audio with given text with torch.inference_mode(): res, outputs = self.generate( **inputs, tokenizer=self.processor.tokenizer, max_new_tokens=1 if teacher_forcing else max_new_tokens, vision_hidden_states=vision_hidden_states, stream=stream, **generation_config, ) # spk bound and tts bound tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>") tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>") # Combine input_ids and generated sequences to get complete sequence input_ids = inputs["input_ids"][0] generated_ids = outputs.sequences[0] # Combine by concatenating input_ids with the new tokens from generated sequence full_sequence = torch.cat([input_ids, generated_ids]) # Update the sequences in outputs full_sequences = full_sequence.unsqueeze(0) outputs["full_sequences"] = full_sequences tts_bos_indices = [] tts_eos_indices = [] for i, x in enumerate(full_sequences[0]): if x == tts_bos_token: tts_bos_indices.append(i + 1) # tts_bos + 1 才是第一个tts的位置,这样方便直接给tts去slice hidden states elif x == tts_eos_token: if teacher_forcing and i == len(full_sequences[0]) - 1: continue tts_eos_indices.append(i) tts_bos_idx = tts_bos_indices[-1] if tts_bos_indices else -1 # Use None instead of -1 when no EOS token found, so that slice [start:None] # means "to the end" rather than [start:-1] which excludes the last element tts_eos_idx = tts_eos_indices[-1] if tts_eos_indices else None tts_bound = (tts_bos_idx, tts_eos_idx) answer = res[0] if answer is not None: answer = answer.rstrip("<|tts_eos|>") if use_tts_template and generate_audio and output_audio_path: try: generated_waveform = self._generate_speech_non_streaming( outputs=outputs, tts_bound=tts_bound, tts_proj_layer=tts_proj_layer, audio_prompt=( input_audios_list[0][0] if len(input_audios_list) > 0 and len(input_audios_list[0]) > 0 else None ), output_tts_inputs_embeds_path=output_tts_inputs_embeds_path, tts_sampling_params=tts_sampling_params, ) if isinstance(generated_waveform, torch.Tensor): sf.write(output_audio_path, generated_waveform.cpu().numpy(), samplerate=24000) elif isinstance(generated_waveform, np.ndarray): sf.write(output_audio_path, generated_waveform, samplerate=24000) logger.debug(f"audio saved to {output_audio_path}") except: import traceback traceback.print_exc() if return_prompt: return answer, prompts_lists[0] else: return answer @torch.inference_mode() def _generate_speech_non_streaming( self, outputs, tts_bound, tts_proj_layer, audio_prompt, output_tts_inputs_embeds_path=None, tts_sampling_params: TTSSamplingParams = TTSSamplingParams(), ): last_hidden_states = [hs[tts_proj_layer] for hs in outputs.hidden_states] last_hidden_states = torch.vstack([i[0] for i in last_hidden_states]) spk_embeds = ( torch.ones([0, self.tts.config.hidden_size]).to(last_hidden_states.device).to(last_hidden_states.dtype) ) if self.tts.condition_type == "hidden_text_merge": llm_tokens = outputs["full_sequences"][0][tts_bound[0] : tts_bound[1]] llm_tokens = torch.tensor(llm_tokens, device=self.tts.emb_text.weight.device, dtype=torch.long) llm_embeds = self.tts.emb_text(llm_tokens) # make sure emb_text is compatible with llm vocab size hidden_embeds = last_hidden_states[tts_bound[0] : tts_bound[1]] hidden_embeds = self.tts.projector_semantic(hidden_embeds) if self.tts.config.normalize_projected_hidden: hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1) tts_embeds = llm_embeds + hidden_embeds if self.tts.interleaved: chunks = [] cond_length = tts_embeds.shape[0] for i in range(0, cond_length, 10): chunks.append(tts_embeds[i : i + 10]) tts_embeds = chunks else: raise NotImplementedError audio_bos = [self.tts.audio_bos_token_id] audio_bos = torch.tensor(audio_bos, device=self.tts.emb_text.weight.device, dtype=torch.long) audio_bos_embeds = self.tts.emb_text(audio_bos) if self.tts.interleaved: text_eos_embed = self.tts.emb_text( torch.tensor( [self.tts.config.text_eos_token_id], device=self.tts.emb_text.weight.device, dtype=torch.long, ) ) tts_embeds[-1] = torch.cat([tts_embeds[-1], text_eos_embed], dim=0) for i in range(len(tts_embeds)): tts_embeds[i] = torch.cat([tts_embeds[i], audio_bos_embeds], dim=0).unsqueeze(0) outputs = self.tts.interleaved_generate( spk_embeds=spk_embeds, conditions=tts_embeds, temperature=0.8, repetition_penalty=1.05, eos_token=torch.tensor( [self.tts.config.num_audio_tokens - 1], dtype=torch.long, device=self.tts.device, ), ) else: if self.tts.condition_type == "tts_token": inputs_embeds = torch.cat([spk_embeds, tts_embeds, audio_bos_embeds], dim=0).unsqueeze(0) elif self.tts.condition_type == "tts_token_streaming": tts_embeds[1] = spk_embeds.squeeze(0) # apply speaker embedding inputs_embeds = tts_embeds.unsqueeze(0) else: # modern case inputs_embeds = torch.cat([spk_embeds, tts_embeds, audio_bos_embeds], dim=0).unsqueeze(0) # save inputs_embeds to file if output_tts_inputs_embeds_path: torch.save(inputs_embeds, output_tts_inputs_embeds_path) outputs = self.tts.generate( inputs_embeds=inputs_embeds, sampling_params=tts_sampling_params, eos_token=torch.tensor( [self.tts.config.num_audio_tokens - 1], dtype=torch.long, device=self.tts.device, ), ) if self.tts.config.audio_tokenizer_type == "s3tokenizer": generated_tokens = outputs.new_ids.squeeze(-1) reference_audio = audio_prompt if reference_audio is not None: logger.debug("use reference audio in data to generate waveform") prompt_speech_16k = torch.tensor(reference_audio).unsqueeze(0) if self.tts.config.s3_stream_generate: waveform_pred = self.tts.audio_tokenizer.inference_token2wav( speech_tokens=generated_tokens, prompt_speech_16k=prompt_speech_16k, prompt_speech=None, stream=True, n_timesteps=self.tts.config.s3_stream_n_timesteps, code_chunk_size=self.tts.config.s3_stream_chunk_size, chunk_prelook_size=self.tts.config.s3_stream_prelook_size, use_attn_idx=False, ) return waveform_pred[0] else: for i, j in enumerate( self.tts.audio_tokenizer.token2wav( speech_token=generated_tokens, speech_token_len=torch.tensor([generated_tokens.shape[1]], device=generated_tokens.device), prompt_speech_16k=prompt_speech_16k, stream=False, ) ): waveform_pred = j["tts_speech"] waveform_sample_rate = self.tts.audio_tokenizer.sample_rate # 24000 here, not 16000 input. return waveform_pred[0] else: raise NotImplementedError @torch.inference_mode() def init_token2wav_cache(self, prompt_speech_16k): if hasattr(self.tts.audio_tokenizer, "set_stream_cache"): self.tts.audio_tokenizer.cache = None with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: prompt_wav_path = tmp_wav.name sf.write(prompt_wav_path, prompt_speech_16k, 16000) flow_cache_base, hift_cache_base = self.tts.audio_tokenizer.set_stream_cache(prompt_wav_path) self.token2wav_cache = { "flow_cache_base": torch_clone_recursive(flow_cache_base), "hift_cache_base": torch_clone_recursive(hift_cache_base), } else: model_input = self.tts.audio_tokenizer.frontend.frontend_token2wav( speech_tokens=torch.zeros(1, 1, dtype=torch.long, device=self.tts.device), speech_16k=None, prompt_speech_16k=prompt_speech_16k, resample_rate=self.tts.audio_tokenizer.sample_rate, prompt_speech=None, ) prompt_token = model_input["flow_prompt_speech_token"] prompt_feat = model_input["prompt_speech_feat"] embedding = model_input["flow_embedding"] if self.tts.audio_tokenizer.fp16: prompt_feat = prompt_feat.to(torch.half) embedding = embedding.to(torch.half) prepared_cache = self.tts.audio_tokenizer.model.prepare_cache_from_prompt( prompt_token=prompt_token, prompt_feat=prompt_feat, embedding=embedding, n_timesteps=self.tts.config.s3_stream_n_timesteps, code_chunk_size=self.tts.config.s3_stream_chunk_size, chunk_prelook_size=self.tts.config.s3_stream_prelook_size, use_attn_idx=False, ) self.token2wav_cache = prepared_cache # for sliding window def _ensure_dynamic_cache(self): cache = self.llm_past_key_values if cache is None: return None cache = as_dynamic_cache(cache) if isinstance(cache, DynamicCache): self.llm_past_key_values = cache return cache return None def _get_kv_cache_length(self, cache=None): cache = cache if cache is not None else self.llm_past_key_values return get_kv_cache_length(cache) # todo: not-used del? def _rebuild_cache_from_history(self): preserved_ids: List[torch.Tensor] = [] for entry in self._omni_chunk_history: ids = entry.get("input_ids") if ids is None or not isinstance(ids, torch.Tensor) or ids.numel() == 0: continue preserved_ids.append(ids.to(self.device)) if not preserved_ids: self.llm_past_key_values = None self.streaming_position_offset = 0 self._rope_inv_freq_cache.clear() return concat_ids = torch.cat(preserved_ids, dim=1) attention_mask = torch.ones((1, concat_ids.shape[1]), dtype=torch.bool, device=self.device) outputs = self.llm( input_ids=concat_ids, attention_mask=attention_mask, use_cache=True, return_dict=True, ) self.llm_past_key_values = outputs.past_key_values self.streaming_position_offset = 0 self._rope_inv_freq_cache.clear() def _get_rope_theta(self) -> float: return float(getattr(self.llm.config, "rope_theta", 10000.0)) def _realign_rotary_suffix( self, suffix_keys: torch.Tensor, old_positions: torch.Tensor, new_positions: torch.Tensor, ) -> torch.Tensor: """Wrapper for realign_rotary_suffix using instance's rope_theta and cache.""" return realign_rotary_suffix( suffix_keys, old_positions, new_positions, rope_theta=self._get_rope_theta(), inv_freq_cache=self._rope_inv_freq_cache, ) def _encode_text(self, tokenizer, text) -> Optional[torch.Tensor]: if tokenizer is None or not text: return None ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"] return ids.to(self.device) @staticmethod def _safe_decode(tokenizer, input_ids): if tokenizer is None or input_ids is None: return None if isinstance(input_ids, torch.Tensor): ids = input_ids.cpu().tolist() if ids and isinstance(ids[0], list): ids = ids[0] else: ids = input_ids try: return tokenizer.decode(ids, skip_special_tokens=False) except Exception: return None def _finalize_round( self, round_id: Optional[int], cache_before: int, assistant_input_ids: Optional[torch.Tensor] = None ): if round_id is None: self._pending_round_id = None return cache_after = self._get_kv_cache_length() if assistant_input_ids is not None: assistant_len = assistant_input_ids.shape[1] else: assistant_len = max(cache_after - cache_before, 0) if assistant_len > 0: self._register_chunk( assistant_len, "assistant", round_id=round_id, input_ids=assistant_input_ids, tokenizer=self.processor.tokenizer if hasattr(self, "processor") else None, ) logger.info( "Finalized round=%s cache len before=%s after=%s assistant_len=%s", round_id, cache_before, cache_after, assistant_len, ) self._pending_round_id = None self._next_round_id += 1 def _register_chunk( self, seq_len: int, chunk_type: str, *, round_id: int, input_ids=None, tokenizer=None, ) -> None: if seq_len <= 0: return entry = {"length": int(seq_len), "type": chunk_type, "round": round_id} if input_ids is not None: entry["input_ids"] = input_ids.clone().detach() entry["decoded"] = self._safe_decode(tokenizer, entry["input_ids"]) else: entry["input_ids"] = None entry["decoded"] = None self._omni_chunk_history.append(entry) logger.info( "Registered chunk round=%s type=%s len=%s decoded=%s", round_id, chunk_type, entry["length"], entry["decoded"], ) if chunk_type == "system": self.streaming_text_preserve = max(self.streaming_text_preserve, entry["length"]) def _drop_tokens_from_cache(self, length: int, cache: DynamicCache) -> bool: """Drop tokens from cache using the utility function.""" _, new_offset, success = drop_tokens_from_cache( cache=cache, length=length, preserve=self.streaming_text_preserve, position_offset=self.streaming_position_offset, rope_theta=self._get_rope_theta(), inv_freq_cache=self._rope_inv_freq_cache, ) if success: self.streaming_position_offset = new_offset return success def _drop_next_round(self, cache: DynamicCache) -> bool: seen_rounds = set() for entry in self._omni_chunk_history: round_id = entry.get("round") if round_id is None or round_id in seen_rounds: continue seen_rounds.add(round_id) round_entries = [e for e in self._omni_chunk_history if e.get("round") == round_id] if any(e.get("type") == "system" for e in round_entries): continue if self._drop_round(round_id, cache): return True return False def _drop_round(self, round_id: int, cache: DynamicCache) -> bool: entries = [e for e in self._omni_chunk_history if e.get("round") == round_id] if not entries: return False total_len = sum(e["length"] for e in entries) if total_len <= 0: for e in entries: self._omni_chunk_history.remove(e) return False if not self._drop_tokens_from_cache(total_len, cache): return False for e in entries: logger.info( "Dropped round=%s chunk type=%s len=%s decoded=%s", round_id, e["type"], e["length"], e.get("decoded"), ) self._omni_chunk_history.remove(e) return True def _enforce_text_window(self) -> None: if not self.streaming_window_enabled: return cache = self._ensure_dynamic_cache() if cache is None: return high_limit = max(0, int(self.streaming_window_config.text_window_high_tokens)) low_limit = max(0, int(self.streaming_window_config.text_window_low_tokens)) if high_limit <= 0: return target = max(0, low_limit) total_len = self._get_kv_cache_length(cache) if total_len <= high_limit: return dropped_any = False while total_len > target: if not self._drop_next_round(cache): break dropped_any = True total_len = self._get_kv_cache_length(cache) # for sliding window # ============== 抢跑快照/恢复接口 ============== def _save_speculative_snapshot(self) -> SpeculativeSnapshot: """Internal method: save speculative snapshot. Called at the start of streaming_generate, saves to self._speculative_snapshot. Save strategy: - LLM KV Cache: only record length (restore by truncation, zero extra VRAM) - Audio KV Cache: deep clone (as generate sets it to None) - Mel processor: full state snapshot (including buffer) """ # 1. 获取 LLM cache 信息 llm_cache_length = self._get_kv_cache_length() llm_cache_checksum = None if self.llm_past_key_values is not None and hasattr(self.llm_past_key_values, "key_cache"): if len(self.llm_past_key_values.key_cache) > 0: llm_cache_checksum = self.llm_past_key_values.key_cache[0].sum().item() # 2. 获取 audio cache 长度并克隆 audio_past_key_values audio_cache_length = 0 audio_cache_checksum = None audio_past_key_values_clone = None if self.audio_past_key_values is not None: # 处理 DynamicCache 格式(Whisper encoder 可能返回此格式) if isinstance(self.audio_past_key_values, DynamicCache): if hasattr(self.audio_past_key_values, "key_cache") and len(self.audio_past_key_values.key_cache) > 0: audio_cache_length = self.audio_past_key_values.key_cache[0].shape[2] audio_cache_checksum = self.audio_past_key_values.key_cache[0].sum().item() # 深度克隆 DynamicCache cloned_cache = DynamicCache() for k, v in zip(self.audio_past_key_values.key_cache, self.audio_past_key_values.value_cache): cloned_cache.update(k.clone(), v.clone(), layer_idx=len(cloned_cache.key_cache)) audio_past_key_values_clone = cloned_cache logger.debug(f"[Speculative] Cloned DynamicCache with length {audio_cache_length}") # 处理 EncoderDecoderCache 格式 elif isinstance(self.audio_past_key_values, EncoderDecoderCache): self_attn_cache = self.audio_past_key_values.self_attention_cache if hasattr(self_attn_cache, "key_cache") and len(self_attn_cache.key_cache) > 0: audio_cache_length = self_attn_cache.key_cache[0].shape[2] audio_cache_checksum = self_attn_cache.key_cache[0].sum().item() # 深度克隆 EncoderDecoderCache cloned_self_attn = DynamicCache() if hasattr(self_attn_cache, "key_cache"): for k, v in zip(self_attn_cache.key_cache, self_attn_cache.value_cache): cloned_self_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_self_attn.key_cache)) cross_attn_cache = self.audio_past_key_values.cross_attention_cache cloned_cross_attn = DynamicCache() if hasattr(cross_attn_cache, "key_cache"): for k, v in zip(cross_attn_cache.key_cache, cross_attn_cache.value_cache): cloned_cross_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_cross_attn.key_cache)) audio_past_key_values_clone = EncoderDecoderCache(cloned_self_attn, cloned_cross_attn) logger.debug(f"[Speculative] Cloned EncoderDecoderCache with length {audio_cache_length}") # 处理 tuple 格式(兼容旧格式) elif isinstance(self.audio_past_key_values, tuple) and len(self.audio_past_key_values) > 0: audio_cache_length = self.audio_past_key_values[0][0].shape[2] audio_cache_checksum = self.audio_past_key_values[0][0].sum().item() # 深度克隆 audio_past_key_values(tuple of tuples of tensors) audio_past_key_values_clone = tuple( tuple(t.clone() for t in layer_cache) for layer_cache in self.audio_past_key_values ) # 3. 获取 mel processor 快照 mel_processor_snapshot = None mel_buffer_checksum = None if hasattr(self, "processor") and self.processor is not None: mel_processor_snapshot = self.processor.get_streaming_snapshot() if mel_processor_snapshot: buf = mel_processor_snapshot.get("buffer") if buf is not None and len(buf) > 0: mel_buffer_checksum = float(buf.sum()) # 4. 保存 RNG 状态(关键:用于恢复后确保 dithering 等随机操作的确定性) rng_state_cpu = torch.get_rng_state() rng_state_cuda = None if torch.cuda.is_available() and self.device.type == "cuda": rng_state_cuda = torch.cuda.get_rng_state(self.device) # 5. 创建快照 snapshot = SpeculativeSnapshot( llm_cache_length=llm_cache_length, audio_cache_length=audio_cache_length, new_user_msg=self.new_user_msg, llm_generated=self.llm_generated, llm_generate_completed=self.llm_generate_completed, next_round_id=self._next_round_id, pending_round_id=self._pending_round_id, omni_chunk_history_length=len(self._omni_chunk_history), tts_last_turn_tokens=self.tts_last_turn_tokens.clone() if self.tts_last_turn_tokens is not None else None, audio_chunk_idx=self.audio_chunk_idx, mel_processor_snapshot=mel_processor_snapshot, audio_past_key_values=audio_past_key_values_clone, timestamp=time.time(), # 调试字段 llm_cache_checksum=llm_cache_checksum, audio_cache_checksum=audio_cache_checksum, mel_buffer_checksum=mel_buffer_checksum, # RNG 状态 rng_state_cpu=rng_state_cpu, rng_state_cuda=rng_state_cuda, ) logger.info("[Speculative] Saved snapshot: %s", snapshot.summary()) logger.debug( "[Speculative] Snapshot checksums: llm=%.6f, audio=%.6f, mel_buf=%.6f", llm_cache_checksum or 0.0, audio_cache_checksum or 0.0, mel_buffer_checksum or 0.0, ) return snapshot def restore_speculative_snapshot(self) -> bool: """Restore speculative snapshot - called when VAD speculation fails. Restores model state to before streaming_generate was called, allowing continued streaming_prefill for newly arrived audio. Notes: - Snapshot is saved when streaming_generate is called with enable_speculative_snapshot=True - This method uses the most recent snapshot for restoration - Snapshot is cleared after restore, cannot be called repeatedly Returns: bool: Whether restoration was successful """ snapshot = getattr(self, "_speculative_snapshot", None) if snapshot is None: logger.warning("[Speculative] No snapshot to restore") return False try: # 记录恢复前的状态(用于日志对比) current_cache_length = self._get_kv_cache_length() current_history_length = len(self._omni_chunk_history) logger.info( "[Speculative] Restoring snapshot: target=%s", snapshot.summary(), ) logger.info( "[Speculative] Current state before restore: llm_cache=%d, history_len=%d, " "audio_chunk_idx=%d, new_user_msg=%s, llm_generated=%s", current_cache_length, current_history_length, self.audio_chunk_idx, self.new_user_msg, self.llm_generated, ) # 1. 裁剪 LLM KV Cache if current_cache_length > snapshot.llm_cache_length: self._truncate_llm_cache(snapshot.llm_cache_length) logger.debug( "[Speculative] Truncated LLM cache: %d -> %d", current_cache_length, snapshot.llm_cache_length, ) # 2. 恢复 Audio KV Cache(关键:从克隆的副本恢复) # 因为 streaming_generate 会将 audio_past_key_values 设为 None self.audio_past_key_values = snapshot.audio_past_key_values if snapshot.audio_past_key_values is not None: logger.debug( "[Speculative] Restored audio cache: length=%d, checksum=%.6f", snapshot.audio_cache_length, snapshot.audio_cache_checksum or 0.0, ) else: logger.debug("[Speculative] Audio cache restored to None") # 3. 恢复会话状态 self.new_user_msg = snapshot.new_user_msg self.llm_generated = snapshot.llm_generated self.llm_generate_completed = snapshot.llm_generate_completed # 4. 恢复 Round 管理 self._next_round_id = snapshot.next_round_id self._pending_round_id = snapshot.pending_round_id # 5. 截断 chunk 历史 if current_history_length > snapshot.omni_chunk_history_length: self._omni_chunk_history = self._omni_chunk_history[: snapshot.omni_chunk_history_length] logger.debug( "[Speculative] Truncated chunk history: %d -> %d", current_history_length, snapshot.omni_chunk_history_length, ) # 6. 恢复 TTS 状态 self.tts_last_turn_tokens = snapshot.tts_last_turn_tokens # 7. 恢复 streaming 处理器状态 self.audio_chunk_idx = snapshot.audio_chunk_idx # 8. 恢复 mel processor 状态(关键!否则后续 prefill 会因帧数不匹配而失败) if ( snapshot.mel_processor_snapshot is not None and hasattr(self, "processor") and self.processor is not None ): self.processor.restore_streaming_snapshot(snapshot.mel_processor_snapshot) mel_snap = snapshot.mel_processor_snapshot logger.info( "[Speculative] Restored mel processor: buffer_len=%d, chunk_count=%d, " "last_emitted_T=%d, total_samples=%d", len(mel_snap.get("buffer", [])), mel_snap.get("chunk_count", 0), mel_snap.get("last_emitted_T", 0), mel_snap.get("total_samples_processed", 0), ) # 9. 恢复 RNG 状态(关键:确保 dithering 等随机操作的确定性) if snapshot.rng_state_cpu is not None: torch.set_rng_state(snapshot.rng_state_cpu) logger.debug("[Speculative] Restored CPU RNG state") if snapshot.rng_state_cuda is not None and torch.cuda.is_available(): torch.cuda.set_rng_state(snapshot.rng_state_cuda, self.device) logger.debug("[Speculative] Restored CUDA RNG state") # 11. 清理生成过程中产生的临时状态 if hasattr(self, "_streaming_generated_token_ids"): del self._streaming_generated_token_ids if hasattr(self, "_last_streaming_text"): del self._last_streaming_text # 12. 验证恢复后的状态 restored_cache_length = self._get_kv_cache_length() if restored_cache_length != snapshot.llm_cache_length: logger.warning( "[Speculative] LLM cache length mismatch after restore: expected=%d, actual=%d", snapshot.llm_cache_length, restored_cache_length, ) # 验证 LLM cache checksum(如果有) if snapshot.llm_cache_checksum is not None and self.llm_past_key_values is not None: if hasattr(self.llm_past_key_values, "key_cache") and len(self.llm_past_key_values.key_cache) > 0: current_checksum = self.llm_past_key_values.key_cache[0].sum().item() if abs(current_checksum - snapshot.llm_cache_checksum) > 1e-3: logger.warning( "[Speculative] LLM cache checksum mismatch: expected=%.6f, actual=%.6f", snapshot.llm_cache_checksum, current_checksum, ) else: logger.debug( "[Speculative] LLM cache checksum verified: %.6f", current_checksum, ) # 11. 清除快照(只能恢复一次) self._speculative_snapshot = None logger.info( "[Speculative] Restore completed: llm_cache %d -> %d, elapsed=%.3fs", current_cache_length, snapshot.llm_cache_length, time.time() - snapshot.timestamp, ) return True except Exception as e: import traceback logger.error("[Speculative] Failed to restore snapshot: %s", e) logger.error("[Speculative] Traceback: %s", traceback.format_exc()) return False def has_speculative_snapshot(self) -> bool: return getattr(self, "_speculative_snapshot", None) is not None def clear_speculative_snapshot(self) -> None: if hasattr(self, "_speculative_snapshot"): self._speculative_snapshot = None def _truncate_llm_cache(self, target_length: int) -> None: if self.llm_past_key_values is None: return cache = self._ensure_dynamic_cache() if cache is None: return current_length = self._get_kv_cache_length(cache) if current_length <= target_length: return # 裁剪每一层的 cache for layer_idx in range(len(cache.key_cache)): if cache.key_cache[layer_idx].numel() > 0: cache.key_cache[layer_idx] = cache.key_cache[layer_idx][:, :, :target_length, :].contiguous() cache.value_cache[layer_idx] = cache.value_cache[layer_idx][:, :, :target_length, :].contiguous() # 更新 cache 元数据 cache.crop(target_length) cache._seen_tokens = target_length logger.debug("[Speculative] Truncated LLM cache: %d -> %d", current_length, target_length) # ============== 抢跑快照/恢复接口 结束 ============== @torch.inference_mode() def streaming_prefill( self, session_id, msgs, tokenizer=None, # deprecated omni_mode=True, max_slice_nums=None, use_tts_template=True, enable_thinking=False, is_last_chunk=False, # for audio chunk, if is the last chunk, set to True **kwargs, ): assert session_id is not None, "session_id cannot be None" self.is_first = self.session_id is None or session_id != self.session_id if not hasattr(self, "processor") or self.processor is None: self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) images = [] audios = [] assert len(msgs) == 1 copy_msgs = deepcopy(msgs) msg = copy_msgs[0] assert msg["role"] in ["system", "user", "assistant"] is_not_system_prefill = msg["role"] != "system" content = msg["content"] cur_msgs = [] for j, c in enumerate(content): if isinstance(c, Image.Image): images.append(c) cur_msgs.append("./") elif isinstance(c, np.ndarray): audios.append(c) cur_msgs.append("") elif isinstance(c, str): cur_msgs.append(c) else: logger.error(f"Invalid content type: {c}, ignore it.") cur_contents = "".join(cur_msgs) if omni_mode else "\n".join(cur_msgs) if msg["role"] in ["system", "assistant"]: self.new_user_msg = True self.audio_past_key_values = None if self.is_first: self.reset_session(reset_token2wav_cache=False) self.session_id = session_id self.init_streaming_processor() if msg["role"] == "user": # 没有 system prefill,第一个 user turn 的第一个 segment # 不使用 apply_chat_template,手动构建 prompt 以避免自动添加 <|im_end|> prompt = "<|im_start|>user\n" + cur_contents self.new_user_msg = False # 标记后续 segments 不需要再添加 user 前缀 else: # system 或 assistant prefill,使用 apply_chat_template msg["content"] = cur_contents prompt = self.processor.tokenizer.apply_chat_template( copy_msgs, tokenize=False, add_generation_prompt=False, use_tts_template=use_tts_template, enable_thinking=enable_thinking, ) add_special_tokens = True # add bos else: # 非首次 prefill if self.new_user_msg and msg["role"] == "user": # 新的 user turn 的第一个 segment if self.llm_generated: # todo: when to set llm_generate_completed? if self.llm_generate_completed: prompt = "<|im_end|>\n<|im_start|>user\n" + cur_contents else: prompt = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents else: prompt = "<|im_start|>user\n" + cur_contents self.new_user_msg = False else: # 同一个 turn 的后续 segments,直接使用内容 prompt = cur_contents add_special_tokens = False # when first user audio prefill, ensure audio length satisfies FIRST_CHUNK_MS requirements if is_not_system_prefill and len(audios) > 0 and self.audio_chunk_idx == 0: assert len(audios) == 1, f"streaming mode only supports single audio, currently {len(audios)}" first_chunk_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000) if len(audios[0]) < first_chunk_samples: pad_len = first_chunk_samples - len(audios[0]) audios[0] = np.concatenate([np.zeros(pad_len, dtype=audios[0].dtype), audios[0]]) model_inputs = self.processor( [prompt], [images], [audios], max_slice_nums=1 if max_slice_nums is None else max_slice_nums, use_image_id=False, chunk_input=True, return_tensors="pt", max_length=None, sampling_rate=16000, add_special_tokens=add_special_tokens, online_streaming=is_not_system_prefill, audio_chunk_idx=self.audio_chunk_idx, is_last_chunk=is_last_chunk, ).to(self.device) # DEBUG: 打印 mel 特征的 checksum(用于诊断 rollback 不一致问题) if len(audios) > 0 and is_not_system_prefill and hasattr(self, "_debug_prefill") and self._debug_prefill: audio_feats = model_inputs.get("audio_features", None) if audio_feats is not None and hasattr(audio_feats, "sum"): mel_sum = audio_feats.sum().item() mel_shape = audio_feats.shape print( f"[DEBUG prefill] audio_chunk_idx={self.audio_chunk_idx}, mel_sum={mel_sum:.6f}, mel_shape={mel_shape}" ) else: print(f"[DEBUG prefill] audio_chunk_idx={self.audio_chunk_idx}, audio_feats type={type(audio_feats)}") if len(audios) > 0 and is_not_system_prefill: self.audio_chunk_idx += 1 # 1. prepare input embeddings model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs) # get audio embedding with audio_past_key_values # todo: should pass chunk_length=self.config.audio_chunk_length ? inputs_embeds = self.get_omni_embedding( model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=is_not_system_prefill ) # DEBUG: 打印 inputs_embeds 的 checksum if len(audios) > 0 and is_not_system_prefill and hasattr(self, "_debug_prefill") and self._debug_prefill: embed_sum = inputs_embeds.sum().item() embed_shape = inputs_embeds.shape print(f"[DEBUG prefill] inputs_embeds sum={embed_sum:.6f}, shape={embed_shape}") if self.is_first: self.audio_past_key_values = None # clean audio_past_key_values after first prefill round_id = self._next_round_id self._pending_round_id = round_id chunk_type = "system" if msg["role"] == "system" else ("user" if msg["role"] == "user" else "assistant") seq_len = inputs_embeds.shape[1] self._enforce_text_window() cache_length = self._get_kv_cache_length() attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device) # 2. do prefill outputs = self.llm( past_key_values=self.llm_past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=None, use_cache=True, return_dict=True, ) self.llm_past_key_values = as_dynamic_cache(outputs["past_key_values"]) self._register_chunk( seq_len, chunk_type, round_id=round_id, input_ids=model_inputs["input_ids"], tokenizer=self.processor.tokenizer, ) self._enforce_text_window() if self.force_rope_reindex: self._force_reindex_all_cache() return prompt @torch.inference_mode() def streaming_generate( self, session_id, tokenizer=None, # deprecated bos_input=None, generate_audio=True, audio_token_chunk_size=25, # 25 token/s tts_sampling_params: TTSSamplingParams = TTSSamplingParams(), max_new_tokens=256, fn="chunk_generate", enable_thinking=False, use_tts_template=True, do_sample=True, enable_speculative_snapshot=False, **kwargs, ): # 保存抢跑快照(在修改任何状态之前) # 用于 VAD 抢跑场景:如果抢跑失败,可调用 restore_speculative_snapshot() 恢复 # enable_speculative_snapshot=True 时启用,False 时跳过(节省少量开销) if enable_speculative_snapshot: self._speculative_snapshot = self._save_speculative_snapshot() # reset buf self.new_user_msg = True self.llm_generated = True self.llm_generate_completed = False self.audio_past_key_values = None if not hasattr(self, "processor") or self.processor is None: self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) # reset current turn generated token IDs if hasattr(self, "_streaming_generated_token_ids"): del self._streaming_generated_token_ids # reset full generated text if hasattr(self, "_last_streaming_text"): del self._last_streaming_text cache = self._ensure_dynamic_cache() cache_length = self._get_kv_cache_length(cache) host_round_id = self._pending_round_id logger.info("streaming_generate kv cache length before= %s", cache_length) ## 单工情况每调用一次 streaming_generate 需要重新初始化 streaming_processor, 进入下一个 turn self.init_streaming_processor() # 1) llm generate token and hidden states per chunk=10, 2) tts generate audio token chunk per chunk=25, 3) yield 1 chunk audio token def audio_chunk_generator( bos_input, tokenizer, generate_audio, tts_sampling_params, max_new_tokens, do_sample, **kwargs, ): generate_chunk_size = 10 if bos_input is None: bos_input = "".join( [ "<|im_end|>\n<|im_start|>assistant\n", "" if enable_thinking else self.think_str.replace("\\n", "\n"), "<|tts_bos|>" if use_tts_template else "", ] ) bos_input_ids = tokenizer.encode(bos_input) bos_input_ids = torch.tensor(bos_input_ids, dtype=torch.long, device=self.device).unsqueeze(0) # DEBUG: 打印生成开始时的状态 _cache_len = self._get_kv_cache_length() _cache_sum = self.llm_past_key_values.key_cache[0].sum().item() if self.llm_past_key_values else 0 # 检查 KV Cache 最后几个位置的值 _k_last = ( self.llm_past_key_values.key_cache[0][0, 0, -5:, :3].flatten().tolist() if self.llm_past_key_values else [] ) print(f"[DEBUG streaming_generate] cache_len={_cache_len}, cache_sum={_cache_sum:.6f}, k_last={_k_last}") bos_input_embeds = self.llm.get_input_embeddings()(bos_input_ids) generation_inputs_embeds = bos_input_embeds generated_ids = torch.empty((1, 0), dtype=torch.long, device=self.device) num_chunks_decode = (max_new_tokens + generate_chunk_size - 1) // generate_chunk_size conditions = [] # generate chunk by chunk, each chunk has 10 tokens, each chunk takes last hidden states, and pass tokens to tts llm_streaming_generator = ChunkPrefillChunkGenerate( model=self.llm, tokenizer=tokenizer, terminators=["<|tts_eos|>", "<|im_end|>", ""], ) if generate_audio: logits_warpers, logits_processors = gen_logits( num_code=self.tts.config.num_audio_tokens, repetition_penalty=tts_sampling_params.repetition_penalty, top_p=tts_sampling_params.top_p, top_k=tts_sampling_params.top_k, ) tts_streaming_generator = TTSStreamingGenerator( model=self.tts, temperature=tts_sampling_params.temperature, eos_token=torch.tensor( [self.tts.config.num_audio_tokens - 1], dtype=torch.long, device=self.tts.device, ), chunk_size=audio_token_chunk_size, # s3tokenizer 1s = 25token tts_last_turn_tokens=self.tts_last_turn_tokens, logits_processors=logits_processors, logits_warpers=logits_warpers, ) # LLM chunk generate outer loop for chunk_idx in range(num_chunks_decode): is_first_generate_chunk = chunk_idx == 0 if fn == "chunk_generate": output = llm_streaming_generator.chunk_generate( inputs_embeds=generation_inputs_embeds, past_key_values=self.llm_past_key_values, is_first_generate_chunk=is_first_generate_chunk, return_hidden_states=True, chunk_size=generate_chunk_size + 1 * is_first_generate_chunk, do_sample=do_sample, temperature=kwargs.get("temperature", 0.7), top_p=kwargs.get("top_p", 0.8), top_k=kwargs.get("top_k", 100), repetition_penalty=kwargs.get("repetition_penalty", 1.02), all_input_ids=generated_ids, ) else: output = llm_streaming_generator.chunk_generate_hf( inputs_embeds=generation_inputs_embeds, past_key_values=self.llm_past_key_values, is_first_generate_chunk=is_first_generate_chunk, return_hidden_states=True, chunk_size=generate_chunk_size + 1 * is_first_generate_chunk, do_sample=do_sample, **kwargs, ) if output.chunk_token_ids is None: break # DEBUG: 打印第一个 chunk 生成的 token if chunk_idx == 0: print(f"[DEBUG streaming_generate] first_chunk_tokens={output.chunk_token_ids.tolist()}") if is_first_generate_chunk: if generate_audio: spk_emb = torch.empty( (bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]), dtype=bos_input_embeds.dtype, device=bos_input_embeds.device, ) tts_streaming_generator.spk_emb = spk_emb if output.finished: yield_chunk_token_ids = output.chunk_token_ids else: # the first chunk generated chunk_size + 1 tokens, we only take the first chunk_size tokens, # the last token is not prefilled, and last hidden states is not obtained yield_chunk_token_ids = output.chunk_token_ids[:, :-1] elif output.finished: yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids], dim=1) else: # in the chunk that is not the first chunk, we need to add the token at the end of the previous chunk, # it is not prefilled into the model to get last hidden states # similarly, the last generated token of subsequent chunks is not prefilled, and last hidden states is not obtained, # so it is not passed out yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids[:, :-1]], dim=1) if not generate_audio: chunk_generated_text = tokenizer.decode(yield_chunk_token_ids[0]) yield yield_chunk_token_ids, output.finished else: # TTS inner loop # dense connection here is hardcoded to use text-hidden merged as condition llm_embeds = self.tts.emb_text(yield_chunk_token_ids) hidden_embeds = output.last_hidden_states hidden_embeds = self.tts.projector_semantic(hidden_embeds) if self.tts.config.normalize_projected_hidden: # default should be opened hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1) tts_embeds = llm_embeds + hidden_embeds conditions.append(tts_embeds) # Store token IDs instead of decoded text to avoid UTF-8 multi-byte character truncation if not hasattr(self, "_streaming_generated_token_ids"): self._streaming_generated_token_ids = [] self._streaming_generated_token_ids.extend(yield_chunk_token_ids[0].tolist()) # there is buffer generated, each time exactly returns 25 audio tokens, # the last audio chunk returns audio tokens of variable length, length [0, 25] tts_generator = tts_streaming_generator.generate_with_buffer( condition=tts_embeds, text_finished=output.finished ) for audio_token_chunk, is_last_audio_chunk in tts_generator: yield audio_token_chunk, is_last_audio_chunk generated_ids = torch.cat([generated_ids, output.chunk_token_ids], dim=1) generation_inputs_embeds = output.current_inputs_embeds self.llm_past_key_values = output.past_key_values if output.finished: if generate_audio: self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens break # IMPORTANT: Flush remaining TTS buffer when LLM generation ends # This handles BOTH cases: # 1. LLM finished with terminator (output.finished=True) - buffer may still have tokens # 2. LLM hit max chunks limit (output.finished=False) - buffer definitely has tokens if generate_audio: if len(tts_streaming_generator._token_buffer) > 0: batch = torch.cat(tts_streaming_generator._token_buffer, dim=1) yield batch, True tts_streaming_generator._token_buffer = [] if generate_audio: if hasattr(self, "_streaming_generated_token_ids"): try: self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids) assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text) self._finalize_round( round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids ) except Exception: self._last_streaming_text = None else: self._last_streaming_text = None yield None, None else: return # iter for generating text chunk and audio chunk audio_chunk_generator_iter = audio_chunk_generator( bos_input=bos_input, tokenizer=self.processor.tokenizer, generate_audio=generate_audio, tts_sampling_params=tts_sampling_params, max_new_tokens=max_new_tokens, do_sample=do_sample, **kwargs, ) if generate_audio: if self.tts.config.audio_tokenizer_type == "s3tokenizer_step_audio": self.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.token2wav_cache["flow_cache_base"]) self.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive( self.token2wav_cache["hift_cache_base"] ) # pre-insert 3-5 prefix 4218 silence tokens, each token corresponds to 0.04s, # adding 5 tokens means introducing 0.2s of silence buffer = [4218] * 3 pre_lookahead = 3 CHUNK_SIZE = 25 chunk_idx = 0 prev_text_len = 0 # track text position for streaming text output for audio_token_chunk, is_last_audio_chunk in audio_chunk_generator_iter: if audio_token_chunk is None: break buffer += audio_token_chunk.reshape(-1).tolist() if len(buffer) >= CHUNK_SIZE + pre_lookahead: waveform_chunk = self.tts.audio_tokenizer.stream( buffer[: CHUNK_SIZE + pre_lookahead], prompt_wav=None, last_chunk=is_last_audio_chunk, return_waveform=True, ) waveform_chunk = torch.from_numpy(waveform_chunk) # get new text chunk corresponding to this waveform # Decode from accumulated token IDs to avoid UTF-8 multi-byte truncation new_text = "" if hasattr(self, "_streaming_generated_token_ids"): current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids) # Filter out trailing replacement characters (incomplete UTF-8 sequences) safe_end = len(current_text) while safe_end > 0 and current_text[safe_end - 1] == "\ufffd": safe_end -= 1 safe_text = current_text[:safe_end] new_text = safe_text[prev_text_len:] prev_text_len = len(safe_text) yield waveform_chunk, new_text buffer = buffer[CHUNK_SIZE:] chunk_idx += 1 # flush rest if len(buffer) > 0: waveform_chunk = self.tts.audio_tokenizer.stream( buffer, prompt_wav=None, last_chunk=True, return_waveform=True, ) waveform_chunk = torch.from_numpy(waveform_chunk) # get remaining new text for the final chunk # Final chunk: decode all remaining text without filtering new_text = "" if hasattr(self, "_streaming_generated_token_ids"): current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids) new_text = current_text[prev_text_len:] prev_text_len = len(current_text) yield waveform_chunk, new_text # maybe the buffer is empty, and text is not empty, should we flush text without wave? else: raise NotImplementedError(f"not supported audio tokenizer: {self.tts.config.audio_tokenizer_type}") else: # For text-only generation, decode tokens and handle partial multi-byte characters yield from streaming_token_decoder( audio_chunk_generator_iter, self.processor.tokenizer, skip_special_tokens=False, ) class MiniCPMODuplex: def __init__( self, name_or_path: str, duplex_config: Optional[MiniCPMODuplexConfig] = None, generate_audio: bool = True, ls_mode: str = "explicit", device: str = "cuda", pt_path: Optional[str] = None, **kwargs, ): """Initialize MiniCPMODuplex. Args: name_or_path: Path to the pretrained model or model identifier. duplex_config: Optional MiniCPMODuplexConfig. If provided, overrides other params. generate_audio: Whether to generate audio output. ls_mode: Listen/Speak mode, e.g., "explicit". device: Device to load the model on. pt_path: Optional path to additional checkpoint weights. **kwargs: Additional generation config parameters (overrides duplex_config if provided). """ self.session_logs = [] self.session_start_time = None self.log_file_path = None self.name_or_path = name_or_path if duplex_config is not None: self.duplex_config = duplex_config self.generate_audio = kwargs.get("generate_audio", duplex_config.generate_audio) self.ls_mode = kwargs.get("ls_mode", duplex_config.ls_mode) attn_implementation = kwargs.get("attn_implementation", duplex_config.attn_implementation) else: self.duplex_config = None self.generate_audio = generate_audio self.ls_mode = ls_mode attn_implementation = kwargs.get("attn_implementation", "flash_attention_2") self.device = device from transformers import AutoConfig from transformers import AutoTokenizer from .processing_minicpmo import MiniCPMOProcessor from .stream_decoder import StreamDecoder self.processor = MiniCPMOProcessor.from_pretrained(name_or_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(name_or_path, trust_remote_code=True) self.processor.tokenizer = self.tokenizer config = AutoConfig.from_pretrained(name_or_path, trust_remote_code=True) vision_batch_size = kwargs.pop("vision_batch_size", None) audio_pool_step = kwargs.pop("audio_pool_step", None) audio_chunk_length = kwargs.pop("audio_chunk_length", None) max_slice_nums = kwargs.pop("max_slice_nums", None) if vision_batch_size is not None and hasattr(config, "vision_batch_size"): config.vision_batch_size = vision_batch_size if audio_pool_step is not None and hasattr(config, "audio_pool_step"): config.audio_pool_step = audio_pool_step if audio_chunk_length is not None and hasattr(config, "audio_chunk_length"): config.audio_chunk_length = audio_chunk_length if max_slice_nums is not None and hasattr(config.slice_config, "max_slice_nums"): config.slice_config.max_slice_nums = max_slice_nums self.model = MiniCPMO.from_pretrained( name_or_path, config=config, trust_remote_code=True, attn_implementation=attn_implementation ) self.model.to(torch.bfloat16) self.model.processor = self.processor if pt_path is not None: logger.info(f"Loading checkpoint from {pt_path}") state_dict = torch.load(pt_path, map_location="cpu") info = self.model.load_state_dict(state_dict, strict=False) logger.warning(info) del state_dict self.model.eval().to(device=device) self.model.init_tts( streaming=True, enable_float16=kwargs.get("enable_float16", False), n_timesteps=kwargs.get("n_timesteps", 10), ) self.break_event = threading.Event() self.session_stop_event = threading.Event() # llm generation_config - from duplex_config or defaults self.max_new_speak_tokens_per_chunk = self._get_value( "max_new_speak_tokens_per_chunk", duplex_config, kwargs, 20 ) self.text_repetition_penalty = self._get_value("text_repetition_penalty", duplex_config, kwargs, 1.05) self.temperature = self._get_value("temperature", duplex_config, kwargs, 0.7) self.top_k = self._get_value("top_k", duplex_config, kwargs, 20) self.top_p = self._get_value("top_p", duplex_config, kwargs, 0.8) self.text_repetition_window_size = self._get_value("text_repetition_window_size", duplex_config, kwargs, 512) self.listen_prob_scale = self._get_value("listen_prob_scale", duplex_config, kwargs, 1.0) self.force_listen_count = self._get_value("force_listen_count", duplex_config, kwargs, 0) # tts generation_config tts_temp_value = self._get_value("tts_temperature", duplex_config, kwargs, 0.8) self.tts_temperature = torch.tensor([tts_temp_value], dtype=torch.float, device=self.device) self.tts_repetition_penalty = self._get_value("tts_repetition_penalty", duplex_config, kwargs, 1.05) # stream config self.CHUNK_MS = self._get_value("chunk_ms", duplex_config, kwargs, 1000) self.FIRST_CHUNK_MS = self._get_value("first_chunk_ms", duplex_config, kwargs, 1035) self.CNN_REDUNDANCY_MS = self._get_value("cnn_redundancy_ms", duplex_config, kwargs, 20) self.SAMPLE_RATE = self._get_value("sample_rate", duplex_config, kwargs, 16000) self.model.CHUNK_MS = self.CHUNK_MS self.model.FIRST_CHUNK_MS = self.FIRST_CHUNK_MS self.model.CNN_REDUNDANCY_MS = self.CNN_REDUNDANCY_MS self.model.SAMPLE_RATE = self.SAMPLE_RATE # special tokens self.unit_token_id = self.tokenizer.convert_tokens_to_ids("") self.image_start_token_id = self.tokenizer.convert_tokens_to_ids("") self.image_end_token_id = self.tokenizer.convert_tokens_to_ids("") self.slice_start_token_id = self.tokenizer.convert_tokens_to_ids("") self.slice_end_token_id = self.tokenizer.convert_tokens_to_ids("") self.listen_token_id = self.tokenizer.convert_tokens_to_ids("<|listen|>") self.speak_token_id = self.tokenizer.convert_tokens_to_ids("<|speak|>") self.tts_bos_token_id = self.tokenizer.convert_tokens_to_ids("<|tts_bos|>") self.tts_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|tts_eos|>") self.chunk_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>") self.chunk_tts_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>") self.turn_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>") self.chunk_terminator_token_ids = [self.listen_token_id, self.chunk_eos_token_id, self.chunk_tts_eos_token_id] self.turn_terminator_token_ids = [self.turn_eos_token_id] self.chunk_speak_token_ids = [self.speak_token_id] self.tts_pad_id = self.tokenizer.convert_tokens_to_ids("<|tts_pad|>") bad_token_ids = getattr(self.tokenizer, "bad_token_ids", []) self.forbidden_token_ids = [self.tts_pad_id] + list(bad_token_ids) self.decoder = StreamDecoder( llm=self.model.llm, tokenizer=self.tokenizer, forbidden_token_ids=self.forbidden_token_ids ) # 配置滑窗参数 from .stream_decoder import DuplexWindowConfig # 滑窗模式: "off" / "basic" / "context" sliding_window_mode = self._get_value("sliding_window_mode", duplex_config, kwargs, "off") # 不带 Context 的滑窗参数 basic_window_high_tokens = self._get_value("basic_window_high_tokens", duplex_config, kwargs, 4000) basic_window_low_tokens = self._get_value("basic_window_low_tokens", duplex_config, kwargs, 3500) # 带 Context 的滑窗参数 context_previous_max_tokens = self._get_value("context_previous_max_tokens", duplex_config, kwargs, 500) context_max_units = self._get_value("context_max_units", duplex_config, kwargs, 24) self.decoder.set_window_config( DuplexWindowConfig( sliding_window_mode=sliding_window_mode, basic_window_high_tokens=basic_window_high_tokens, basic_window_low_tokens=basic_window_low_tokens, context_previous_max_tokens=context_previous_max_tokens, context_max_units=context_max_units, ) ) # 根据 mode 设置滑窗开关 window_enabled = sliding_window_mode != "off" self.decoder.set_window_enabled(window_enabled) logger.info( "[Duplex] Sliding window: mode=%s, high=%d, low=%d, prev_max=%d, max_units=%d", sliding_window_mode, basic_window_high_tokens, basic_window_low_tokens, context_previous_max_tokens, context_max_units, ) self.tts_logits_processors = None self.tts_eos_token = None if self.generate_audio: self.tts_logits_processors = gen_logits( num_code=self.model.tts.config.num_audio_tokens, repetition_penalty=self.tts_repetition_penalty, ) self.tts_eos_token = torch.tensor( [self.model.tts.config.num_audio_tokens - 1], dtype=torch.long, device=self.device, ) # for offline_generate self.step_map = { "sp_text": ["", "", "", "", "<|listen|>"], "image": [""], "audio": ["<|audio|>"], } self.loop = [] self._reset_streaming_state() import gc gc.collect() torch.cuda.empty_cache() @staticmethod def _get_value(key, config, kwargs, default): return kwargs.get(key, getattr(config, key, default)) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, device: str = "cuda", pt_path: Optional[str] = None, **kwargs, ) -> "MiniCPMODuplex": """Load MiniCPMODuplex from a pretrained model directory. This method loads the duplex_config.json if available, and uses it to configure the duplex model. Any kwargs provided will override the config values. Args: pretrained_model_name_or_path: Path to the model directory. device: Device to load the model on. pt_path: Optional path to additional checkpoint weights. **kwargs: Override parameters for generation config. Returns: MiniCPMODuplex instance. """ try: duplex_config = MiniCPMODuplexConfig.from_pretrained(pretrained_model_name_or_path) except FileNotFoundError: duplex_config = None return cls( name_or_path=pretrained_model_name_or_path, duplex_config=duplex_config, device=device, pt_path=pt_path, **kwargs, ) def save_duplex_config(self, save_directory: str): """Save the current duplex configuration to a directory. Args: save_directory: Path to save the duplex_config.json. """ config = MiniCPMODuplexConfig( generate_audio=self.generate_audio, ls_mode=self.ls_mode, max_new_speak_tokens_per_chunk=self.max_new_speak_tokens_per_chunk, text_repetition_penalty=self.text_repetition_penalty, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, text_repetition_window_size=self.text_repetition_window_size, listen_prob_scale=self.listen_prob_scale, tts_temperature=self.tts_temperature.item(), tts_repetition_penalty=self.tts_repetition_penalty, chunk_ms=self.CHUNK_MS, first_chunk_ms=self.FIRST_CHUNK_MS, cnn_redundancy_ms=self.CNN_REDUNDANCY_MS, sample_rate=self.SAMPLE_RATE, ) config.save_pretrained(save_directory) def set_break_event(self): self.break_event.set() def clear_break_event(self): self.break_event.clear() def set_session_stop(self): self.session_stop_event.set() self.break_event.set() def clear_session_stop(self): self.session_stop_event.clear() def is_break_set(self) -> bool: return self.break_event.is_set() def is_session_stop_set(self) -> bool: return self.session_stop_event.is_set() def _init_token2wav_cache(self, prompt_wav_path: str): self.model.tts.audio_tokenizer.cache = None flow_cache, hift_cache = self.model.tts.audio_tokenizer.set_stream_cache(prompt_wav_path) self.flow_cache_base = torch_clone_recursive(flow_cache) self.hift_cache_base = torch_clone_recursive(hift_cache) self.pre_lookahead = int(self.model.tts.audio_tokenizer.flow.pre_lookahead_len) self.token2wav_initialized = True def _reset_token2wav_for_new_turn(self): if self.token2wav_initialized: self.model.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.flow_cache_base) self.model.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(self.hift_cache_base) self.token2wav_buffer = [4218] * 3 # silence token prefix def _reset_streaming_state(self): self.audio_chunk_idx = 0 self.current_turn_ended = True self.speak_count = 0 self.res_ids = [] self.total_ids = [] self.total_hidden = [] # TTS state self.tts_text_start_pos = 0 self.tts_past_key_values = None self.tts_current_turn_start_time = None # token2wav state self.token2wav_initialized = False self.token2wav_buffer = [] self.flow_cache_base = None self.hift_cache_base = None # Audio prefill state self.audio_buffer = np.array([], dtype=np.float32) self.pending_logits: Optional[torch.Tensor] = None self.current_mode: Optional[str] = None # Force listen state self._streaming_generate_count = 0 def prepare( self, prefix_system_prompt: Optional[str] = None, suffix_system_prompt: Optional[str] = None, ref_audio: Optional[np.ndarray] = None, prompt_wav_path: Optional[str] = None, mode: Literal["omni", "video", "audio"] = "omni", context_previous_marker: str = "\n\nprevious: ", ): self.clear_break_event() self.clear_session_stop() self.session_start_time = time.time() self._reset_streaming_state() self.decoder.reset() self.model.init_streaming_processor() if prompt_wav_path is not None and prompt_wav_path and self.generate_audio: self._init_token2wav_cache(prompt_wav_path) self._reset_token2wav_for_new_turn() self._update_loop(mode) # Prefill system prompt prefix if prefix_system_prompt: tokens = self.tokenizer.encode(prefix_system_prompt, add_special_tokens=False) for token_id in tokens: self.decoder.feed(self.decoder.embed_token(token_id)) # Prefill reference audio if ref_audio is not None: data = self.processor.process_audio([ref_audio]) embeds_nested = self.model.get_audio_embedding(data, chunk_length=self.model.config.audio_chunk_length) embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) if embeds_nested else None if embeds is not None: self.decoder.feed(embeds) # 注册 system prompt 保护长度(滑窗时保护这部分不被移除) if prefix_system_prompt or suffix_system_prompt or ref_audio is not None: logger.info("[Duplex] prepare: registering system prompt protection") if self.decoder._window_config.sliding_window_mode == "context": # Context 保留模式: # 初始化时布局: [prefix] [suffix] [units...] # 首次滑窗后布局: [prefix] [context_previous_marker + content] [suffix] [units...] # 此时先注册 prefix 长度,再 feed suffix self._prefix_system_prompt = prefix_system_prompt self._suffix_system_prompt = suffix_system_prompt self._ref_audio = ref_audio # 获取 suffix token ids suffix_token_ids = [] if suffix_system_prompt: suffix_token_ids = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False) # 注册(此时 cache 只有 prefix,还没有 suffix,也没有 previous) self.decoder.register_system_prompt_with_context( suffix_token_ids=suffix_token_ids, context_previous_marker=context_previous_marker, # 首次滑窗时动态添加 ) # 现在 feed suffix for token_id in suffix_token_ids: self.decoder.feed(self.decoder.embed_token(token_id)) logger.info( "[Duplex] prepare: context-preserve mode, prefix=%d, suffix=%d tokens, marker='%s'", self.decoder._preserve_prefix_length, len(suffix_token_ids), context_previous_marker.replace("\n", "\\n"), ) else: # 非 context 保留模式:先 feed suffix,再注册总长度 if suffix_system_prompt: tokens = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False) for token_id in tokens: self.decoder.feed(self.decoder.embed_token(token_id)) self.decoder.register_system_prompt() if prefix_system_prompt or suffix_system_prompt: if ref_audio is not None: full_prompt = (prefix_system_prompt or "") + "[音频嵌入]" + (suffix_system_prompt or "") else: full_prompt = (prefix_system_prompt or "") + (suffix_system_prompt or "") return full_prompt return "" @torch.no_grad() def streaming_prefill( self, audio_waveform: Optional[np.ndarray] = None, frame_list: Optional[list] = None, max_slice_nums: Union[int, List[int]] = 1, ): """Streaming prefill - called once per second, processing audio/video data Args: audio_waveform: audio waveform data frame_list: image frame list max_slice_nums: maximum number of slices for HD image encoding (default 1, no slicing) Can be an int (same for all images) or a list matching frame_list length Process: 0. determine mode based on input: AUDIO / VISION / OMNI 1. feed token 2. get and feed image embed (if frame_list) - return pending logits in VISION MODE 3. get and feed audio embed (if audio_waveform) - return pending logits in AUDIO/OMNI MODE Returns: dict with keys: - success: bool - cost_vision_process: float (image processing time) - cost_vision_embed: float (vision embedding time) - cost_vision_feed: float (vision feed time) - cost_audio_process: float (audio processing time) - cost_audio_embed: float (audio embedding time) - cost_audio_feed: float (audio feed time) - cost_all: float (total time) """ start_time = time.time() cost_vision_process = 0.0 cost_vision_embed = 0.0 cost_vision_feed = 0.0 cost_audio_process = 0.0 cost_audio_embed = 0.0 cost_audio_feed = 0.0 def _make_result(success, reasons=""): reason = reasons if isinstance(reasons, list): reason = "; ".join(reasons) return { "success": success, "reason": reason, "cost_vision_process": cost_vision_process, "cost_vision_embed": cost_vision_embed, "cost_vision_feed": cost_vision_feed, "cost_audio_process": cost_audio_process, "cost_audio_embed": cost_audio_embed, "cost_audio_feed": cost_audio_feed, "cost_all": time.time() - start_time, } if self.is_session_stop_set() or self.is_break_set(): return _make_result(False) has_frames = frame_list is not None and len(frame_list) > 0 has_audio = audio_waveform is not None and len(audio_waveform) > 0 if has_frames and has_audio: mode = "OMNI" elif has_frames: mode = "VISION" elif has_audio: mode = "AUDIO" else: return _make_result(False) self.pending_logits = None # 滑窗:记录 unit 开始位置 logger.info( "[Duplex] streaming_prefill: mode=%s, has_frames=%s, has_audio=%s, starting unit", mode, has_frames, has_audio, ) self.decoder.register_unit_start() # Step 1: Feed token self.decoder.feed(self.decoder.embed_token(self.unit_token_id)) # Step 2: process image if has_frames: t0 = time.time() # Normalize max_slice_nums to a list matching frame_list length if isinstance(max_slice_nums, int): max_slice_nums_list = [max_slice_nums] * len(frame_list) else: max_slice_nums_list = list(max_slice_nums) if len(max_slice_nums_list) != len(frame_list): raise ValueError( f"max_slice_nums list length ({len(max_slice_nums_list)}) " f"must match frame_list length ({len(frame_list)})" ) # Check if all max_slice_nums are the same (can use batch processing) all_same = len(set(max_slice_nums_list)) == 1 if all_same: # All images use the same max_slice_nums, use batch processing processed_frames = self.processor.process_image(frame_list, max_slice_nums=max_slice_nums_list[0]) if self.device: processed_frames = processed_frames.to(self.device) else: # Different max_slice_nums per image, process individually and merge all_pixel_values = [] all_tgt_sizes = [] for frame, max_slices in zip(frame_list, max_slice_nums_list): pf = self.processor.process_image([frame], max_slice_nums=max_slices) if self.device: pf = pf.to(self.device) # pf["pixel_values"][0] is the list of slices for this image all_pixel_values.extend(pf["pixel_values"][0]) # pf["tgt_sizes"][0] is the array of target sizes for this image's slices if hasattr(pf["tgt_sizes"][0], "tolist"): all_tgt_sizes.extend(pf["tgt_sizes"][0].tolist()) else: all_tgt_sizes.extend(list(pf["tgt_sizes"][0])) # Reconstruct processed_frames with merged data processed_frames = { "pixel_values": [all_pixel_values], "tgt_sizes": [torch.tensor(all_tgt_sizes) if all_tgt_sizes else []], } cost_vision_process = time.time() - t0 t0 = time.time() # Get vision embeddings for all images (each may have multiple slices) # vision_hidden_states is a list, one entry per input image # Each entry contains embeddings for [source_image, slice_1, slice_2, ...] vision_hidden_states = self.model.get_vision_embedding(processed_frames) cost_vision_embed = time.time() - t0 if vision_hidden_states is not None and len(vision_hidden_states) > 0: t0 = time.time() # vision_hidden_states[0] contains ALL slices from ALL images (flattened) # Shape: [total_slices, 64, D] where total_slices = sum of slices across all images # We need to know how many slices each image has to correctly group them # Calculate slice counts for each image using get_sliced_grid (lightweight, no actual slicing) slice_counts = [] # e.g., [5, 9] means img1 has 5 slices (1 source + 4 HD), img2 has 9 for frame_idx, frame in enumerate(frame_list): max_slices = max_slice_nums_list[frame_idx] if hasattr(frame, "size"): # get_sliced_grid returns [M, N] grid or None if no slicing needed # Total images = 1 (source) + M * N (HD slices) grid = self.processor.image_processor.get_sliced_grid( frame.size, max_slices, nerver_split=False ) if grid is not None: slice_counts.append(1 + grid[0] * grid[1]) # 1 source + M*N slices else: slice_counts.append(1) # No slicing, only source image else: slice_counts.append(1) # Default: single image, no slicing # Get the flattened embeddings tensor # vision_hidden_states is a list with one element (the batch) # vision_hidden_states[0] shape: [total_slices, 64, D] all_embeds = vision_hidden_states[0] # Collect all feed operations first, then execute # This allows us to identify the last token for VISION mode logits feed_operations = [] # List of (embed, is_last_for_vision_mode) embed_idx = 0 # Current index in all_embeds for img_idx, num_slices in enumerate(slice_counts): if num_slices == 0: continue # First embedding is always the source image (downsampled overview) # Feed token feed_operations.append((self.decoder.embed_token(self.image_start_token_id), False)) # Feed source image embedding (shape: [64, D]) feed_operations.append((all_embeds[embed_idx], False)) # Feed token feed_operations.append((self.decoder.embed_token(self.image_end_token_id), False)) embed_idx += 1 # Remaining embeddings are HD slices (if num_slices > 1) if num_slices > 1: for slice_i in range(1, num_slices): # Feed token feed_operations.append((self.decoder.embed_token(self.slice_start_token_id), False)) # Feed slice embedding (shape: [64, D]) feed_operations.append((all_embeds[embed_idx], False)) # Feed token feed_operations.append((self.decoder.embed_token(self.slice_end_token_id), False)) embed_idx += 1 # Mark the last operation for VISION mode logits if feed_operations: feed_operations[-1] = (feed_operations[-1][0], True) # Execute all feed operations for embed, is_last in feed_operations: if mode == "VISION" and is_last: # Get logits from the last token self.pending_logits, _ = self.decoder.feed(embed, return_logits=True) else: self.decoder.feed(embed) # For OMNI MODE, no pending logits needed here (wait for audio) cost_vision_feed = time.time() - t0 # Step 3: process audio (if any) if has_audio: # accumulate audio to buffer self.audio_buffer = np.concatenate([self.audio_buffer, audio_waveform]) # calculate required audio length if self.audio_chunk_idx == 0: required_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000) if len(self.audio_buffer) < required_samples: padding_samples = required_samples - len(self.audio_buffer) padding = np.zeros(padding_samples, dtype=np.float32) self.audio_buffer = np.concatenate([padding, self.audio_buffer]) else: required_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000) need_samples = self.processor.get_streaming_chunk_size() if len(self.audio_buffer) < need_samples: return _make_result(False, f"音频不足: 需要 {need_samples} 样本, 只有 {len(self.audio_buffer)}") audio_chunk = self.audio_buffer[:need_samples] t0 = time.time() batch_feature = self.processor.process_audio_streaming( audio_chunk, reset=False, return_batch_feature=True, ) if batch_feature is None or batch_feature.audio_features.shape[-1] == 0: return _make_result(False, "流式音频处理返回空") # metadata batch_feature.chunk_idx = self.audio_chunk_idx batch_feature.use_extra_context = True batch_feature.prefix_extra_frames = 0 if self.audio_chunk_idx == 0 else 2 batch_feature.suffix_extra_frames = 2 batch_feature = batch_feature.to(self.device) cost_audio_process = time.time() - t0 t0 = time.time() embeds_nested = self.model.get_audio_embedding_streaming( batch_feature, use_extra_context=batch_feature.use_extra_context, prefix_extra_frames=batch_feature.prefix_extra_frames, suffix_extra_frames=batch_feature.suffix_extra_frames, ) audio_embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) cost_audio_embed = time.time() - t0 t0 = time.time() self.pending_logits, _ = self.decoder.feed(audio_embeds, return_logits=True) cost_audio_feed = time.time() - t0 if self.audio_chunk_idx == 0: cfg = self.processor._streaming_mel_processor.get_config() consumed_ms = int(cfg.get("effective_first_chunk_ms", self.FIRST_CHUNK_MS)) consumed_samples = int(consumed_ms * self.SAMPLE_RATE / 1000) else: consumed_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000) self.audio_buffer = self.audio_buffer[consumed_samples:] self.audio_chunk_idx += 1 self.current_mode = mode # for VISION mode, need to manually increase chunk count (AUDIO and OMNI modes already increased in _process_audio_buffer) if mode == "VISION": self.audio_chunk_idx += 1 return _make_result(True) @torch.no_grad() def streaming_generate( self, prompt_wav_path=None, max_new_speak_tokens_per_chunk=20, decode_mode: str = "sampling", temperature=0.7, top_k=20, top_p=0.8, listen_prob_scale=1.0, listen_top_k=None, text_repetition_penalty=1.05, text_repetition_window_size=512, ): start_time = time.time() if self.is_session_stop_set() or self.is_break_set(): return { "is_listen": True, "text": "", "audio_waveform": self._generate_silence_waveform(), "end_of_turn": True, "current_time": self.audio_chunk_idx, "cost_llm": 0.0, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": 0, "n_tts_tokens": 0, } # check if there are pending logits to process if not hasattr(self, "pending_logits") or self.pending_logits is None: return { "is_listen": True, "text": "", "audio_waveform": self._generate_silence_waveform(), "end_of_turn": False, "current_time": self.audio_chunk_idx, "cost_llm": 0.0, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": 0, "n_tts_tokens": 0, } # use pending logits generated in streaming_prefill logits = self.pending_logits self.pending_logits = None # Force listen: check if we should force listen for first N calls force_listen = self._streaming_generate_count < self.force_listen_count self._streaming_generate_count += 1 if force_listen: print(f"[Duplex] streaming_generate: force_listen=True (call #{self._streaming_generate_count})") total_hidden_in_unit = [] total_ids_in_unit = [] current_time = self.audio_chunk_idx is_listen = False end_of_turn = False llm_start_time = time.time() for j in range(max_new_speak_tokens_per_chunk): if j == max_new_speak_tokens_per_chunk - 1: if self.ls_mode == "explicit": self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id)) self.total_ids.append(self.chunk_eos_token_id) break if force_listen: last_id = torch.tensor([self.listen_token_id], dtype=torch.long, device=self.device) else: last_id = self.decoder.decode( logits=logits, mode=decode_mode, temperature=temperature, top_k=top_k, top_p=top_p, listen_top_k=listen_top_k, listen_prob_scale=listen_prob_scale, text_repetition_penalty=text_repetition_penalty, text_repetition_window_size=text_repetition_window_size, debug_print_top5=False, ) # if current turn not ended, not allowed to listen (only check when not force_listen) if last_id.item() == self.listen_token_id and (not self.current_turn_ended): last_id = torch.tensor([self.tts_bos_token_id], dtype=torch.long, device=self.device) self.total_ids.append(last_id.item()) is_listen = last_id.item() == self.listen_token_id # termination condition detection if last_id.item() in self.chunk_terminator_token_ids: if self.ls_mode == "explicit": logits, _ = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True) break else: # normal speak self.current_turn_ended = False if last_id.item() in self.chunk_speak_token_ids: pass else: self.res_ids.append(last_id.item()) self.speak_count += 1 logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True) assert len(hidden.shape) == 3 assert hidden.shape[0] == 1 assert hidden.shape[1] == 1 end_of_turn = last_id.item() in self.turn_terminator_token_ids if end_of_turn: self.current_turn_ended = True if j != 0: total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn]) total_ids_in_unit.append(last_id.item()) # Prefill token unit_end_id = self.tokenizer.convert_tokens_to_ids("") self.decoder.feed(self.decoder.embed_token(unit_end_id)) self.total_ids.append(unit_end_id) # 计算生成的文本(用于滑窗 context 保留,过滤掉特殊 token) generated_text = self.tokenizer.decode(total_ids_in_unit, skip_special_tokens=True) if total_ids_in_unit else "" # 滑窗:注册 unit 结束,并检查是否需要滑窗 input_type = self.current_mode.lower() if self.current_mode else "audio" logger.info( "[Duplex] streaming_generate: completing unit, mode=%s, is_listen=%s, generated=%d tokens", input_type, is_listen, len(total_ids_in_unit), ) self.decoder.register_unit_end( input_type=input_type, generated_tokens=total_ids_in_unit, is_listen=is_listen, generated_text=generated_text, ) # 根据滑窗模式选择滑窗方法 if self.decoder._window_config.sliding_window_mode == "context": self.decoder.enforce_window_with_context() elif self.decoder._window_config.sliding_window_mode == "basic": self.decoder.enforce_window() llm_end_time = time.time() if is_listen: self.total_hidden.append([]) return { "is_listen": True, "text": "", "audio_waveform": self._generate_silence_waveform(), "end_of_turn": False, "current_time": current_time, "cost_llm": llm_end_time - llm_start_time, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": len(total_ids_in_unit), "n_tts_tokens": 0, } self.total_hidden.append(total_hidden_in_unit) text = generated_text # 复用已计算的文本 print(f"> speak: {text}") if not self.generate_audio: return { "is_listen": False, "text": text, "audio_waveform": None, "end_of_turn": end_of_turn, "current_time": current_time, "cost_llm": llm_end_time - llm_start_time, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": len(total_ids_in_unit), "n_tts_tokens": 0, } # TTS generate tts_start_time = time.time() tts_prep_start_time = time.time() tts_condition = self._convert_results_to_tts_input(total_hidden_in_unit) tts_prep_end_time = time.time() max_token_per_chunk = 25 + 1 min_token_per_chunk = 25 + 1 if end_of_turn: min_token_per_chunk = 0 force_flush = False if self.tts_text_start_pos == 0: # 这是turn的开始 min_token_per_chunk = 0 # 可以允许解码<1s的音频 force_flush = True if self.tts_current_turn_start_time is None: self.tts_current_turn_start_time = current_time new_tokens, old_kv = self.model.tts.generate_chunk( inputs_embeds=tts_condition, temperature=self.tts_temperature, repetition_penalty=self.tts_repetition_penalty, eos_token=self.tts_eos_token, force_no_stop=False, max_new_token=max_token_per_chunk, min_new_tokens=min_token_per_chunk, past_key_values=self.tts_past_key_values, logits_processors=self.tts_logits_processors, text_start_pos=self.tts_text_start_pos, ) tts_end_time = time.time() # 更新 TTS 状态(注意:token2wav 的重置必须在音频生成之后,否则会丢失 buffer 中的 tokens) if end_of_turn: self.tts_text_start_pos = 0 self.tts_past_key_values = None self.tts_current_turn_start_time = None # 注意:_reset_token2wav_for_new_turn() 移到下面音频生成之后 else: self.tts_past_key_values = old_kv self.tts_text_start_pos += tts_condition.shape[1] + new_tokens.shape[1] # Token2Wav 生成(必须在 reset 之前,否则 buffer 中倒数第二个 chunk 的 tokens 会丢失) token2wav_start_time = time.time() audio_waveform = self._generate_waveform_from_tokens(new_tokens, prompt_wav_path, end_of_turn, force_flush=force_flush) token2wav_end_time = time.time() # 在音频生成完成后再重置 token2wav 状态,确保 buffer 中的 tokens 都被处理 if end_of_turn: self._reset_token2wav_for_new_turn() end_time = time.time() return { "is_listen": False, "text": text, "audio_waveform": audio_waveform, "end_of_turn": end_of_turn, "current_time": current_time, "cost_llm": llm_end_time - llm_start_time, "cost_tts_prep": tts_prep_end_time - tts_prep_start_time, "cost_tts": tts_end_time - tts_start_time, "cost_token2wav": token2wav_end_time - token2wav_start_time, "cost_all": end_time - start_time, "n_tokens": len(total_ids_in_unit), "n_tts_tokens": new_tokens.numel(), } def _convert_results_to_tts_input(self, results): """convert LLM hidden states to TTS input""" if len(results) == 0: audio_bos = self.model.tts.emb_text( torch.tensor( [self.model.tts.audio_bos_token_id], device=self.model.tts.emb_text.weight.device, dtype=torch.long, ) ) return audio_bos.unsqueeze(0) llm_tokens = [] llm_hidden = [] for hidden in results: llm_tokens.append(hidden[0]) llm_hidden.append(hidden[1].squeeze(0)) llm_tokens_tensor = torch.Tensor(llm_tokens).to(self.device, dtype=torch.long) llm_embeds = self.model.tts.emb_text(llm_tokens_tensor) llm_hidden_tensor = torch.cat(llm_hidden, dim=0) llm_hidden_tensor = self.model.tts.projector_semantic(llm_hidden_tensor) llm_hidden_tensor = torch.nn.functional.normalize(llm_hidden_tensor, p=2, dim=-1) tts_embeds = llm_embeds + llm_hidden_tensor audio_bos = self.model.tts.emb_text( torch.tensor( [self.model.tts.audio_bos_token_id], device=self.model.tts.emb_text.weight.device, dtype=torch.long, ) ) tts_embeds = torch.cat([tts_embeds, audio_bos], dim=0) return tts_embeds.unsqueeze(0) def _generate_waveform_from_tokens( self, new_tokens: torch.Tensor, prompt_wav_path: Optional[str], is_last_chunk: bool = False, force_flush: bool = False ) -> Optional[np.ndarray]: """从 audio tokens 生成波形""" if not self.token2wav_initialized: print("⚠️ token2wav 未初始化") return None CHUNK_SIZE = 25 # 将新 tokens 添加到 buffer token_ids = torch.reshape(new_tokens, (-1,)).tolist() self.token2wav_buffer += token_ids # 检测是否有 chunk_eos token has_chunk_eos = any(tid in self.chunk_terminator_token_ids for tid in token_ids) self._log( "AUDIO", f"Token2Wav buffer size: {len(self.token2wav_buffer)}, new tokens: {len(token_ids)}, has_chunk_eos: {has_chunk_eos}", ) pcm_bytes_list = [] # process enough tokens # if there is chunk_eos, try to flush more content if has_chunk_eos or force_flush: # when there is chunk_eos, try to flush more content while len(self.token2wav_buffer) >= self.pre_lookahead + 5: # at least keep some lookahead chunk_to_process = min(CHUNK_SIZE + self.pre_lookahead, len(self.token2wav_buffer)) pcm_bytes = self.model.tts.audio_tokenizer.stream( self.token2wav_buffer[:chunk_to_process], prompt_wav=prompt_wav_path, ) pcm_bytes_list.append(pcm_bytes) self.token2wav_buffer = self.token2wav_buffer[min(CHUNK_SIZE, chunk_to_process - self.pre_lookahead) :] else: while len(self.token2wav_buffer) >= CHUNK_SIZE + self.pre_lookahead: pcm_bytes = self.model.tts.audio_tokenizer.stream( self.token2wav_buffer[: CHUNK_SIZE + self.pre_lookahead], prompt_wav=prompt_wav_path, ) pcm_bytes_list.append(pcm_bytes) self.token2wav_buffer = self.token2wav_buffer[CHUNK_SIZE:] # if is the last chunk, flush remaining tokens if is_last_chunk and len(self.token2wav_buffer) > 0: self._log("AUDIO", f"Flushing final {len(self.token2wav_buffer)} tokens") pcm_bytes = self.model.tts.audio_tokenizer.stream( self.token2wav_buffer, prompt_wav=prompt_wav_path, last_chunk=True, ) pcm_bytes_list.append(pcm_bytes) self.token2wav_buffer = [] if not pcm_bytes_list: return None # merge PCM and convert to numpy array (24kHz, int16 -> float32) all_pcm = b"".join(pcm_bytes_list) if len(all_pcm) == 0: self._log("AUDIO", "No audio bytes generated") return None pcm_np = np.frombuffer(all_pcm, dtype=" np.ndarray: """generate silence waveform (24kHz)""" sample_rate = 24000 num_samples = int(duration_sec * sample_rate) return np.zeros(num_samples, dtype=np.float32) def get_generated_text(self) -> str: return self.tokenizer.decode(self.res_ids) def get_current_time(self) -> int: return self.audio_chunk_idx def _log(self, event_type: str, message: str, data: Optional[dict] = None): if self.session_start_time is None: self.session_start_time = time.time() timestamp = time.time() - self.session_start_time log_entry = {"timestamp": timestamp, "type": event_type, "message": message} if data: log_entry["data"] = data self.session_logs.append(log_entry) prefix = f"[{timestamp:6.3f}s]" if event_type == "FEED": print(f"{prefix} [FEED] {message}") elif event_type == "DECODE": print(f"{prefix} [DECODE] {message}") elif event_type == "SYS": print(f"{prefix} [SYS] {message}") elif event_type == "AUDIO": print(f"{prefix} [AUDIO] {message}") else: print(f"{prefix} [{event_type}] {message}") def save_session_log(self, output_dir: str = "tmp_session_logs"): import json import os from datetime import datetime os.makedirs(output_dir, exist_ok=True) timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = os.path.join(output_dir, f"session_log_{timestamp_str}.json") with open(log_file, "w", encoding="utf-8") as f: json.dump(self.session_logs, f, ensure_ascii=False, indent=2) print(f"session log saved to: {log_file}") return log_file def _update_loop(self, mode): modes = { "omni": [ "", "", "", "", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>:decode", "", ], "video": ["", "", "", ":decode", ""], "audio": [ "", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>", "<|audio|>:decode", "", ], } self.loop = modes.get(mode, None) @torch.no_grad() def offline_generate( self, frame_path: Optional[str] = None, audio_path: Optional[str] = None, max_new_speak_tokens=600, max_new_speak_tokens_per_chunk=20, max_loop_time=300, decode_mode: str = "sampling", temperature=0.7, top_k=20, top_p=0.8, listen_top_k=None, listen_prob_scale=1.0, text_repetition_penalty=1.05, text_repetition_window_size=512, ) -> Dict: from .stream_providers import LoopPlanner from .stream_providers import MockStreamProvider from .stream_providers import StreamProvider res_ids = [] total_ids = [] total_hidden = [] per_second_results = [] total_res_with_time = [] iter_data = MockStreamProvider(frame_path, audio_path, self.processor) audio_data, frame_data = iter_data.all_data(device=self.device) frame_emb_list = None if frame_data is not None: frame_emb_list = self.model.get_vision_embedding(frame_data) if frame_emb_list and len(frame_emb_list) > 0: frame_emb_list = frame_emb_list[0] else: frame_emb_list = None audio_emb_list = None if audio_data is not None: audio_emb_list = self.model.get_audio_embedding( audio_data, chunk_length=self.model.config.audio_chunk_length ) if audio_emb_list and len(audio_emb_list) > 0: audio_emb_list = torch.cat(audio_emb_list[0], dim=0) else: audio_emb_list = None if audio_emb_list is None and frame_emb_list is None: logger.warning("Warning: No audio or frame data available, returning empty result") return {"total_ids": [], "res_ids": []} loop = LoopPlanner(self.step_map, self.loop, self.tokenizer) provider = StreamProvider(audio_emb_list, frame_emb_list, audio_chunk_len=1, image_step=1, audio_chunk_ms=100) speak_count = 0 while ( speak_count < max_new_speak_tokens and not provider.finished() and provider.get_current_time() < max_loop_time ): modal, tid, action = loop.next_plan() total_ids.append(tid) if action == "sliding": self.decoder.sliding_embeds() continue if modal == "audio": embeds = provider.next_audio() elif modal == "image": embeds = provider.next_image() if embeds is not None: embeds = embeds.squeeze(0) elif modal == "sp_text": embeds = self.decoder.embed_token(tid) else: raise ValueError(f"Unknown modal: {modal}") if embeds is None: break if action == "prefill": self.decoder.feed(embeds) elif action == "decode": total_hidden_in_unit = [] current_second_ids = [] logits, _ = self.decoder.feed(embeds, return_logits=True) current_time = provider.get_current_time() end_of_turn = False for j in range(max_new_speak_tokens_per_chunk): if j == max_new_speak_tokens_per_chunk - 1: if self.ls_mode == "explicit": self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id)) total_ids.append(self.chunk_eos_token_id) break last_id = self.decoder.decode( logits=logits, mode=decode_mode, temperature=temperature, top_k=top_k, top_p=top_p, listen_top_k=listen_top_k, listen_prob_scale=listen_prob_scale, text_repetition_penalty=text_repetition_penalty, text_repetition_window_size=text_repetition_window_size, debug_print_top5=False, ) is_listen = last_id.item() == self.listen_token_id end_of_turn = last_id.item() in self.turn_terminator_token_ids total_res_with_time.append( { "text": self.tokenizer.decode([last_id.item()]), "timestamp": provider.get_current_time(), } ) total_ids.append(last_id.item()) if last_id.item() in self.chunk_terminator_token_ids: if self.ls_mode == "explicit": self.decoder.feed(self.decoder.embed_token(last_id.item())) break else: if last_id.item() in self.chunk_speak_token_ids: pass else: res_ids.append(last_id.item()) current_second_ids.append(last_id.item()) speak_count += 1 logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True) assert len(hidden.shape) == 3 assert hidden.shape[0] == 1 assert hidden.shape[1] == 1 if j != 0: total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn]) per_second_results.append( { "time": current_time, "ids": current_second_ids, "text": self.tokenizer.decode(current_second_ids) if current_second_ids else "", } ) if is_listen: pass else: total_hidden.append(total_hidden_in_unit) if last_id.item() in self.turn_terminator_token_ids: break return { "total_ids": total_ids, "res_ids": res_ids, "total_res_with_time": total_res_with_time, "total_hidden": total_hidden, "per_second_results": per_second_results, } class MiniCPMWhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.d_model try: # compatible old transformers from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, layer_idx=layer_idx, ) except: from transformers.models.whisper.modeling_whisper import WhisperAttention self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = False, ) -> torch.Tensor: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, past_key_values = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_values, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) if use_cache: outputs += (past_key_values,) return outputs # Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference class MiniCPMWhisperEncoder(WhisperEncoder): def __init__(self, config: WhisperConfig): super().__init__(config) self.layers = nn.ModuleList( [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] ) def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = None, use_extra_context: Optional[bool] = False, prefix_extra_frames: Optional[int] = 1, suffix_extra_frames: Optional[int] = 1, return_debug: Optional[bool] = False, cnn_min_length: Optional[int] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # 初始化debug字典 debug_info = {} if return_debug else None # Ignore copy input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) # 可选:将短输入pad到最小长度以确保CNN计算一致性 original_length = input_features.shape[2] padded_for_cnn = False if cnn_min_length is not None and original_length < cnn_min_length: padded_features = torch.zeros( input_features.shape[0], input_features.shape[1], cnn_min_length, dtype=input_features.dtype, device=input_features.device, ) padded_features[:, :, :original_length] = input_features input_features = padded_features padded_for_cnn = True if return_debug: debug_info["padded_for_cnn"] = True debug_info["original_length"] = original_length debug_info["padded_length"] = cnn_min_length if return_debug: debug_info["input_features"] = input_features.clone() conv1_output = self.conv1(input_features) if return_debug: debug_info["after_conv1"] = conv1_output.clone() inputs_embeds = nn.functional.gelu(conv1_output) if return_debug: debug_info["after_gelu1"] = inputs_embeds.clone() conv2_output = self.conv2(inputs_embeds) if return_debug: debug_info["after_conv2"] = conv2_output.clone() inputs_embeds = nn.functional.gelu(conv2_output) if return_debug: debug_info["after_gelu2"] = inputs_embeds.clone() # 如果之前进行了padding,现在需要移除padding的影响 if padded_for_cnn: # Conv1: stride=1, 输出长度=输入长度 # Conv2: stride=2, 输出长度=(输入长度+1)//2 actual_cnn_output_length = (original_length + 1) // 2 inputs_embeds = inputs_embeds[:, :, :actual_cnn_output_length] if return_debug: debug_info["after_unpad"] = inputs_embeds.clone() # 如果使用额外上下文,CNN操作后需要移除冗余帧 # conv2的stride=2,所以输入的冗余帧在输出中会变成一半(向上取整) if use_extra_context: # 输入有prefix_extra_frames前缀帧和suffix_extra_frames后缀帧 # conv2 stride=2,输出长度 = ceil(输入长度 / 2) # 对于2帧冗余,输出是1帧(ceil(2/2) = 1) prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0 # 向上取整 suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0 # 向上取整 # 移除前后冗余帧 (batch, channels, time) if prefix_to_remove > 0: inputs_embeds = inputs_embeds[:, :, prefix_to_remove:] if 0 < suffix_to_remove < inputs_embeds.shape[2]: inputs_embeds = inputs_embeds[:, :, :-suffix_to_remove] if return_debug and use_extra_context: debug_info["after_redundancy_removal"] = inputs_embeds.clone() inputs_embeds = inputs_embeds.permute(0, 2, 1) if return_debug: debug_info["after_permute"] = inputs_embeds.clone() embed_pos = self.embed_positions.weight past_key_values_length = 0 if use_cache: if past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) elif isinstance(past_key_values, list): past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) elif isinstance(past_key_values, DynamicCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) else: pass past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") embed_pos_front = embed_pos[past_key_values_length:, :] embed_pos = torch.cat( ( embed_pos_front, torch.repeat_interleave( embed_pos[-1, :].unsqueeze(0), inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, dim=0, ), ) ) else: embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] else: embed_pos = embed_pos[: inputs_embeds.shape[1], :] if return_debug: debug_info["positional_embedding"] = embed_pos.clone() debug_info["past_key_values_length"] = past_key_values_length hidden_states = inputs_embeds + embed_pos if return_debug: debug_info["after_add_pos_embed"] = hidden_states.clone() hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if return_debug: debug_info["after_dropout"] = hidden_states.clone() encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True # Ignore copy if to_drop: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), output_attentions, past_key_values, use_cache, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = layer_outputs[0] if return_debug and idx < 3: # 只记录前3层以节省内存 debug_info[f"after_layer_{idx}"] = hidden_states.clone() if use_cache: next_encoder_cache = layer_outputs[2 if output_attentions else 1] else: next_encoder_cache = None if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if return_debug: debug_info["after_layer_norm"] = hidden_states.clone() if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: result = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) if return_debug: return result, debug_info return result result = BaseModelOutputWithPast( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, past_key_values=next_encoder_cache, ) if return_debug: return result, debug_info return result class MultiModalProjector(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) self.relu = nn.ReLU() self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) def forward(self, audio_features): hidden_states = self.relu(self.linear1(audio_features)) hidden_states = self.linear2(hidden_states) return hidden_states class MiniCPMMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.in_dim = config.llm_hidden_size self.out_dim = config.hidden_size self.intermediate_size = config.llm_intermediate_size self.gate_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True) self.up_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True) self.down_proj = nn.Linear(self.intermediate_size, self.out_dim, bias=True) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @dataclass class MiniCPMTTSGenerationOutput(ModelOutput): """ Output class for MiniCPMTTS generation. Args: new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq). audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq). past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head). finished (bool): Boolean indicating whether generation is complete. """ new_ids: torch.LongTensor = None audio_input_ids: torch.LongTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_input_ids: Optional[torch.LongTensor] = None finished: bool = None def make_streaming_chunk_mask_inference( tts_text_scope: List[int], tts_text_mask: torch.Tensor, streaming_audio_chunk_size: int = 50, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cuda"), max_sequence_length: int = 4096, ): """ Example: Input sequence: [t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...] Output 4D causal mask: ------- text positions ------- [0] <- here is [Stts] [0, 0] <- here is [spk_emb] * N [0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0, 0] ------- audio positions -------- [0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token v- here is [Ptts] [0, 0, -inf, -inf, -inf, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] """ # Create a complete attention mask for input embeds [batch_size, seq_len], without considering audio mask as audio is always at the end assert tts_text_mask.dtype == torch.int8 padding_mask = torch.ones(max_sequence_length, dtype=torch.int8, device=device) padding_mask[tts_text_scope[0] : tts_text_scope[1]] = tts_text_mask # Initialize a standard upper triangular causal mask min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (max_sequence_length, max_sequence_length), fill_value=min_dtype, dtype=dtype, device=device, ) if max_sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) else: raise ValueError("max_sequence_length of tts could not be 1.") # For each data sample audio_token_start = tts_text_scope[1] audio_duration = max_sequence_length - tts_text_scope[1] # Record which text chunk the current audio chunk can see up to text_pivot = 0 num_valid_text_tokens = torch.sum(tts_text_mask).item() - 1 # [Ptts] excluded # How many audio chunks are in total, the num of buckets should be smaller as possible num_text_tokens_per_audio_chunk = 10 # For each chunk of audio for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)): audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size # New text seen by this new audio chunk new_text_this_chunk = num_text_tokens_per_audio_chunk # The right bound of visible text tokens text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens) # Mask all text chunks after the visible ones # -> [text_pivot, len(tts_text_scope)-1] excluding [Ptts] # print("audio_chunk_start-1", audio_chunk_start-1, "audio_chunk_end-1", audio_chunk_end-1, "tts_text_scope[0] + text_pivot", tts_text_scope[0] + text_pivot, "tts_text_scope[1] - 1", tts_text_scope[1] - 1) causal_mask[ audio_chunk_start - 1 : audio_chunk_end - 1, # tts_text_scope[0] + text_pivot: tts_text_scope[1], tts_text_scope[0] + text_pivot : tts_text_scope[1] - 1, ] = min_dtype # Mask the padding parts in tts_text_masks (no position will attend to it) causal_mask[:, padding_mask == 0] = min_dtype # Add extra dimensions, [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len] causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) return causal_mask class MiniCPMTTS(PreTrainedModel): config_class = MiniCPMTTSConfig def __init__(self, config: MiniCPMTTSConfig, audio_tokenizer: None): super().__init__(config) self.use_llm_hidden_state = config.use_llm_hidden_state self.use_text = config.use_text self.streaming = config.streaming self.streaming_text_chunk_min = config.streaming_text_chunk_min self.streaming_text_chunk_max = config.streaming_text_chunk_max self.streaming_audio_chunk_size = config.streaming_audio_chunk_size self.streaming_text_reserved_len = config.streaming_text_reserved_len # streaming tts self.streaming_text_chunk_size = config.streaming_text_chunk_max self.audio_bos_token_id = config.audio_bos_token_id self.num_mel_bins = config.num_mel_bins self.num_vq = config.num_vq self.num_audio_tokens = config.num_audio_tokens self.top_p = config.top_p self.top_k = config.top_k self.repetition_penalty = config.repetition_penalty self.interleaved = config.interleaved self.attention_type = config.attention_type self.recomputed_chunks = config.recomputed_chunks self.window_size = config.window_size if self.attention_type == "sliding_recompute" and self.window_size <= self.recomputed_chunks: raise ValueError( f"Sliding recompute requires window_size > recomputed_chunks, but got window_size={self.window_size} and recomputed_chunks={self.recomputed_chunks}" ) if config.backbone_model == "llama": model_config = LlamaConfig( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, num_attention_heads=config.num_attention_heads, num_hidden_layers=config.num_hidden_layers, num_key_value_heads=config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, attn_implementation=config.attn_implementation, ) self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) model = LlamaModel(model_config) self.model = model else: raise ValueError(f"Unsupported backbone model: {config.backbone_model}") self.projector_spk = self.create_projector(config) self.projector_semantic = self.create_projector(config) self.audio_tokenizer = audio_tokenizer self.emb_code = nn.ModuleList( [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)] ) self.head_code = nn.ModuleList( [ weight_norm( nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), name="weight", ) for _ in range(config.num_vq) ] ) self.condition_type = config.condition_type return @staticmethod def create_projector(config): if config.projector_type == "mlp": return MultiModalProjector(config.llm_dim, config.hidden_size) elif config.projector_type == "minicpm": return MiniCPMMLP(config) elif config.projector_type == "default": return nn.Linear(config.llm_dim, config.hidden_size, bias=False) else: raise ValueError(f"Unsupported projector type: {config.projector_type}") # non-streaming @torch.inference_mode() def generate( self, inputs_embeds: torch.Tensor, eos_token: Union[int, torch.Tensor], force_no_stop=False, min_new_token=50, max_new_token=2048, show_tqdm=True, streaming=False, text_lengths=None, sampling_params: TTSSamplingParams = TTSSamplingParams(), ): temperature = torch.tensor( [sampling_params.temperature] * self.config.num_vq, dtype=torch.float, device=self.device, ) temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to( inputs_embeds.device ) logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=sampling_params.repetition_penalty, top_p=sampling_params.top_p, top_k=sampling_params.top_k, ) # We only support batch size `1` for now assert inputs_embeds.shape[0] == 1 eos_token = eos_token.to(inputs_embeds.device) finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool() condition_length = inputs_embeds.shape[1] pbar: Optional[tqdm] = None if show_tqdm: pbar = tqdm( total=max_new_token, desc="code", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", ) if streaming: raise NotImplementedError("this kind of streaming is not supported yet") new_tokens = torch.zeros( inputs_embeds.shape[0], max_new_token, self.num_vq, device=inputs_embeds.device, dtype=torch.long, ) past_key_values = None for t in range(max_new_token): audio_bos = False # If this is the first audio token, the case is special if t == 0: audio_bos = True inputs_embeds = inputs_embeds position_ids = torch.tensor( list(range(0, condition_length)), dtype=torch.long, device=self.device, ).unsqueeze(0) if streaming: raise NotImplementedError("this kind of streaming is not supported yet") else: causal_mask_4d = None else: code_emb = [] for q in range(self.num_vq): x = self.emb_code[q](new_tokens[:, t - 1 : t, q]) code_emb.append(x) inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze( 0 ) if streaming: raise NotImplementedError("this kind of streaming is not supported yet") else: causal_mask_4d = None if self.config.backbone_model == "llama": outputs: BaseModelOutputWithPast = self.model( position_ids=position_ids, cache_position=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=causal_mask_4d, use_cache=True, output_attentions=False, # return_dict=True, # Add this to ensure returns dict with past_key_values ) else: raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}") del position_ids del inputs_embeds hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= temperature if not audio_bos: input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) for logitsWarpers in logits_warpers: logits = logitsWarpers(logits_token, logits) del logits_token if t < min_new_token: logits[:, eos_token] = -torch.inf if force_no_stop: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or new_tokens[:, t] = idx_next if t == 0 and finish.any(): break del idx_next if finish.all(): break if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() if not finish.all(): logger.warning(f"incomplete result. hit max_new_token: {max_new_token}") genrated_input_ids = new_tokens[:, 0:t, :] return MiniCPMTTSGenerationOutput( new_ids=genrated_input_ids, audio_input_ids=None, # for update purpose past_key_values=None, # for update purpose past_input_ids=None, # for update purpose finished=finish.all(), ) # fake streaming @torch.inference_mode() def generate_mock_legacy_streaming( self, inputs_embeds: torch.Tensor, eos_token: Union[int, torch.Tensor], force_no_stop=False, min_new_token=50, max_new_token=2048, show_tqdm=True, streaming=False, text_lengths=None, sampling_params: TTSSamplingParams = TTSSamplingParams(), valid_text_length=None, ): assert valid_text_length is not None, "valid_text_length should be not None" tts_text_scope = [0, inputs_embeds.shape[1]] tts_text_mask = torch.zeros(inputs_embeds.shape[1], dtype=torch.int8, device=inputs_embeds.device) tts_text_mask[0:valid_text_length] = 1 tts_text_mask[-1] = 1 # [Ptts] streaming_mask_4d_full = make_streaming_chunk_mask_inference( tts_text_scope=tts_text_scope, tts_text_mask=tts_text_mask, dtype=torch.bfloat16, device=self.device, streaming_audio_chunk_size=50, max_sequence_length=4096, ) temperature = torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.device) temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to( inputs_embeds.device ) logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=sampling_params.repetition_penalty, top_p=sampling_params.top_p, top_k=sampling_params.top_k, ) # We only support batch size `1` for now assert inputs_embeds.shape[0] == 1 eos_token = eos_token.to(inputs_embeds.device) finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool() condition_length = inputs_embeds.shape[1] pbar: Optional[tqdm] = None if show_tqdm: pbar = tqdm( total=max_new_token, desc="code", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", ) new_tokens = torch.zeros( inputs_embeds.shape[0], max_new_token, self.num_vq, device=inputs_embeds.device, dtype=torch.long, ) past_key_values = None for t in range(max_new_token): audio_bos = False if t == 0: audio_bos = True inputs_embeds = inputs_embeds position_ids = torch.tensor( list(range(0, condition_length)), dtype=torch.long, device=self.device, ).unsqueeze(0) causal_mask_4d = streaming_mask_4d_full[:, :, :condition_length, :condition_length] else: code_emb = [] for q in range(self.num_vq): x = self.emb_code[q](new_tokens[:, t - 1 : t, q]) code_emb.append(x) inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze( 0 ) causal_mask_4d = streaming_mask_4d_full[ :, :, condition_length + t : condition_length + t + 1, : condition_length + t, ] # get length of past_key_values past_key_values_length = past_key_values[0][0].shape[2] assert causal_mask_4d.shape[-1] == (past_key_values_length + 1) if self.config.backbone_model == "llama": outputs: BaseModelOutputWithPast = self.model( position_ids=position_ids, cache_position=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=causal_mask_4d, use_cache=True, output_attentions=False, # return_dict=True, # Add this to ensure returns dict with past_key_values ) else: raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}") del position_ids del inputs_embeds hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= temperature if not audio_bos: input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) for logitsWarpers in logits_warpers: logits = logitsWarpers(logits_token, logits) del logits_token if t < min_new_token: logits[:, eos_token] = -torch.inf if force_no_stop: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or new_tokens[:, t] = idx_next if t == 0 and finish.any(): break del idx_next if finish.all(): break if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() if not finish.all(): logger.warning(f"incomplete result. hit max_new_token: {max_new_token}") genrated_input_ids = new_tokens[:, 0:t, :] return MiniCPMTTSGenerationOutput( new_ids=genrated_input_ids, audio_input_ids=None, # for update purpose past_key_values=None, # for update purpose past_input_ids=None, # for update purpose finished=finish.all(), ) # non-streaming, interleave @torch.inference_mode() def generate_chunk( self, inputs_embeds: torch.Tensor, temperature: torch.Tensor, repetition_penalty: float, eos_token: Union[int, torch.Tensor], force_no_stop=False, max_new_token=500, min_new_tokens=0, past_key_values=None, logits_processors=None, text_start_pos=None, ): """For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like: |Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS| where the last position is the audio BOS token. So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token. """ logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty ) # We only support batch size `1` for now assert inputs_embeds.shape[0] == 1 eos_token = eos_token.to(inputs_embeds.device) finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool() temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to( inputs_embeds.device ) condition_length = inputs_embeds.shape[1] new_tokens = torch.zeros( inputs_embeds.shape[0], max_new_token, self.num_vq, device=inputs_embeds.device, dtype=torch.long, ) for t in range(max_new_token): audio_bos = False # If this is the first audio token, the case is special if t == 0: audio_bos = True inputs_embeds_ = inputs_embeds position_ids = torch.tensor( list(range(text_start_pos, text_start_pos + condition_length)), dtype=torch.long, device=self.device, ).unsqueeze(0) else: # Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate` inputs_embeds_ = self.emb_code[0](new_tokens[:, t - 1 : t, 0]) position_ids = torch.tensor( [text_start_pos + condition_length + t - 1], # 把上一个token prefill进去 dtype=torch.long, device=self.device, ).unsqueeze(0) outputs: BaseModelOutputWithPast = self.model( position_ids=position_ids, # cache_position=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds_, use_cache=True, output_attentions=False, # return_dict=True, # Add this to ensure returns dict with past_key_values ) del position_ids del inputs_embeds_ hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= temperature if not audio_bos: input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) del logits_token if force_no_stop or t < min_new_tokens: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or new_tokens[:, t] = idx_next if t == 0 and finish.any(): break del idx_next if finish.all(): break if not finish.all(): logger.warning(f"incomplete result. hit max_new_token: {max_new_token}") # 最新生成的那个token不在此次返回的范围内,如果是eos token,不返回,如果是其他正常的token,也不返回。正常! genrated_input_ids = new_tokens[:, 0:t, :] return genrated_input_ids, past_key_values @torch.inference_mode() def interleaved_generate( self, spk_embeds: torch.Tensor, conditions: List[torch.Tensor], temperature: torch.Tensor, repetition_penalty: float, eos_token: Union[int, torch.Tensor], **kwargs, ): """ For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like: |Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS| where the last position is the audio BOS token. So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token. """ temperature = torch.tensor([temperature], dtype=torch.float, device=self.device) logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty, ) eos_token = eos_token.to(conditions[0].device) num_chunks = len(conditions) text_start_pos = 0 last_window_size = 0 past_key_values = None for idx in range(num_chunks): condition = conditions[idx].to(conditions[0].device) if self.attention_type == "sliding_recompute": recomputed_conditions = [] if ( idx >= self.window_size and (idx - self.recomputed_chunks) % (self.window_size - self.recomputed_chunks) == 0 ): for i in range(self.recomputed_chunks): recomputed_conditions.append(conditions[idx - self.recomputed_chunks + i]) recomputed_conditions.append( self.emb_code[0](generated_tokens[-self.recomputed_chunks + i][:, :, 0]) ) recomputed_conditions.append(condition) condition = torch.cat(recomputed_conditions, dim=1) text_start_pos = 0 new_tokens, old_kv = self.generate_chunk( inputs_embeds=condition, temperature=temperature, repetition_penalty=repetition_penalty, eos_token=eos_token, force_no_stop=False, max_new_token=500, past_key_values=None, logits_processors=logits_processors, text_start_pos=text_start_pos, ) else: new_tokens, old_kv = self.generate_chunk( inputs_embeds=condition, temperature=temperature, repetition_penalty=repetition_penalty, eos_token=eos_token, force_no_stop=False, max_new_token=500, past_key_values=past_key_values, logits_processors=logits_processors, text_start_pos=text_start_pos, ) else: new_tokens, old_kv = self.generate_chunk( inputs_embeds=condition, temperature=temperature, repetition_penalty=repetition_penalty, eos_token=eos_token, force_no_stop=False, max_new_token=500, past_key_values=past_key_values, logits_processors=logits_processors, text_start_pos=text_start_pos, ) past_key_values = [] if self.attention_type == "sliding_window" and idx >= 1: for layer_idx in range(len(old_kv)): past_key_values.append( ( old_kv[layer_idx][0][:, :, last_window_size:, :], old_kv[layer_idx][1][:, :, last_window_size:, :], ) ) else: past_key_values = old_kv last_window_size = condition.shape[1] + new_tokens.shape[1] text_start_pos += last_window_size if idx == 0: generated_tokens = [new_tokens] else: generated_tokens.append(new_tokens) return MiniCPMTTSGenerationOutput(new_ids=torch.cat(generated_tokens, dim=1), finished=True) class CustomRepetitionPenaltyLogitsProcessorRepeat: def __init__(self, penalty: float, max_input_ids: int, past_window: int): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.size(1) > self.past_window: input_ids = input_ids.narrow(1, -self.past_window, self.past_window) freq = F.one_hot(input_ids, scores.size(1)).sum(1) if freq.size(0) > self.max_input_ids: freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_() alpha = torch.pow(self.penalty, freq) scores = scores.contiguous() inp = scores.multiply(alpha) oth = scores.divide(alpha) con = scores < 0 out = torch.where(con, inp, oth) del inp, oth, scores, con, alpha return out def gen_logits(num_code: int, top_p=0.7, top_k=20, repetition_penalty=1.0): logits_warpers = [] if top_p is not None: logits_warpers.append(TopPLogitsWarper(top_p, min_tokens_to_keep=3)) if top_k is not None: logits_warpers.append(TopKLogitsWarper(top_k, min_tokens_to_keep=3)) logits_processors = [] if repetition_penalty is not None and repetition_penalty != 1: logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16)) return logits_warpers, logits_processors # Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens else: cache_length = past_length = past_key_values[0][0].shape[2] # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for positionidspositionidsposition_ids. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs["input_ids"].shape device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min from transformers.models.paligemma.modeling_paligemma import ( _prepare_4d_causal_attention_mask_with_cache_position, ) attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_length(), dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) model_inputs.update( { "position_ids": position_ids, # "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs