#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import logging
import math
import os
import tempfile
import threading
import time
import types
from copy import deepcopy
from dataclasses import dataclass
from threading import Thread
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as P
from PIL import Image
from pydantic import BaseModel
from torch.nn.utils.parametrizations import weight_norm
from tqdm import tqdm
from transformers import LlamaConfig
from transformers import LlamaModel
from transformers import PreTrainedModel
from transformers import Qwen3ForCausalLM
from transformers import Qwen3PreTrainedModel
from transformers import TextIteratorStreamer
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import EncoderDecoderCache
from transformers.cache_utils import StaticCache
from transformers.generation.logits_process import TopKLogitsWarper
from transformers.generation.logits_process import TopPLogitsWarper
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_outputs import ModelOutput
from transformers.models.whisper.configuration_whisper import WhisperConfig
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from .chunk_prefill_generate import ChunkPrefillChunkGenerate
from .configuration_minicpmo import MiniCPMOConfig
from .configuration_minicpmo import MiniCPMODuplexConfig
from .configuration_minicpmtts import MiniCPMTTSConfig
from .modeling_navit_siglip import SiglipVisionTransformer
from .processing_minicpmo import MiniCPMOProcessor
from .resampler import Resampler
from .sliding_utils import as_dynamic_cache
from .sliding_utils import drop_tokens_from_cache
from .sliding_utils import get_kv_cache_length
from .sliding_utils import realign_rotary_suffix
from .sliding_utils import StreamingWindowConfig
from .tts_streaming_generate import TTSStreamingGenerator
from .utils import streaming_token_decoder
from .utils import torch_clone_recursive
# NOTE: These imports are added to work around a Transformers bug where secondary
# relative imports are not copied to the cache when loading from a local path.
# See: https://github.com/huggingface/transformers/issues/XXXXX
from .audio_utils import process_audio_batch as _ # noqa: F401
from .processing_audio_minicpma import MiniCPMAAudioProcessor as __ # noqa: F401
from .processing_image_minicpmv import MiniCPMOBatchFeature as ___ # noqa: F401
from .processing_streaming_mel import StreamingMelProcessorExact as ____ # noqa: F401
logger = logging.getLogger(__name__)
@dataclass
class SpeculativeSnapshot:
"""Speculative snapshot for VAD speculative rollback.
Used in VAD speculative execution: creates a snapshot after streaming_prefill
and before streaming_generate. If speculation fails (user continues speaking),
the state can be restored to continue streaming_prefill.
Implementation:
- LLM KV Cache: only record length, restore by truncation (zero extra VRAM)
- Audio KV Cache: requires cloning, as generate sets it to None
- Mel processor: save full state snapshot (including buffer)
"""
# KV Cache 长度(用于裁剪恢复)
llm_cache_length: int
audio_cache_length: int
# 会话状态
new_user_msg: bool
llm_generated: bool
llm_generate_completed: bool
# Round 管理
next_round_id: int
pending_round_id: Optional[int]
omni_chunk_history_length: int
# TTS 状态(需要克隆,但通常很小)
tts_last_turn_tokens: Optional[torch.Tensor]
# Streaming 处理器状态
audio_chunk_idx: int
# Mel processor 状态快照(包括 buffer)
mel_processor_snapshot: Optional[dict] = None
# Audio encoder KV cache(需要克隆以确保恢复后 continue prefill 的确定性)
audio_past_key_values: Optional[tuple] = None
# 时间戳(调试用)
timestamp: float = 0.0
# 调试字段:用于验证恢复的正确性
llm_cache_checksum: Optional[float] = None # LLM KV Cache 第一层 K 的 sum
audio_cache_checksum: Optional[float] = None # Audio KV Cache 第一层 K 的 sum
mel_buffer_checksum: Optional[float] = None # Mel buffer 的 sum
# RNG 状态(关键:用于恢复后确保 dithering 等随机操作的确定性)
rng_state_cpu: Optional[torch.Tensor] = None # torch CPU RNG state
rng_state_cuda: Optional[torch.Tensor] = None # torch CUDA RNG state (if on GPU)
def summary(self) -> str:
"""返回快照摘要,用于日志"""
mel_buf_len = 0
if self.mel_processor_snapshot:
buf = self.mel_processor_snapshot.get("buffer")
if buf is not None:
mel_buf_len = len(buf)
return (
f"llm_cache={self.llm_cache_length}, "
f"audio_cache={self.audio_cache_length}, "
f"audio_chunk_idx={self.audio_chunk_idx}, "
f"mel_buffer={mel_buf_len}, "
f"history_len={self.omni_chunk_history_length}, "
f"new_user_msg={self.new_user_msg}, "
f"llm_generated={self.llm_generated}"
)
class TTSSamplingParams(BaseModel):
top_p: float = 0.85
min_p: float = 0.01
top_k: int = 25
repetition_penalty: float = 1.05
temperature: float = 0.8
win_size: int = 16
tau_r: float = 0.1
class MiniCPMOPreTrainedModel(Qwen3PreTrainedModel):
config_class = MiniCPMOConfig
class MiniCPMO(MiniCPMOPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.llm = Qwen3ForCausalLM(config)
self.embed_dim = self.llm.config.hidden_size
self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm
# init vision module
if self.config.init_vision:
self.vpm = self.init_vision_module()
self.vision_dim = self.vpm.embed_dim
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
# init audio module
if self.config.init_audio:
self.apm = self.init_audio_module()
audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)
self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)
self.audio_encoder_layer = -1
# init tts module
if self.config.init_tts:
self.tts = self.init_tts_module()
self.terminators = ["<|im_end|>", "<|endoftext|>"]
self.think_str = ""
if self.llm.__class__.__name__ == "Qwen3ForCausalLM":
self.think_str = "\\n\\n\\n\\n"
self.default_tts_chat_template = (
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n"
+ self.think_str
+ "<|tts_bos|>' }}{% endif %}"
)
# for streaming
self.reset_session(reset_token2wav_cache=True)
# streaming audio processing constants
self.SAMPLE_RATE = 16000
self.CHUNK_MS = 1000 # regular chunk length (ms)
self.FIRST_CHUNK_MS = 1035 # first chunk length (ms)
self.CNN_REDUNDANCY_MS = 0 # CNN redundancy (ms)
# for sliding window
self.streaming_window_config = StreamingWindowConfig()
self.streaming_require_system_prompt = True
self.streaming_window_enabled = True
self.force_rope_reindex = False # RoPE reindex testing switch
def init_streaming_processor(self):
if not hasattr(self, "processor") or self.processor is None:
self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
if hasattr(self.processor, "set_streaming_mode"):
self.processor.set_streaming_mode(
mode="exact",
chunk_ms=self.CHUNK_MS,
verbose=False,
first_chunk_ms=self.FIRST_CHUNK_MS,
cnn_redundancy_ms=self.CNN_REDUNDANCY_MS,
enable_sliding_window=True,
slide_trigger_seconds=30.0,
slide_stride_seconds=10.0,
)
self.processor.reset_streaming()
self.audio_chunk_idx = 0
def reset_session(self, reset_token2wav_cache=True):
self.llm_past_key_values = None
self.audio_past_key_values = None
self.tts_last_turn_tokens = None
self.llm_generated = False # last turn generated by llm or not
self.llm_generate_completed = False
self.new_user_msg = True
self.session_id = None
if reset_token2wav_cache:
self.token2wav_cache = None
# for sliding window
self.streaming_text_preserve = 0
self.streaming_position_offset = 0
self._rope_inv_freq_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
self._next_round_id = 0
self._pending_round_id = None
self._omni_chunk_history: List[Dict[str, Union[str, int]]] = []
self._round_history: List[Dict[str, Union[int, str, torch.Tensor, Optional[int]]]] = []
def init_vision_module(self):
if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
setattr(model, "embed_dim", model.embeddings.embed_dim)
setattr(model, "patch_size", model.embeddings.patch_size)
return model
def init_resampler(self, embed_dim, vision_dim):
return Resampler(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
)
def init_audio_module(self):
if self.config._attn_implementation == "eager":
self.config.audio_config._attn_implementation = "eager"
else:
# using flash_attention_2 will cause: RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
self.config.audio_config._attn_implementation = "sdpa"
return MiniCPMWhisperEncoder(self.config.audio_config)
def init_tts_module(self):
if self.config._attn_implementation == "flash_attention_2":
self.config.tts_config.attn_implementation = "flash_attention_2"
else:
self.config.tts_config.attn_implementation = "eager"
return MiniCPMTTS(config=self.config.tts_config, audio_tokenizer=None)
def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_timesteps=10):
if streaming:
if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio":
logger.warning("audio tokenizer type is set to s3tokenizer_step_audio")
self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
try:
from stepaudio2 import Token2wav
except ImportError:
raise ImportError(f"please install Token2wav via: pip install stepaudio2-minicpmo")
model_dir = model_dir or os.path.join(self.config._name_or_path, "assets/token2wav")
self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
return self.tts.audio_tokenizer
else:
if self.config.tts_config.audio_tokenizer_type != "s3tokenizer":
logger.warning("audio tokenizer type is set to s3tokenizer")
self.tts.config.audio_tokenizer_type = "s3tokenizer"
try:
from cosyvoice.cli.cosyvoice import CosyVoice2
except ImportError:
raise ImportError(f"please install cosyvoice via: pip install cosyvoice-minicpmo")
model_dir = model_dir or os.path.join(self.config._name_or_path, "assets/CosyVoice2-0.5B")
self.tts.audio_tokenizer = CosyVoice2(model_dir=model_dir, load_jit=False, load_trt=False, fp16=False)
return self.tts.audio_tokenizer
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
def set_input_embeddings(self, value):
self.llm.embed_tokens = value
def get_output_embeddings(self):
return self.llm.lm_head
def set_output_embeddings(self, new_embeddings):
self.llm.lm_head = new_embeddings
def set_decoder(self, decoder):
self.llm = decoder
def get_decoder(self):
return self.llm
@staticmethod
def get_sys_prompt(ref_audio=None, mode="default", language="en", ref_audio_max_ms=None):
if ref_audio is not None:
if isinstance(ref_audio, str):
if ref_audio == "assets/demo.wav":
import librosa
duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None
ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration)
else:
import os
import librosa
if os.path.isfile(ref_audio) and os.path.exists(ref_audio):
duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None
ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration)
else:
logger.error(f"Could not find {ref_audio}")
ref_audio = None
assert isinstance(ref_audio, np.ndarray), "ref_audio error"
if mode == "omni":
if language == "zh":
sys_prompt = ""
vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。"
vc_prompt_suffix = (
"请用这种声音风格来为用户提供帮助。 请认真、高质量地回复用户的问题。 请用高自然度的方式和用户聊天。"
)
else:
sys_prompt = ""
vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt."
vc_prompt_suffix = "As an assistant, you will speak using this voice style."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
else:
sys_msgs = {"role": "system", "content": [sys_prompt]}
return sys_msgs
elif mode == "audio_assistant":
if language == "zh":
vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。"
vc_prompt_suffix = "你的任务是用这种声音模式来当一个助手。请认真、高质量地回复用户的问题。请用高自然度的方式和用户聊天。你是由面壁智能开发的人工智能助手:面壁小钢炮。"
else:
vc_prompt_prefix = "Use the voice in the audio prompt to synthesize new content."
vc_prompt_suffix = "You are a helpful assistant with the above voice style."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
else:
logger.warning(
"Warning: ref_audio is None, speech generation will be performed based on the default voice."
)
sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]}
return sys_msgs
elif mode == "audio_roleplay":
if language == "zh":
vc_prompt_prefix = "模仿输入音频中的声音特征。"
vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。"
else:
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
vc_prompt_suffix = "Try to role-play the character based on the audio prompt above."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
else:
print("Warning: ref_audio is None, speech generation will be performed based on the default voice.")
sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]}
return sys_msgs
elif mode == "voice_cloning":
if language == "zh":
vc_prompt_prefix = "模仿输入音频中的声音特征。"
else:
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio]}
else:
raise ValueError("ref_audio con't be None in voice_cloning mode.")
return sys_msgs
else:
sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text."
sys_msgs = {"role": "system", "content": [sys_prompt]}
return sys_msgs
@staticmethod
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
num_lookhead: int = 0,
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
num_lookhead:
Returns:
torch.Tensor: mask
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
ret[i, start:ending] = True
return ret
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""Computes the output length of the convolutional layers and the output length of the audio encoder"""
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
input_lengths_after_pooling = (
input_lengths_after_cnn - self.config.audio_pool_step
) // self.config.audio_pool_step + 1
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
return input_lengths_after_cnn, input_lengths_after_pooling
def get_vision_embedding(self, data):
if "vision_hidden_states" not in data:
dtype = self.llm.model.embed_tokens.weight.dtype
device = self.llm.model.embed_tokens.weight.device
tgt_sizes = data["tgt_sizes"]
pixel_values_list = data["pixel_values"]
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
for pixel_values in pixel_values_list:
img_cnt.append(len(pixel_values))
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
# exist image
if all_pixel_values:
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
for i in range(B):
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.type(dtype)
if B > vision_batch_size:
hs = []
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
tmp_hs = self.vpm(
all_pixel_values[start_idx:end_idx],
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
tgt_sizes=tgt_sizes[start_idx:end_idx],
).last_hidden_state
hs.append(tmp_hs)
vision_embedding = torch.cat(hs, dim=0)
else:
vision_embedding = self.vpm(
all_pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
else: # no image
if self.training:
dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
tgt_sizes = torch.Tensor(
[
[
(224 // self.config.patch_size),
math.ceil(224 / self.config.patch_size),
]
]
).type(torch.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
for _ in range(len(pixel_values_list)):
vision_hidden_states.append(dummy_feature)
else:
vision_hidden_states = data["vision_hidden_states"]
return vision_hidden_states
def get_vllm_embedding(self, data):
vision_hidden_states = self.get_vision_embedding(data)
if hasattr(self.llm.config, "scale_emb"):
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
else:
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
vision_hidden_states = [
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
]
bs = len(data["input_ids"])
for i in range(bs):
cur_vs_hs = vision_hidden_states[i]
if len(cur_vs_hs) > 0:
cur_vllm_emb = vllm_embedding[i]
cur_image_bound = data["image_bound"][i]
if len(cur_image_bound) > 0:
image_indices = torch.stack(
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
).to(vllm_embedding.device)
cur_vllm_emb.scatter_(
0,
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
)
elif self.training:
cur_vllm_emb += cur_vs_hs[0].mean() * 0
return vllm_embedding, vision_hidden_states
def get_audio_embedding_streaming(
self,
data,
use_extra_context=False,
prefix_extra_frames=1,
suffix_extra_frames=1,
return_debug=False,
cnn_min_length=None,
):
"""Extract audio embeddings in a streaming manner using cached key-value pairs.
This method processes incoming audio features incrementally and stores/updates `past_key_values`
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios.
Args:
data (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
use_extra_context (bool): If True, assumes input contains extra frames for CNN context.
prefix_extra_frames (int): Number of prefix extra frames.
suffix_extra_frames (int): Number of suffix extra frames.
return_debug (bool): Whether to return debug information.
cnn_min_length (int): Minimum length for CNN input padding.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
# exist audio
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
assert batch_size == 1
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# whisper's past_key_values management (core)
if self.audio_past_key_values is not None:
cache_length = self.audio_past_key_values[0][0].shape[2]
apm_max_len = self.apm.embed_positions.weight.shape[0]
if cache_length + max_seq_len >= apm_max_len:
logger.warning(
f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
)
self.audio_past_key_values = None
# build attention mask (bidirectional attention, same as offline mode)
batch_size, _, max_mel_seq_len = wavforms.shape
current_seq_len = (max_mel_seq_len - 1) // 2 + 1
# if use extra context, need to adjust sequence length
if use_extra_context:
# calculate actual sequence length after removing redundancy
# conv2's stride=2, so the mapping from mel frames to output frames is ceil(x/2)
prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0
suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0
current_seq_len = current_seq_len - prefix_to_remove - suffix_to_remove
# calculate history length (if there is KV cache)
if self.audio_past_key_values is not None:
past_len = self.audio_past_key_values[0][0].shape[2] # get history sequence length
total_seq_len = past_len + current_seq_len
else:
past_len = 0
total_seq_len = current_seq_len
# create bidirectional attention mask (full attention)
audio_attention_mask = torch.zeros(
(batch_size, 1, current_seq_len, total_seq_len),
dtype=self.apm.conv1.weight.dtype,
device=wavforms.device,
)
debug_info = {} if return_debug else None
# DEBUG: 打印输入和 past_key_values 状态
if hasattr(self, "_debug_prefill") and self._debug_prefill:
wavform_sum = wavforms.sum().item()
past_kv_info = (
"None" if self.audio_past_key_values is None else f"len={self.audio_past_key_values[0][0].shape[2]}"
)
print(
f"[DEBUG audio_embed] wavforms sum={wavform_sum:.6f}, shape={wavforms.shape}, past_kv={past_kv_info}"
)
# Step 1: APM processing
ret_apm = self.apm(
wavforms,
past_key_values=self.audio_past_key_values,
use_cache=True,
output_hidden_states=True,
attention_mask=audio_attention_mask,
use_extra_context=use_extra_context,
prefix_extra_frames=prefix_extra_frames,
suffix_extra_frames=suffix_extra_frames,
return_debug=return_debug,
cnn_min_length=cnn_min_length,
)
if return_debug:
audio_outputs, debug_info["apm_debug"] = ret_apm
else:
audio_outputs = ret_apm
if hasattr(self, "audio_encoder_layer"):
audio_states = audio_outputs.hidden_states[self.audio_encoder_layer]
else:
audio_states = audio_outputs.last_hidden_state
# DEBUG: 打印 apm 输出的 checksum
if hasattr(self, "_debug_prefill") and self._debug_prefill:
apm_out_sum = audio_outputs.last_hidden_state.sum().item()
audio_states_sum = audio_states.sum().item()
print(f"[DEBUG audio_embed] apm_output sum={apm_out_sum:.6f}, audio_states sum={audio_states_sum:.6f}")
if return_debug:
debug_info["apm_output"] = audio_outputs.last_hidden_state.clone()
debug_info["audio_states_before_proj"] = audio_states.clone()
debug_info["attention_mask"] = audio_attention_mask.clone()
debug_info["use_extra_context"] = use_extra_context
debug_info["prefix_extra_frames"] = prefix_extra_frames
debug_info["suffix_extra_frames"] = suffix_extra_frames
self.audio_past_key_values = audio_outputs.past_key_values
# Step 2: Projection
audio_embeds = self.audio_projection_layer(audio_states)
# DEBUG: 打印 projection 和 pooling 后的 checksum
if hasattr(self, "_debug_prefill") and self._debug_prefill:
proj_sum = audio_embeds.sum().item()
print(f"[DEBUG audio_embed] after_projection sum={proj_sum:.6f}, shape={audio_embeds.shape}")
if return_debug:
debug_info["after_projection"] = audio_embeds.clone()
# Step 3: Pooling
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
# DEBUG: 打印 pooling 后的 checksum
if hasattr(self, "_debug_prefill") and self._debug_prefill:
pool_sum = audio_embeds.sum().item()
print(f"[DEBUG audio_embed] after_pooling sum={pool_sum:.6f}, shape={audio_embeds.shape}")
if return_debug:
debug_info["after_pooling"] = audio_embeds.clone()
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
if return_debug:
debug_info["final_embeddings"] = final_audio_embeds
return final_audio_embeds, debug_info
else:
return final_audio_embeds
else:
return [] if not return_debug else ([], {})
def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
dtype = self.apm.embed_positions.weight.dtype
device = self.apm.embed_positions.weight.device
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, output_hidden_states=True, attention_mask=audio_attention_mask
).hidden_states[self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
elif self.training and dummy:
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
return [audio_embeds]
else:
return []
def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False):
"""
Args:
data:
input_embeddings:
chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding or not
Returns:
final embeddings with audio feature
"""
if stream_input:
audio_embeddings = self.get_audio_embedding_streaming(data)
else:
audio_embeddings = self.get_audio_embedding(data, chunk_length)
bs = len(input_embeddings)
if len(data.get("audio_features", [])) > 0:
assert len(audio_embeddings) == len(input_embeddings)
if len(audio_embeddings) > 0:
audio_bounds = data["audio_bounds"]
if self.config.stream_input:
assert bs == 1, "audio stream_input mode only support batch size 1"
for i in range(bs):
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeddings.device, dtype=input_embeddings.dtype
)
audio_start_pos = 0
for bound in audio_bounds[i]:
audio_len = bound[1] - bound[0]
input_embeddings[i, bound[0] : bound[1]] = audio_embs[
audio_start_pos : audio_start_pos + audio_len, :
]
audio_start_pos += audio_len
else:
for i in range(bs):
audio_embs = audio_embeddings[i]
bounds = audio_bounds[i]
for embs, bound in zip(audio_embs, bounds):
audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to(
input_embeddings.device
)
if embs.shape[0] != len(audio_indices):
logger.error(f"Sample {i}:")
logger.error(f" Bounds: {bound}, Indices Length: {len(audio_indices)}")
logger.error(f" Embeddings Shape: {embs.shape}")
logger.error(
f" Input Embedding Shape at Indices: {input_embeddings[i, audio_indices].shape}"
)
raise ValueError(
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}"
)
input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype)
elif self.training:
for i in range(bs):
# dummy audio_embedings
input_embeddings += audio_embeddings[0].mean() * 0
return input_embeddings
def forward(self, data, **kwargs):
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
vllm_embedding = self.get_omni_embedding(
data,
input_embeddings=vllm_embedding,
chunk_length=self.config.audio_chunk_length,
)
position_ids = data["position_ids"]
if position_ids.dtype != torch.int64:
position_ids = position_ids.long()
return self.llm(
input_ids=None,
position_ids=position_ids,
inputs_embeds=vllm_embedding,
**kwargs,
)
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
return outputs
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_config = {
"inputs_embeds": inputs_embeds,
"pad_token_id": 0,
"eos_token_id": terminators,
"streamer": streamer,
}
generation_config.update(kwargs)
thread = Thread(target=self.llm.generate, kwargs=generation_config)
thread.start()
return streamer
def _decode_text(self, result_ids, tokenizer):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
result_text = []
for result in result_ids:
result = result[result != 0]
if result[0] == tokenizer.bos_id:
result = result[1:]
if result[-1] in terminators:
result = result[:-1]
result_text.append(tokenizer.decode(result))
return result_text
@torch.inference_mode()
def generate(
self,
input_ids=None,
pixel_values=None,
tgt_sizes=None,
audio_features=None,
audio_feature_lens=None,
image_bound=None,
audio_bounds=None,
spk_bounds=None,
attention_mask=None,
tokenizer=None,
vision_hidden_states=None,
stream=False,
**kwargs,
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
model_inputs = {
"input_ids": input_ids,
"audio_features": audio_features,
"audio_feature_lens": audio_feature_lens,
"image_bound": image_bound,
"audio_bounds": audio_bounds,
"spk_bounds": spk_bounds,
}
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
model_inputs["tgt_sizes"] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
with torch.inference_mode():
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
model_inputs["inputs_embeds"] = self.get_omni_embedding(
model_inputs,
input_embeddings=model_inputs["inputs_embeds"],
chunk_length=self.config.audio_chunk_length,
)
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
outputs = {} # if stream return TextIteratorStreamer and output is empty
else:
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
result = self._decode_text(outputs.sequences, tokenizer)
return result, outputs
def _build_streaming_mask(self, tts_tokens_len):
tts_sequence_full_length = 1 + self.tts.streaming_text_reserved_len + 1
streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8)
streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1
streaming_attention_mask[-1] = 1
return streaming_attention_mask
def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048):
spk_embeds = self._get_last_spk_embeds(inputs, outputs)
text = text.split("<|tts_bos|>")[-1]
gen_text = text.split("<|tts_eos|>")[0]
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
logits_warpers, logits_processors = gen_logits(
num_code=626,
top_p=self.tts.top_p,
top_k=self.tts.top_k,
repetition_penalty=self.tts.repetition_penalty,
)
condition_length = 1 + self.tts.streaming_text_reserved_len + 1
dtype = self.tts.emb_text.weight.dtype
emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device)
past_key_values = [
(
torch.zeros(
1,
self.tts.config.num_attention_heads,
condition_length - 1,
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
dtype=emb.dtype,
device=self.tts.device,
),
torch.zeros(
1,
self.tts.config.num_attention_heads,
condition_length - 1,
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
dtype=emb.dtype,
device=self.tts.device,
),
)
for _ in range(self.tts.config.num_hidden_layers)
]
audio_input_ids = torch.zeros(
1,
condition_length,
self.tts.num_vq,
dtype=torch.long,
device=self.tts.device,
)
eos_lab = False
for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)):
if chunk_idx == 0:
begin = chunk_idx * self.tts.streaming_text_chunk_size + 0
end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1
else:
begin = chunk_idx * self.tts.streaming_text_chunk_size + 1
end = min(
(chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1,
condition_length - 1,
)
if end - begin > 0:
text_input_ids = tts_input_ids[:, begin:end]
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
if begin == 0:
past_key_values = self.tts.prefill_text(
input_ids=text_input_ids,
position_ids=position_ids,
past_key_values=past_key_values,
lm_spk_emb_last_hidden_states=spk_embeds,
)
else:
past_key_values = self.tts.prefill_text(
input_ids=text_input_ids,
position_ids=position_ids,
past_key_values=past_key_values,
)
outputs = self.tts.generate(
input_ids=audio_input_ids,
past_key_values=past_key_values,
streaming_tts_text_mask=streaming_tts_text_mask,
max_new_token=output_chunk_size,
force_no_stop=self.force_no_stop,
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
logits_warpers=logits_warpers,
logits_processors=logits_processors,
)
audio_input_ids = outputs.audio_input_ids
past_key_values = outputs.past_key_values
if outputs.finished:
logger.debug("Generation finished.")
eos_lab = True
break
if not eos_lab:
logger.debug("eos_lab False, Generation continue.")
while True:
outputs = self.tts.generate(
input_ids=audio_input_ids,
past_key_values=past_key_values,
streaming_tts_text_mask=streaming_tts_text_mask,
max_new_token=output_chunk_size,
force_no_stop=self.force_no_stop,
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
logits_warpers=logits_warpers,
logits_processors=logits_processors,
)
audio_input_ids = outputs.audio_input_ids
past_key_values = outputs.past_key_values
if outputs.finished:
logger.debug("Generation finished.")
break
if outputs.new_ids.shape[1] > tts_max_new_tokens:
logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.")
break
@staticmethod
def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs):
num_beams = kwargs.get("num_beams", 3)
generation_config = {
"num_beams": num_beams,
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"repetition_penalty": 1.02,
}
if do_sample:
generation_config.update(
{
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"repetition_penalty": 1.02,
}
)
elif num_beams > 1:
generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False})
else:
generation_config.update({"do_sample": False, "repetition_penalty": 1.02})
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
generation_config["min_new_tokens"] = min_new_tokens
generation_config["max_new_tokens"] = max_new_tokens
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
return generation_config
@torch.inference_mode()
def chat(
self,
image=None,
msgs=None,
tokenizer=None, # deprecated
processor=None, # deprecated
vision_hidden_states=None,
max_new_tokens=4096,
min_new_tokens=0,
do_sample=True,
sampling=None, # deprecated, please use do_sample
max_inp_length=8192,
stream=False,
stream_input=False,
max_slice_nums=None,
use_image_id=None,
enable_thinking=False,
use_tts_template=False,
generate_audio=False,
output_audio_path=None,
output_tts_inputs_embeds_path=None,
# add
omni_mode=False,
omni_input=None, # deprecated, please use omni_mode
teacher_forcing=False,
return_prompt=False,
tts_proj_layer=-1,
tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
merge_audio_from_same_content=True,
**kwargs,
):
# todo: deprecated
if sampling is not None:
do_sample = sampling
if omni_input is not None:
omni_mode = omni_input
batched = isinstance(msgs[0], list)
msgs_list = msgs
images_list = image
if not batched:
images_list, msgs_list = [images_list], [msgs_list]
else:
assert images_list is None, "Please integrate image to msgs when using batch inference."
images_list = [None] * len(msgs_list)
assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
if not hasattr(self, "processor") or self.processor is None:
self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
prompts_lists = []
input_images_list = []
input_audios_list = []
audio_parts_list = []
for image, msgs in zip(images_list, msgs_list):
if isinstance(msgs, str):
msgs = json.loads(msgs)
copy_msgs = deepcopy(msgs)
assert len(msgs) > 0, "msgs is empty"
assert do_sample or not stream, "if use stream mode, make sure do_sample=True"
if image is not None and isinstance(copy_msgs[0]["content"], str):
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
images = []
audios = []
audio_parts = []
for i, msg in enumerate(copy_msgs):
role = msg["role"]
content = msg["content"]
assert role in ["system", "user", "assistant"]
if i == 0:
assert role in ["user", "system"], "The role of first msg should be user"
if isinstance(content, str):
content = [content]
cur_msgs = []
for c in content:
if isinstance(c, Image.Image):
images.append(c)
cur_msgs.append("./")
elif isinstance(c, np.ndarray): # audio
audios.append(c)
audio_parts.append(i)
cur_msgs.append("")
use_tts_template = True
elif isinstance(c, str):
cur_msgs.append(c)
if omni_mode or stream_input:
msg["content"] = "".join(cur_msgs)
else:
msg["content"] = "\n".join(cur_msgs)
prompts_lists.append(
self.processor.tokenizer.apply_chat_template(
copy_msgs,
tokenize=False,
add_generation_prompt=False if teacher_forcing else True,
use_tts_template=use_tts_template,
enable_thinking=enable_thinking,
)
)
input_images_list.append(images)
input_audios_list.append(audios)
audio_parts_list.append(audio_parts)
if not merge_audio_from_same_content:
audio_parts_list = None
inputs = self.processor(
prompts_lists,
input_images_list,
input_audios_list,
audio_parts_list,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
stream_input=stream_input,
return_tensors="pt",
max_length=max_inp_length,
).to(self.device)
generation_config = self.prepare_generation_config(
do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs
)
generation_config.pop("max_new_tokens", None)
inputs.pop("image_sizes")
# teacher_forcing = True => generate audio with given text
with torch.inference_mode():
res, outputs = self.generate(
**inputs,
tokenizer=self.processor.tokenizer,
max_new_tokens=1 if teacher_forcing else max_new_tokens,
vision_hidden_states=vision_hidden_states,
stream=stream,
**generation_config,
)
# spk bound and tts bound
tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
# Combine input_ids and generated sequences to get complete sequence
input_ids = inputs["input_ids"][0]
generated_ids = outputs.sequences[0]
# Combine by concatenating input_ids with the new tokens from generated sequence
full_sequence = torch.cat([input_ids, generated_ids])
# Update the sequences in outputs
full_sequences = full_sequence.unsqueeze(0)
outputs["full_sequences"] = full_sequences
tts_bos_indices = []
tts_eos_indices = []
for i, x in enumerate(full_sequences[0]):
if x == tts_bos_token:
tts_bos_indices.append(i + 1) # tts_bos + 1 才是第一个tts的位置,这样方便直接给tts去slice hidden states
elif x == tts_eos_token:
if teacher_forcing and i == len(full_sequences[0]) - 1:
continue
tts_eos_indices.append(i)
tts_bos_idx = tts_bos_indices[-1] if tts_bos_indices else -1
# Use None instead of -1 when no EOS token found, so that slice [start:None]
# means "to the end" rather than [start:-1] which excludes the last element
tts_eos_idx = tts_eos_indices[-1] if tts_eos_indices else None
tts_bound = (tts_bos_idx, tts_eos_idx)
answer = res[0]
if answer is not None:
answer = answer.rstrip("<|tts_eos|>")
if use_tts_template and generate_audio and output_audio_path:
try:
generated_waveform = self._generate_speech_non_streaming(
outputs=outputs,
tts_bound=tts_bound,
tts_proj_layer=tts_proj_layer,
audio_prompt=(
input_audios_list[0][0]
if len(input_audios_list) > 0 and len(input_audios_list[0]) > 0
else None
),
output_tts_inputs_embeds_path=output_tts_inputs_embeds_path,
tts_sampling_params=tts_sampling_params,
)
if isinstance(generated_waveform, torch.Tensor):
sf.write(output_audio_path, generated_waveform.cpu().numpy(), samplerate=24000)
elif isinstance(generated_waveform, np.ndarray):
sf.write(output_audio_path, generated_waveform, samplerate=24000)
logger.debug(f"audio saved to {output_audio_path}")
except:
import traceback
traceback.print_exc()
if return_prompt:
return answer, prompts_lists[0]
else:
return answer
@torch.inference_mode()
def _generate_speech_non_streaming(
self,
outputs,
tts_bound,
tts_proj_layer,
audio_prompt,
output_tts_inputs_embeds_path=None,
tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
):
last_hidden_states = [hs[tts_proj_layer] for hs in outputs.hidden_states]
last_hidden_states = torch.vstack([i[0] for i in last_hidden_states])
spk_embeds = (
torch.ones([0, self.tts.config.hidden_size]).to(last_hidden_states.device).to(last_hidden_states.dtype)
)
if self.tts.condition_type == "hidden_text_merge":
llm_tokens = outputs["full_sequences"][0][tts_bound[0] : tts_bound[1]]
llm_tokens = torch.tensor(llm_tokens, device=self.tts.emb_text.weight.device, dtype=torch.long)
llm_embeds = self.tts.emb_text(llm_tokens) # make sure emb_text is compatible with llm vocab size
hidden_embeds = last_hidden_states[tts_bound[0] : tts_bound[1]]
hidden_embeds = self.tts.projector_semantic(hidden_embeds)
if self.tts.config.normalize_projected_hidden:
hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
tts_embeds = llm_embeds + hidden_embeds
if self.tts.interleaved:
chunks = []
cond_length = tts_embeds.shape[0]
for i in range(0, cond_length, 10):
chunks.append(tts_embeds[i : i + 10])
tts_embeds = chunks
else:
raise NotImplementedError
audio_bos = [self.tts.audio_bos_token_id]
audio_bos = torch.tensor(audio_bos, device=self.tts.emb_text.weight.device, dtype=torch.long)
audio_bos_embeds = self.tts.emb_text(audio_bos)
if self.tts.interleaved:
text_eos_embed = self.tts.emb_text(
torch.tensor(
[self.tts.config.text_eos_token_id],
device=self.tts.emb_text.weight.device,
dtype=torch.long,
)
)
tts_embeds[-1] = torch.cat([tts_embeds[-1], text_eos_embed], dim=0)
for i in range(len(tts_embeds)):
tts_embeds[i] = torch.cat([tts_embeds[i], audio_bos_embeds], dim=0).unsqueeze(0)
outputs = self.tts.interleaved_generate(
spk_embeds=spk_embeds,
conditions=tts_embeds,
temperature=0.8,
repetition_penalty=1.05,
eos_token=torch.tensor(
[self.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=self.tts.device,
),
)
else:
if self.tts.condition_type == "tts_token":
inputs_embeds = torch.cat([spk_embeds, tts_embeds, audio_bos_embeds], dim=0).unsqueeze(0)
elif self.tts.condition_type == "tts_token_streaming":
tts_embeds[1] = spk_embeds.squeeze(0) # apply speaker embedding
inputs_embeds = tts_embeds.unsqueeze(0)
else: # modern case
inputs_embeds = torch.cat([spk_embeds, tts_embeds, audio_bos_embeds], dim=0).unsqueeze(0)
# save inputs_embeds to file
if output_tts_inputs_embeds_path:
torch.save(inputs_embeds, output_tts_inputs_embeds_path)
outputs = self.tts.generate(
inputs_embeds=inputs_embeds,
sampling_params=tts_sampling_params,
eos_token=torch.tensor(
[self.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=self.tts.device,
),
)
if self.tts.config.audio_tokenizer_type == "s3tokenizer":
generated_tokens = outputs.new_ids.squeeze(-1)
reference_audio = audio_prompt
if reference_audio is not None:
logger.debug("use reference audio in data to generate waveform")
prompt_speech_16k = torch.tensor(reference_audio).unsqueeze(0)
if self.tts.config.s3_stream_generate:
waveform_pred = self.tts.audio_tokenizer.inference_token2wav(
speech_tokens=generated_tokens,
prompt_speech_16k=prompt_speech_16k,
prompt_speech=None,
stream=True,
n_timesteps=self.tts.config.s3_stream_n_timesteps,
code_chunk_size=self.tts.config.s3_stream_chunk_size,
chunk_prelook_size=self.tts.config.s3_stream_prelook_size,
use_attn_idx=False,
)
return waveform_pred[0]
else:
for i, j in enumerate(
self.tts.audio_tokenizer.token2wav(
speech_token=generated_tokens,
speech_token_len=torch.tensor([generated_tokens.shape[1]], device=generated_tokens.device),
prompt_speech_16k=prompt_speech_16k,
stream=False,
)
):
waveform_pred = j["tts_speech"]
waveform_sample_rate = self.tts.audio_tokenizer.sample_rate # 24000 here, not 16000 input.
return waveform_pred[0]
else:
raise NotImplementedError
@torch.inference_mode()
def init_token2wav_cache(self, prompt_speech_16k):
if hasattr(self.tts.audio_tokenizer, "set_stream_cache"):
self.tts.audio_tokenizer.cache = None
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
prompt_wav_path = tmp_wav.name
sf.write(prompt_wav_path, prompt_speech_16k, 16000)
flow_cache_base, hift_cache_base = self.tts.audio_tokenizer.set_stream_cache(prompt_wav_path)
self.token2wav_cache = {
"flow_cache_base": torch_clone_recursive(flow_cache_base),
"hift_cache_base": torch_clone_recursive(hift_cache_base),
}
else:
model_input = self.tts.audio_tokenizer.frontend.frontend_token2wav(
speech_tokens=torch.zeros(1, 1, dtype=torch.long, device=self.tts.device),
speech_16k=None,
prompt_speech_16k=prompt_speech_16k,
resample_rate=self.tts.audio_tokenizer.sample_rate,
prompt_speech=None,
)
prompt_token = model_input["flow_prompt_speech_token"]
prompt_feat = model_input["prompt_speech_feat"]
embedding = model_input["flow_embedding"]
if self.tts.audio_tokenizer.fp16:
prompt_feat = prompt_feat.to(torch.half)
embedding = embedding.to(torch.half)
prepared_cache = self.tts.audio_tokenizer.model.prepare_cache_from_prompt(
prompt_token=prompt_token,
prompt_feat=prompt_feat,
embedding=embedding,
n_timesteps=self.tts.config.s3_stream_n_timesteps,
code_chunk_size=self.tts.config.s3_stream_chunk_size,
chunk_prelook_size=self.tts.config.s3_stream_prelook_size,
use_attn_idx=False,
)
self.token2wav_cache = prepared_cache
# for sliding window
def _ensure_dynamic_cache(self):
cache = self.llm_past_key_values
if cache is None:
return None
cache = as_dynamic_cache(cache)
if isinstance(cache, DynamicCache):
self.llm_past_key_values = cache
return cache
return None
def _get_kv_cache_length(self, cache=None):
cache = cache if cache is not None else self.llm_past_key_values
return get_kv_cache_length(cache)
# todo: not-used del?
def _rebuild_cache_from_history(self):
preserved_ids: List[torch.Tensor] = []
for entry in self._omni_chunk_history:
ids = entry.get("input_ids")
if ids is None or not isinstance(ids, torch.Tensor) or ids.numel() == 0:
continue
preserved_ids.append(ids.to(self.device))
if not preserved_ids:
self.llm_past_key_values = None
self.streaming_position_offset = 0
self._rope_inv_freq_cache.clear()
return
concat_ids = torch.cat(preserved_ids, dim=1)
attention_mask = torch.ones((1, concat_ids.shape[1]), dtype=torch.bool, device=self.device)
outputs = self.llm(
input_ids=concat_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True,
)
self.llm_past_key_values = outputs.past_key_values
self.streaming_position_offset = 0
self._rope_inv_freq_cache.clear()
def _get_rope_theta(self) -> float:
return float(getattr(self.llm.config, "rope_theta", 10000.0))
def _realign_rotary_suffix(
self,
suffix_keys: torch.Tensor,
old_positions: torch.Tensor,
new_positions: torch.Tensor,
) -> torch.Tensor:
"""Wrapper for realign_rotary_suffix using instance's rope_theta and cache."""
return realign_rotary_suffix(
suffix_keys,
old_positions,
new_positions,
rope_theta=self._get_rope_theta(),
inv_freq_cache=self._rope_inv_freq_cache,
)
def _encode_text(self, tokenizer, text) -> Optional[torch.Tensor]:
if tokenizer is None or not text:
return None
ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
return ids.to(self.device)
@staticmethod
def _safe_decode(tokenizer, input_ids):
if tokenizer is None or input_ids is None:
return None
if isinstance(input_ids, torch.Tensor):
ids = input_ids.cpu().tolist()
if ids and isinstance(ids[0], list):
ids = ids[0]
else:
ids = input_ids
try:
return tokenizer.decode(ids, skip_special_tokens=False)
except Exception:
return None
def _finalize_round(
self, round_id: Optional[int], cache_before: int, assistant_input_ids: Optional[torch.Tensor] = None
):
if round_id is None:
self._pending_round_id = None
return
cache_after = self._get_kv_cache_length()
if assistant_input_ids is not None:
assistant_len = assistant_input_ids.shape[1]
else:
assistant_len = max(cache_after - cache_before, 0)
if assistant_len > 0:
self._register_chunk(
assistant_len,
"assistant",
round_id=round_id,
input_ids=assistant_input_ids,
tokenizer=self.processor.tokenizer if hasattr(self, "processor") else None,
)
logger.info(
"Finalized round=%s cache len before=%s after=%s assistant_len=%s",
round_id,
cache_before,
cache_after,
assistant_len,
)
self._pending_round_id = None
self._next_round_id += 1
def _register_chunk(
self,
seq_len: int,
chunk_type: str,
*,
round_id: int,
input_ids=None,
tokenizer=None,
) -> None:
if seq_len <= 0:
return
entry = {"length": int(seq_len), "type": chunk_type, "round": round_id}
if input_ids is not None:
entry["input_ids"] = input_ids.clone().detach()
entry["decoded"] = self._safe_decode(tokenizer, entry["input_ids"])
else:
entry["input_ids"] = None
entry["decoded"] = None
self._omni_chunk_history.append(entry)
logger.info(
"Registered chunk round=%s type=%s len=%s decoded=%s",
round_id,
chunk_type,
entry["length"],
entry["decoded"],
)
if chunk_type == "system":
self.streaming_text_preserve = max(self.streaming_text_preserve, entry["length"])
def _drop_tokens_from_cache(self, length: int, cache: DynamicCache) -> bool:
"""Drop tokens from cache using the utility function."""
_, new_offset, success = drop_tokens_from_cache(
cache=cache,
length=length,
preserve=self.streaming_text_preserve,
position_offset=self.streaming_position_offset,
rope_theta=self._get_rope_theta(),
inv_freq_cache=self._rope_inv_freq_cache,
)
if success:
self.streaming_position_offset = new_offset
return success
def _drop_next_round(self, cache: DynamicCache) -> bool:
seen_rounds = set()
for entry in self._omni_chunk_history:
round_id = entry.get("round")
if round_id is None or round_id in seen_rounds:
continue
seen_rounds.add(round_id)
round_entries = [e for e in self._omni_chunk_history if e.get("round") == round_id]
if any(e.get("type") == "system" for e in round_entries):
continue
if self._drop_round(round_id, cache):
return True
return False
def _drop_round(self, round_id: int, cache: DynamicCache) -> bool:
entries = [e for e in self._omni_chunk_history if e.get("round") == round_id]
if not entries:
return False
total_len = sum(e["length"] for e in entries)
if total_len <= 0:
for e in entries:
self._omni_chunk_history.remove(e)
return False
if not self._drop_tokens_from_cache(total_len, cache):
return False
for e in entries:
logger.info(
"Dropped round=%s chunk type=%s len=%s decoded=%s",
round_id,
e["type"],
e["length"],
e.get("decoded"),
)
self._omni_chunk_history.remove(e)
return True
def _enforce_text_window(self) -> None:
if not self.streaming_window_enabled:
return
cache = self._ensure_dynamic_cache()
if cache is None:
return
high_limit = max(0, int(self.streaming_window_config.text_window_high_tokens))
low_limit = max(0, int(self.streaming_window_config.text_window_low_tokens))
if high_limit <= 0:
return
target = max(0, low_limit)
total_len = self._get_kv_cache_length(cache)
if total_len <= high_limit:
return
dropped_any = False
while total_len > target:
if not self._drop_next_round(cache):
break
dropped_any = True
total_len = self._get_kv_cache_length(cache)
# for sliding window
# ============== 抢跑快照/恢复接口 ==============
def _save_speculative_snapshot(self) -> SpeculativeSnapshot:
"""Internal method: save speculative snapshot.
Called at the start of streaming_generate, saves to self._speculative_snapshot.
Save strategy:
- LLM KV Cache: only record length (restore by truncation, zero extra VRAM)
- Audio KV Cache: deep clone (as generate sets it to None)
- Mel processor: full state snapshot (including buffer)
"""
# 1. 获取 LLM cache 信息
llm_cache_length = self._get_kv_cache_length()
llm_cache_checksum = None
if self.llm_past_key_values is not None and hasattr(self.llm_past_key_values, "key_cache"):
if len(self.llm_past_key_values.key_cache) > 0:
llm_cache_checksum = self.llm_past_key_values.key_cache[0].sum().item()
# 2. 获取 audio cache 长度并克隆 audio_past_key_values
audio_cache_length = 0
audio_cache_checksum = None
audio_past_key_values_clone = None
if self.audio_past_key_values is not None:
# 处理 DynamicCache 格式(Whisper encoder 可能返回此格式)
if isinstance(self.audio_past_key_values, DynamicCache):
if hasattr(self.audio_past_key_values, "key_cache") and len(self.audio_past_key_values.key_cache) > 0:
audio_cache_length = self.audio_past_key_values.key_cache[0].shape[2]
audio_cache_checksum = self.audio_past_key_values.key_cache[0].sum().item()
# 深度克隆 DynamicCache
cloned_cache = DynamicCache()
for k, v in zip(self.audio_past_key_values.key_cache, self.audio_past_key_values.value_cache):
cloned_cache.update(k.clone(), v.clone(), layer_idx=len(cloned_cache.key_cache))
audio_past_key_values_clone = cloned_cache
logger.debug(f"[Speculative] Cloned DynamicCache with length {audio_cache_length}")
# 处理 EncoderDecoderCache 格式
elif isinstance(self.audio_past_key_values, EncoderDecoderCache):
self_attn_cache = self.audio_past_key_values.self_attention_cache
if hasattr(self_attn_cache, "key_cache") and len(self_attn_cache.key_cache) > 0:
audio_cache_length = self_attn_cache.key_cache[0].shape[2]
audio_cache_checksum = self_attn_cache.key_cache[0].sum().item()
# 深度克隆 EncoderDecoderCache
cloned_self_attn = DynamicCache()
if hasattr(self_attn_cache, "key_cache"):
for k, v in zip(self_attn_cache.key_cache, self_attn_cache.value_cache):
cloned_self_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_self_attn.key_cache))
cross_attn_cache = self.audio_past_key_values.cross_attention_cache
cloned_cross_attn = DynamicCache()
if hasattr(cross_attn_cache, "key_cache"):
for k, v in zip(cross_attn_cache.key_cache, cross_attn_cache.value_cache):
cloned_cross_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_cross_attn.key_cache))
audio_past_key_values_clone = EncoderDecoderCache(cloned_self_attn, cloned_cross_attn)
logger.debug(f"[Speculative] Cloned EncoderDecoderCache with length {audio_cache_length}")
# 处理 tuple 格式(兼容旧格式)
elif isinstance(self.audio_past_key_values, tuple) and len(self.audio_past_key_values) > 0:
audio_cache_length = self.audio_past_key_values[0][0].shape[2]
audio_cache_checksum = self.audio_past_key_values[0][0].sum().item()
# 深度克隆 audio_past_key_values(tuple of tuples of tensors)
audio_past_key_values_clone = tuple(
tuple(t.clone() for t in layer_cache) for layer_cache in self.audio_past_key_values
)
# 3. 获取 mel processor 快照
mel_processor_snapshot = None
mel_buffer_checksum = None
if hasattr(self, "processor") and self.processor is not None:
mel_processor_snapshot = self.processor.get_streaming_snapshot()
if mel_processor_snapshot:
buf = mel_processor_snapshot.get("buffer")
if buf is not None and len(buf) > 0:
mel_buffer_checksum = float(buf.sum())
# 4. 保存 RNG 状态(关键:用于恢复后确保 dithering 等随机操作的确定性)
rng_state_cpu = torch.get_rng_state()
rng_state_cuda = None
if torch.cuda.is_available() and self.device.type == "cuda":
rng_state_cuda = torch.cuda.get_rng_state(self.device)
# 5. 创建快照
snapshot = SpeculativeSnapshot(
llm_cache_length=llm_cache_length,
audio_cache_length=audio_cache_length,
new_user_msg=self.new_user_msg,
llm_generated=self.llm_generated,
llm_generate_completed=self.llm_generate_completed,
next_round_id=self._next_round_id,
pending_round_id=self._pending_round_id,
omni_chunk_history_length=len(self._omni_chunk_history),
tts_last_turn_tokens=self.tts_last_turn_tokens.clone() if self.tts_last_turn_tokens is not None else None,
audio_chunk_idx=self.audio_chunk_idx,
mel_processor_snapshot=mel_processor_snapshot,
audio_past_key_values=audio_past_key_values_clone,
timestamp=time.time(),
# 调试字段
llm_cache_checksum=llm_cache_checksum,
audio_cache_checksum=audio_cache_checksum,
mel_buffer_checksum=mel_buffer_checksum,
# RNG 状态
rng_state_cpu=rng_state_cpu,
rng_state_cuda=rng_state_cuda,
)
logger.info("[Speculative] Saved snapshot: %s", snapshot.summary())
logger.debug(
"[Speculative] Snapshot checksums: llm=%.6f, audio=%.6f, mel_buf=%.6f",
llm_cache_checksum or 0.0,
audio_cache_checksum or 0.0,
mel_buffer_checksum or 0.0,
)
return snapshot
def restore_speculative_snapshot(self) -> bool:
"""Restore speculative snapshot - called when VAD speculation fails.
Restores model state to before streaming_generate was called,
allowing continued streaming_prefill for newly arrived audio.
Notes:
- Snapshot is saved when streaming_generate is called with enable_speculative_snapshot=True
- This method uses the most recent snapshot for restoration
- Snapshot is cleared after restore, cannot be called repeatedly
Returns:
bool: Whether restoration was successful
"""
snapshot = getattr(self, "_speculative_snapshot", None)
if snapshot is None:
logger.warning("[Speculative] No snapshot to restore")
return False
try:
# 记录恢复前的状态(用于日志对比)
current_cache_length = self._get_kv_cache_length()
current_history_length = len(self._omni_chunk_history)
logger.info(
"[Speculative] Restoring snapshot: target=%s",
snapshot.summary(),
)
logger.info(
"[Speculative] Current state before restore: llm_cache=%d, history_len=%d, "
"audio_chunk_idx=%d, new_user_msg=%s, llm_generated=%s",
current_cache_length,
current_history_length,
self.audio_chunk_idx,
self.new_user_msg,
self.llm_generated,
)
# 1. 裁剪 LLM KV Cache
if current_cache_length > snapshot.llm_cache_length:
self._truncate_llm_cache(snapshot.llm_cache_length)
logger.debug(
"[Speculative] Truncated LLM cache: %d -> %d",
current_cache_length,
snapshot.llm_cache_length,
)
# 2. 恢复 Audio KV Cache(关键:从克隆的副本恢复)
# 因为 streaming_generate 会将 audio_past_key_values 设为 None
self.audio_past_key_values = snapshot.audio_past_key_values
if snapshot.audio_past_key_values is not None:
logger.debug(
"[Speculative] Restored audio cache: length=%d, checksum=%.6f",
snapshot.audio_cache_length,
snapshot.audio_cache_checksum or 0.0,
)
else:
logger.debug("[Speculative] Audio cache restored to None")
# 3. 恢复会话状态
self.new_user_msg = snapshot.new_user_msg
self.llm_generated = snapshot.llm_generated
self.llm_generate_completed = snapshot.llm_generate_completed
# 4. 恢复 Round 管理
self._next_round_id = snapshot.next_round_id
self._pending_round_id = snapshot.pending_round_id
# 5. 截断 chunk 历史
if current_history_length > snapshot.omni_chunk_history_length:
self._omni_chunk_history = self._omni_chunk_history[: snapshot.omni_chunk_history_length]
logger.debug(
"[Speculative] Truncated chunk history: %d -> %d",
current_history_length,
snapshot.omni_chunk_history_length,
)
# 6. 恢复 TTS 状态
self.tts_last_turn_tokens = snapshot.tts_last_turn_tokens
# 7. 恢复 streaming 处理器状态
self.audio_chunk_idx = snapshot.audio_chunk_idx
# 8. 恢复 mel processor 状态(关键!否则后续 prefill 会因帧数不匹配而失败)
if (
snapshot.mel_processor_snapshot is not None
and hasattr(self, "processor")
and self.processor is not None
):
self.processor.restore_streaming_snapshot(snapshot.mel_processor_snapshot)
mel_snap = snapshot.mel_processor_snapshot
logger.info(
"[Speculative] Restored mel processor: buffer_len=%d, chunk_count=%d, "
"last_emitted_T=%d, total_samples=%d",
len(mel_snap.get("buffer", [])),
mel_snap.get("chunk_count", 0),
mel_snap.get("last_emitted_T", 0),
mel_snap.get("total_samples_processed", 0),
)
# 9. 恢复 RNG 状态(关键:确保 dithering 等随机操作的确定性)
if snapshot.rng_state_cpu is not None:
torch.set_rng_state(snapshot.rng_state_cpu)
logger.debug("[Speculative] Restored CPU RNG state")
if snapshot.rng_state_cuda is not None and torch.cuda.is_available():
torch.cuda.set_rng_state(snapshot.rng_state_cuda, self.device)
logger.debug("[Speculative] Restored CUDA RNG state")
# 11. 清理生成过程中产生的临时状态
if hasattr(self, "_streaming_generated_token_ids"):
del self._streaming_generated_token_ids
if hasattr(self, "_last_streaming_text"):
del self._last_streaming_text
# 12. 验证恢复后的状态
restored_cache_length = self._get_kv_cache_length()
if restored_cache_length != snapshot.llm_cache_length:
logger.warning(
"[Speculative] LLM cache length mismatch after restore: expected=%d, actual=%d",
snapshot.llm_cache_length,
restored_cache_length,
)
# 验证 LLM cache checksum(如果有)
if snapshot.llm_cache_checksum is not None and self.llm_past_key_values is not None:
if hasattr(self.llm_past_key_values, "key_cache") and len(self.llm_past_key_values.key_cache) > 0:
current_checksum = self.llm_past_key_values.key_cache[0].sum().item()
if abs(current_checksum - snapshot.llm_cache_checksum) > 1e-3:
logger.warning(
"[Speculative] LLM cache checksum mismatch: expected=%.6f, actual=%.6f",
snapshot.llm_cache_checksum,
current_checksum,
)
else:
logger.debug(
"[Speculative] LLM cache checksum verified: %.6f",
current_checksum,
)
# 11. 清除快照(只能恢复一次)
self._speculative_snapshot = None
logger.info(
"[Speculative] Restore completed: llm_cache %d -> %d, elapsed=%.3fs",
current_cache_length,
snapshot.llm_cache_length,
time.time() - snapshot.timestamp,
)
return True
except Exception as e:
import traceback
logger.error("[Speculative] Failed to restore snapshot: %s", e)
logger.error("[Speculative] Traceback: %s", traceback.format_exc())
return False
def has_speculative_snapshot(self) -> bool:
return getattr(self, "_speculative_snapshot", None) is not None
def clear_speculative_snapshot(self) -> None:
if hasattr(self, "_speculative_snapshot"):
self._speculative_snapshot = None
def _truncate_llm_cache(self, target_length: int) -> None:
if self.llm_past_key_values is None:
return
cache = self._ensure_dynamic_cache()
if cache is None:
return
current_length = self._get_kv_cache_length(cache)
if current_length <= target_length:
return
# 裁剪每一层的 cache
for layer_idx in range(len(cache.key_cache)):
if cache.key_cache[layer_idx].numel() > 0:
cache.key_cache[layer_idx] = cache.key_cache[layer_idx][:, :, :target_length, :].contiguous()
cache.value_cache[layer_idx] = cache.value_cache[layer_idx][:, :, :target_length, :].contiguous()
# 更新 cache 元数据
cache.crop(target_length)
cache._seen_tokens = target_length
logger.debug("[Speculative] Truncated LLM cache: %d -> %d", current_length, target_length)
# ============== 抢跑快照/恢复接口 结束 ==============
@torch.inference_mode()
def streaming_prefill(
self,
session_id,
msgs,
tokenizer=None, # deprecated
omni_mode=True,
max_slice_nums=None,
use_tts_template=True,
enable_thinking=False,
is_last_chunk=False, # for audio chunk, if is the last chunk, set to True
**kwargs,
):
assert session_id is not None, "session_id cannot be None"
self.is_first = self.session_id is None or session_id != self.session_id
if not hasattr(self, "processor") or self.processor is None:
self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
images = []
audios = []
assert len(msgs) == 1
copy_msgs = deepcopy(msgs)
msg = copy_msgs[0]
assert msg["role"] in ["system", "user", "assistant"]
is_not_system_prefill = msg["role"] != "system"
content = msg["content"]
cur_msgs = []
for j, c in enumerate(content):
if isinstance(c, Image.Image):
images.append(c)
cur_msgs.append("./")
elif isinstance(c, np.ndarray):
audios.append(c)
cur_msgs.append("")
elif isinstance(c, str):
cur_msgs.append(c)
else:
logger.error(f"Invalid content type: {c}, ignore it.")
cur_contents = "".join(cur_msgs) if omni_mode else "\n".join(cur_msgs)
if msg["role"] in ["system", "assistant"]:
self.new_user_msg = True
self.audio_past_key_values = None
if self.is_first:
self.reset_session(reset_token2wav_cache=False)
self.session_id = session_id
self.init_streaming_processor()
if msg["role"] == "user":
# 没有 system prefill,第一个 user turn 的第一个 segment
# 不使用 apply_chat_template,手动构建 prompt 以避免自动添加 <|im_end|>
prompt = "<|im_start|>user\n" + cur_contents
self.new_user_msg = False # 标记后续 segments 不需要再添加 user 前缀
else:
# system 或 assistant prefill,使用 apply_chat_template
msg["content"] = cur_contents
prompt = self.processor.tokenizer.apply_chat_template(
copy_msgs,
tokenize=False,
add_generation_prompt=False,
use_tts_template=use_tts_template,
enable_thinking=enable_thinking,
)
add_special_tokens = True # add bos
else:
# 非首次 prefill
if self.new_user_msg and msg["role"] == "user":
# 新的 user turn 的第一个 segment
if self.llm_generated:
# todo: when to set llm_generate_completed?
if self.llm_generate_completed:
prompt = "<|im_end|>\n<|im_start|>user\n" + cur_contents
else:
prompt = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents
else:
prompt = "<|im_start|>user\n" + cur_contents
self.new_user_msg = False
else:
# 同一个 turn 的后续 segments,直接使用内容
prompt = cur_contents
add_special_tokens = False
# when first user audio prefill, ensure audio length satisfies FIRST_CHUNK_MS requirements
if is_not_system_prefill and len(audios) > 0 and self.audio_chunk_idx == 0:
assert len(audios) == 1, f"streaming mode only supports single audio, currently {len(audios)}"
first_chunk_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000)
if len(audios[0]) < first_chunk_samples:
pad_len = first_chunk_samples - len(audios[0])
audios[0] = np.concatenate([np.zeros(pad_len, dtype=audios[0].dtype), audios[0]])
model_inputs = self.processor(
[prompt],
[images],
[audios],
max_slice_nums=1 if max_slice_nums is None else max_slice_nums,
use_image_id=False,
chunk_input=True,
return_tensors="pt",
max_length=None,
sampling_rate=16000,
add_special_tokens=add_special_tokens,
online_streaming=is_not_system_prefill,
audio_chunk_idx=self.audio_chunk_idx,
is_last_chunk=is_last_chunk,
).to(self.device)
# DEBUG: 打印 mel 特征的 checksum(用于诊断 rollback 不一致问题)
if len(audios) > 0 and is_not_system_prefill and hasattr(self, "_debug_prefill") and self._debug_prefill:
audio_feats = model_inputs.get("audio_features", None)
if audio_feats is not None and hasattr(audio_feats, "sum"):
mel_sum = audio_feats.sum().item()
mel_shape = audio_feats.shape
print(
f"[DEBUG prefill] audio_chunk_idx={self.audio_chunk_idx}, mel_sum={mel_sum:.6f}, mel_shape={mel_shape}"
)
else:
print(f"[DEBUG prefill] audio_chunk_idx={self.audio_chunk_idx}, audio_feats type={type(audio_feats)}")
if len(audios) > 0 and is_not_system_prefill:
self.audio_chunk_idx += 1
# 1. prepare input embeddings
model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs)
# get audio embedding with audio_past_key_values
# todo: should pass chunk_length=self.config.audio_chunk_length ?
inputs_embeds = self.get_omni_embedding(
model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=is_not_system_prefill
)
# DEBUG: 打印 inputs_embeds 的 checksum
if len(audios) > 0 and is_not_system_prefill and hasattr(self, "_debug_prefill") and self._debug_prefill:
embed_sum = inputs_embeds.sum().item()
embed_shape = inputs_embeds.shape
print(f"[DEBUG prefill] inputs_embeds sum={embed_sum:.6f}, shape={embed_shape}")
if self.is_first:
self.audio_past_key_values = None # clean audio_past_key_values after first prefill
round_id = self._next_round_id
self._pending_round_id = round_id
chunk_type = "system" if msg["role"] == "system" else ("user" if msg["role"] == "user" else "assistant")
seq_len = inputs_embeds.shape[1]
self._enforce_text_window()
cache_length = self._get_kv_cache_length()
attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device)
# 2. do prefill
outputs = self.llm(
past_key_values=self.llm_past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=None,
use_cache=True,
return_dict=True,
)
self.llm_past_key_values = as_dynamic_cache(outputs["past_key_values"])
self._register_chunk(
seq_len,
chunk_type,
round_id=round_id,
input_ids=model_inputs["input_ids"],
tokenizer=self.processor.tokenizer,
)
self._enforce_text_window()
if self.force_rope_reindex:
self._force_reindex_all_cache()
return prompt
@torch.inference_mode()
def streaming_generate(
self,
session_id,
tokenizer=None, # deprecated
bos_input=None,
generate_audio=True,
audio_token_chunk_size=25, # 25 token/s
tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
max_new_tokens=256,
fn="chunk_generate",
enable_thinking=False,
use_tts_template=True,
do_sample=True,
enable_speculative_snapshot=False,
**kwargs,
):
# 保存抢跑快照(在修改任何状态之前)
# 用于 VAD 抢跑场景:如果抢跑失败,可调用 restore_speculative_snapshot() 恢复
# enable_speculative_snapshot=True 时启用,False 时跳过(节省少量开销)
if enable_speculative_snapshot:
self._speculative_snapshot = self._save_speculative_snapshot()
# reset buf
self.new_user_msg = True
self.llm_generated = True
self.llm_generate_completed = False
self.audio_past_key_values = None
if not hasattr(self, "processor") or self.processor is None:
self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
# reset current turn generated token IDs
if hasattr(self, "_streaming_generated_token_ids"):
del self._streaming_generated_token_ids
# reset full generated text
if hasattr(self, "_last_streaming_text"):
del self._last_streaming_text
cache = self._ensure_dynamic_cache()
cache_length = self._get_kv_cache_length(cache)
host_round_id = self._pending_round_id
logger.info("streaming_generate kv cache length before= %s", cache_length)
## 单工情况每调用一次 streaming_generate 需要重新初始化 streaming_processor, 进入下一个 turn
self.init_streaming_processor()
# 1) llm generate token and hidden states per chunk=10, 2) tts generate audio token chunk per chunk=25, 3) yield 1 chunk audio token
def audio_chunk_generator(
bos_input,
tokenizer,
generate_audio,
tts_sampling_params,
max_new_tokens,
do_sample,
**kwargs,
):
generate_chunk_size = 10
if bos_input is None:
bos_input = "".join(
[
"<|im_end|>\n<|im_start|>assistant\n",
"" if enable_thinking else self.think_str.replace("\\n", "\n"),
"<|tts_bos|>" if use_tts_template else "",
]
)
bos_input_ids = tokenizer.encode(bos_input)
bos_input_ids = torch.tensor(bos_input_ids, dtype=torch.long, device=self.device).unsqueeze(0)
# DEBUG: 打印生成开始时的状态
_cache_len = self._get_kv_cache_length()
_cache_sum = self.llm_past_key_values.key_cache[0].sum().item() if self.llm_past_key_values else 0
# 检查 KV Cache 最后几个位置的值
_k_last = (
self.llm_past_key_values.key_cache[0][0, 0, -5:, :3].flatten().tolist()
if self.llm_past_key_values
else []
)
print(f"[DEBUG streaming_generate] cache_len={_cache_len}, cache_sum={_cache_sum:.6f}, k_last={_k_last}")
bos_input_embeds = self.llm.get_input_embeddings()(bos_input_ids)
generation_inputs_embeds = bos_input_embeds
generated_ids = torch.empty((1, 0), dtype=torch.long, device=self.device)
num_chunks_decode = (max_new_tokens + generate_chunk_size - 1) // generate_chunk_size
conditions = []
# generate chunk by chunk, each chunk has 10 tokens, each chunk takes last hidden states, and pass tokens to tts
llm_streaming_generator = ChunkPrefillChunkGenerate(
model=self.llm,
tokenizer=tokenizer,
terminators=["<|tts_eos|>", "<|im_end|>", ""],
)
if generate_audio:
logits_warpers, logits_processors = gen_logits(
num_code=self.tts.config.num_audio_tokens,
repetition_penalty=tts_sampling_params.repetition_penalty,
top_p=tts_sampling_params.top_p,
top_k=tts_sampling_params.top_k,
)
tts_streaming_generator = TTSStreamingGenerator(
model=self.tts,
temperature=tts_sampling_params.temperature,
eos_token=torch.tensor(
[self.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=self.tts.device,
),
chunk_size=audio_token_chunk_size, # s3tokenizer 1s = 25token
tts_last_turn_tokens=self.tts_last_turn_tokens,
logits_processors=logits_processors,
logits_warpers=logits_warpers,
)
# LLM chunk generate outer loop
for chunk_idx in range(num_chunks_decode):
is_first_generate_chunk = chunk_idx == 0
if fn == "chunk_generate":
output = llm_streaming_generator.chunk_generate(
inputs_embeds=generation_inputs_embeds,
past_key_values=self.llm_past_key_values,
is_first_generate_chunk=is_first_generate_chunk,
return_hidden_states=True,
chunk_size=generate_chunk_size + 1 * is_first_generate_chunk,
do_sample=do_sample,
temperature=kwargs.get("temperature", 0.7),
top_p=kwargs.get("top_p", 0.8),
top_k=kwargs.get("top_k", 100),
repetition_penalty=kwargs.get("repetition_penalty", 1.02),
all_input_ids=generated_ids,
)
else:
output = llm_streaming_generator.chunk_generate_hf(
inputs_embeds=generation_inputs_embeds,
past_key_values=self.llm_past_key_values,
is_first_generate_chunk=is_first_generate_chunk,
return_hidden_states=True,
chunk_size=generate_chunk_size + 1 * is_first_generate_chunk,
do_sample=do_sample,
**kwargs,
)
if output.chunk_token_ids is None:
break
# DEBUG: 打印第一个 chunk 生成的 token
if chunk_idx == 0:
print(f"[DEBUG streaming_generate] first_chunk_tokens={output.chunk_token_ids.tolist()}")
if is_first_generate_chunk:
if generate_audio:
spk_emb = torch.empty(
(bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]),
dtype=bos_input_embeds.dtype,
device=bos_input_embeds.device,
)
tts_streaming_generator.spk_emb = spk_emb
if output.finished:
yield_chunk_token_ids = output.chunk_token_ids
else:
# the first chunk generated chunk_size + 1 tokens, we only take the first chunk_size tokens,
# the last token is not prefilled, and last hidden states is not obtained
yield_chunk_token_ids = output.chunk_token_ids[:, :-1]
elif output.finished:
yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids], dim=1)
else:
# in the chunk that is not the first chunk, we need to add the token at the end of the previous chunk,
# it is not prefilled into the model to get last hidden states
# similarly, the last generated token of subsequent chunks is not prefilled, and last hidden states is not obtained,
# so it is not passed out
yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids[:, :-1]], dim=1)
if not generate_audio:
chunk_generated_text = tokenizer.decode(yield_chunk_token_ids[0])
yield yield_chunk_token_ids, output.finished
else:
# TTS inner loop
# dense connection here is hardcoded to use text-hidden merged as condition
llm_embeds = self.tts.emb_text(yield_chunk_token_ids)
hidden_embeds = output.last_hidden_states
hidden_embeds = self.tts.projector_semantic(hidden_embeds)
if self.tts.config.normalize_projected_hidden: # default should be opened
hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
tts_embeds = llm_embeds + hidden_embeds
conditions.append(tts_embeds)
# Store token IDs instead of decoded text to avoid UTF-8 multi-byte character truncation
if not hasattr(self, "_streaming_generated_token_ids"):
self._streaming_generated_token_ids = []
self._streaming_generated_token_ids.extend(yield_chunk_token_ids[0].tolist())
# there is buffer generated, each time exactly returns 25 audio tokens,
# the last audio chunk returns audio tokens of variable length, length [0, 25]
tts_generator = tts_streaming_generator.generate_with_buffer(
condition=tts_embeds, text_finished=output.finished
)
for audio_token_chunk, is_last_audio_chunk in tts_generator:
yield audio_token_chunk, is_last_audio_chunk
generated_ids = torch.cat([generated_ids, output.chunk_token_ids], dim=1)
generation_inputs_embeds = output.current_inputs_embeds
self.llm_past_key_values = output.past_key_values
if output.finished:
if generate_audio:
self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
break
# IMPORTANT: Flush remaining TTS buffer when LLM generation ends
# This handles BOTH cases:
# 1. LLM finished with terminator (output.finished=True) - buffer may still have tokens
# 2. LLM hit max chunks limit (output.finished=False) - buffer definitely has tokens
if generate_audio:
if len(tts_streaming_generator._token_buffer) > 0:
batch = torch.cat(tts_streaming_generator._token_buffer, dim=1)
yield batch, True
tts_streaming_generator._token_buffer = []
if generate_audio:
if hasattr(self, "_streaming_generated_token_ids"):
try:
self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids)
assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text)
self._finalize_round(
round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids
)
except Exception:
self._last_streaming_text = None
else:
self._last_streaming_text = None
yield None, None
else:
return
# iter for generating text chunk and audio chunk
audio_chunk_generator_iter = audio_chunk_generator(
bos_input=bos_input,
tokenizer=self.processor.tokenizer,
generate_audio=generate_audio,
tts_sampling_params=tts_sampling_params,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
**kwargs,
)
if generate_audio:
if self.tts.config.audio_tokenizer_type == "s3tokenizer_step_audio":
self.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.token2wav_cache["flow_cache_base"])
self.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(
self.token2wav_cache["hift_cache_base"]
)
# pre-insert 3-5 prefix 4218 silence tokens, each token corresponds to 0.04s,
# adding 5 tokens means introducing 0.2s of silence
buffer = [4218] * 3
pre_lookahead = 3
CHUNK_SIZE = 25
chunk_idx = 0
prev_text_len = 0 # track text position for streaming text output
for audio_token_chunk, is_last_audio_chunk in audio_chunk_generator_iter:
if audio_token_chunk is None:
break
buffer += audio_token_chunk.reshape(-1).tolist()
if len(buffer) >= CHUNK_SIZE + pre_lookahead:
waveform_chunk = self.tts.audio_tokenizer.stream(
buffer[: CHUNK_SIZE + pre_lookahead],
prompt_wav=None,
last_chunk=is_last_audio_chunk,
return_waveform=True,
)
waveform_chunk = torch.from_numpy(waveform_chunk)
# get new text chunk corresponding to this waveform
# Decode from accumulated token IDs to avoid UTF-8 multi-byte truncation
new_text = ""
if hasattr(self, "_streaming_generated_token_ids"):
current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids)
# Filter out trailing replacement characters (incomplete UTF-8 sequences)
safe_end = len(current_text)
while safe_end > 0 and current_text[safe_end - 1] == "\ufffd":
safe_end -= 1
safe_text = current_text[:safe_end]
new_text = safe_text[prev_text_len:]
prev_text_len = len(safe_text)
yield waveform_chunk, new_text
buffer = buffer[CHUNK_SIZE:]
chunk_idx += 1
# flush rest
if len(buffer) > 0:
waveform_chunk = self.tts.audio_tokenizer.stream(
buffer,
prompt_wav=None,
last_chunk=True,
return_waveform=True,
)
waveform_chunk = torch.from_numpy(waveform_chunk)
# get remaining new text for the final chunk
# Final chunk: decode all remaining text without filtering
new_text = ""
if hasattr(self, "_streaming_generated_token_ids"):
current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids)
new_text = current_text[prev_text_len:]
prev_text_len = len(current_text)
yield waveform_chunk, new_text
# maybe the buffer is empty, and text is not empty, should we flush text without wave?
else:
raise NotImplementedError(f"not supported audio tokenizer: {self.tts.config.audio_tokenizer_type}")
else:
# For text-only generation, decode tokens and handle partial multi-byte characters
yield from streaming_token_decoder(
audio_chunk_generator_iter,
self.processor.tokenizer,
skip_special_tokens=False,
)
class MiniCPMODuplex:
def __init__(
self,
name_or_path: str,
duplex_config: Optional[MiniCPMODuplexConfig] = None,
generate_audio: bool = True,
ls_mode: str = "explicit",
device: str = "cuda",
pt_path: Optional[str] = None,
**kwargs,
):
"""Initialize MiniCPMODuplex.
Args:
name_or_path: Path to the pretrained model or model identifier.
duplex_config: Optional MiniCPMODuplexConfig. If provided, overrides other params.
generate_audio: Whether to generate audio output.
ls_mode: Listen/Speak mode, e.g., "explicit".
device: Device to load the model on.
pt_path: Optional path to additional checkpoint weights.
**kwargs: Additional generation config parameters (overrides duplex_config if provided).
"""
self.session_logs = []
self.session_start_time = None
self.log_file_path = None
self.name_or_path = name_or_path
if duplex_config is not None:
self.duplex_config = duplex_config
self.generate_audio = kwargs.get("generate_audio", duplex_config.generate_audio)
self.ls_mode = kwargs.get("ls_mode", duplex_config.ls_mode)
attn_implementation = kwargs.get("attn_implementation", duplex_config.attn_implementation)
else:
self.duplex_config = None
self.generate_audio = generate_audio
self.ls_mode = ls_mode
attn_implementation = kwargs.get("attn_implementation", "flash_attention_2")
self.device = device
from transformers import AutoConfig
from transformers import AutoTokenizer
from .processing_minicpmo import MiniCPMOProcessor
from .stream_decoder import StreamDecoder
self.processor = MiniCPMOProcessor.from_pretrained(name_or_path, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(name_or_path, trust_remote_code=True)
self.processor.tokenizer = self.tokenizer
config = AutoConfig.from_pretrained(name_or_path, trust_remote_code=True)
vision_batch_size = kwargs.pop("vision_batch_size", None)
audio_pool_step = kwargs.pop("audio_pool_step", None)
audio_chunk_length = kwargs.pop("audio_chunk_length", None)
max_slice_nums = kwargs.pop("max_slice_nums", None)
if vision_batch_size is not None and hasattr(config, "vision_batch_size"):
config.vision_batch_size = vision_batch_size
if audio_pool_step is not None and hasattr(config, "audio_pool_step"):
config.audio_pool_step = audio_pool_step
if audio_chunk_length is not None and hasattr(config, "audio_chunk_length"):
config.audio_chunk_length = audio_chunk_length
if max_slice_nums is not None and hasattr(config.slice_config, "max_slice_nums"):
config.slice_config.max_slice_nums = max_slice_nums
self.model = MiniCPMO.from_pretrained(
name_or_path, config=config, trust_remote_code=True, attn_implementation=attn_implementation
)
self.model.to(torch.bfloat16)
self.model.processor = self.processor
if pt_path is not None:
logger.info(f"Loading checkpoint from {pt_path}")
state_dict = torch.load(pt_path, map_location="cpu")
info = self.model.load_state_dict(state_dict, strict=False)
logger.warning(info)
del state_dict
self.model.eval().to(device=device)
self.model.init_tts(
streaming=True,
enable_float16=kwargs.get("enable_float16", False),
n_timesteps=kwargs.get("n_timesteps", 10),
)
self.break_event = threading.Event()
self.session_stop_event = threading.Event()
# llm generation_config - from duplex_config or defaults
self.max_new_speak_tokens_per_chunk = self._get_value(
"max_new_speak_tokens_per_chunk", duplex_config, kwargs, 20
)
self.text_repetition_penalty = self._get_value("text_repetition_penalty", duplex_config, kwargs, 1.05)
self.temperature = self._get_value("temperature", duplex_config, kwargs, 0.7)
self.top_k = self._get_value("top_k", duplex_config, kwargs, 20)
self.top_p = self._get_value("top_p", duplex_config, kwargs, 0.8)
self.text_repetition_window_size = self._get_value("text_repetition_window_size", duplex_config, kwargs, 512)
self.listen_prob_scale = self._get_value("listen_prob_scale", duplex_config, kwargs, 1.0)
self.force_listen_count = self._get_value("force_listen_count", duplex_config, kwargs, 0)
# tts generation_config
tts_temp_value = self._get_value("tts_temperature", duplex_config, kwargs, 0.8)
self.tts_temperature = torch.tensor([tts_temp_value], dtype=torch.float, device=self.device)
self.tts_repetition_penalty = self._get_value("tts_repetition_penalty", duplex_config, kwargs, 1.05)
# stream config
self.CHUNK_MS = self._get_value("chunk_ms", duplex_config, kwargs, 1000)
self.FIRST_CHUNK_MS = self._get_value("first_chunk_ms", duplex_config, kwargs, 1035)
self.CNN_REDUNDANCY_MS = self._get_value("cnn_redundancy_ms", duplex_config, kwargs, 20)
self.SAMPLE_RATE = self._get_value("sample_rate", duplex_config, kwargs, 16000)
self.model.CHUNK_MS = self.CHUNK_MS
self.model.FIRST_CHUNK_MS = self.FIRST_CHUNK_MS
self.model.CNN_REDUNDANCY_MS = self.CNN_REDUNDANCY_MS
self.model.SAMPLE_RATE = self.SAMPLE_RATE
# special tokens
self.unit_token_id = self.tokenizer.convert_tokens_to_ids("")
self.image_start_token_id = self.tokenizer.convert_tokens_to_ids("")
self.image_end_token_id = self.tokenizer.convert_tokens_to_ids("")
self.slice_start_token_id = self.tokenizer.convert_tokens_to_ids("")
self.slice_end_token_id = self.tokenizer.convert_tokens_to_ids("")
self.listen_token_id = self.tokenizer.convert_tokens_to_ids("<|listen|>")
self.speak_token_id = self.tokenizer.convert_tokens_to_ids("<|speak|>")
self.tts_bos_token_id = self.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
self.tts_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
self.chunk_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>")
self.chunk_tts_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>")
self.turn_eos_token_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>")
self.chunk_terminator_token_ids = [self.listen_token_id, self.chunk_eos_token_id, self.chunk_tts_eos_token_id]
self.turn_terminator_token_ids = [self.turn_eos_token_id]
self.chunk_speak_token_ids = [self.speak_token_id]
self.tts_pad_id = self.tokenizer.convert_tokens_to_ids("<|tts_pad|>")
bad_token_ids = getattr(self.tokenizer, "bad_token_ids", [])
self.forbidden_token_ids = [self.tts_pad_id] + list(bad_token_ids)
self.decoder = StreamDecoder(
llm=self.model.llm, tokenizer=self.tokenizer, forbidden_token_ids=self.forbidden_token_ids
)
# 配置滑窗参数
from .stream_decoder import DuplexWindowConfig
# 滑窗模式: "off" / "basic" / "context"
sliding_window_mode = self._get_value("sliding_window_mode", duplex_config, kwargs, "off")
# 不带 Context 的滑窗参数
basic_window_high_tokens = self._get_value("basic_window_high_tokens", duplex_config, kwargs, 4000)
basic_window_low_tokens = self._get_value("basic_window_low_tokens", duplex_config, kwargs, 3500)
# 带 Context 的滑窗参数
context_previous_max_tokens = self._get_value("context_previous_max_tokens", duplex_config, kwargs, 500)
context_max_units = self._get_value("context_max_units", duplex_config, kwargs, 24)
self.decoder.set_window_config(
DuplexWindowConfig(
sliding_window_mode=sliding_window_mode,
basic_window_high_tokens=basic_window_high_tokens,
basic_window_low_tokens=basic_window_low_tokens,
context_previous_max_tokens=context_previous_max_tokens,
context_max_units=context_max_units,
)
)
# 根据 mode 设置滑窗开关
window_enabled = sliding_window_mode != "off"
self.decoder.set_window_enabled(window_enabled)
logger.info(
"[Duplex] Sliding window: mode=%s, high=%d, low=%d, prev_max=%d, max_units=%d",
sliding_window_mode,
basic_window_high_tokens,
basic_window_low_tokens,
context_previous_max_tokens,
context_max_units,
)
self.tts_logits_processors = None
self.tts_eos_token = None
if self.generate_audio:
self.tts_logits_processors = gen_logits(
num_code=self.model.tts.config.num_audio_tokens,
repetition_penalty=self.tts_repetition_penalty,
)
self.tts_eos_token = torch.tensor(
[self.model.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=self.device,
)
# for offline_generate
self.step_map = {
"sp_text": ["", "", "", "", "<|listen|>"],
"image": [""],
"audio": ["<|audio|>"],
}
self.loop = []
self._reset_streaming_state()
import gc
gc.collect()
torch.cuda.empty_cache()
@staticmethod
def _get_value(key, config, kwargs, default):
return kwargs.get(key, getattr(config, key, default))
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
device: str = "cuda",
pt_path: Optional[str] = None,
**kwargs,
) -> "MiniCPMODuplex":
"""Load MiniCPMODuplex from a pretrained model directory.
This method loads the duplex_config.json if available, and uses it to configure
the duplex model. Any kwargs provided will override the config values.
Args:
pretrained_model_name_or_path: Path to the model directory.
device: Device to load the model on.
pt_path: Optional path to additional checkpoint weights.
**kwargs: Override parameters for generation config.
Returns:
MiniCPMODuplex instance.
"""
try:
duplex_config = MiniCPMODuplexConfig.from_pretrained(pretrained_model_name_or_path)
except FileNotFoundError:
duplex_config = None
return cls(
name_or_path=pretrained_model_name_or_path,
duplex_config=duplex_config,
device=device,
pt_path=pt_path,
**kwargs,
)
def save_duplex_config(self, save_directory: str):
"""Save the current duplex configuration to a directory.
Args:
save_directory: Path to save the duplex_config.json.
"""
config = MiniCPMODuplexConfig(
generate_audio=self.generate_audio,
ls_mode=self.ls_mode,
max_new_speak_tokens_per_chunk=self.max_new_speak_tokens_per_chunk,
text_repetition_penalty=self.text_repetition_penalty,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
text_repetition_window_size=self.text_repetition_window_size,
listen_prob_scale=self.listen_prob_scale,
tts_temperature=self.tts_temperature.item(),
tts_repetition_penalty=self.tts_repetition_penalty,
chunk_ms=self.CHUNK_MS,
first_chunk_ms=self.FIRST_CHUNK_MS,
cnn_redundancy_ms=self.CNN_REDUNDANCY_MS,
sample_rate=self.SAMPLE_RATE,
)
config.save_pretrained(save_directory)
def set_break_event(self):
self.break_event.set()
def clear_break_event(self):
self.break_event.clear()
def set_session_stop(self):
self.session_stop_event.set()
self.break_event.set()
def clear_session_stop(self):
self.session_stop_event.clear()
def is_break_set(self) -> bool:
return self.break_event.is_set()
def is_session_stop_set(self) -> bool:
return self.session_stop_event.is_set()
def _init_token2wav_cache(self, prompt_wav_path: str):
self.model.tts.audio_tokenizer.cache = None
flow_cache, hift_cache = self.model.tts.audio_tokenizer.set_stream_cache(prompt_wav_path)
self.flow_cache_base = torch_clone_recursive(flow_cache)
self.hift_cache_base = torch_clone_recursive(hift_cache)
self.pre_lookahead = int(self.model.tts.audio_tokenizer.flow.pre_lookahead_len)
self.token2wav_initialized = True
def _reset_token2wav_for_new_turn(self):
if self.token2wav_initialized:
self.model.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.flow_cache_base)
self.model.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(self.hift_cache_base)
self.token2wav_buffer = [4218] * 3 # silence token prefix
def _reset_streaming_state(self):
self.audio_chunk_idx = 0
self.current_turn_ended = True
self.speak_count = 0
self.res_ids = []
self.total_ids = []
self.total_hidden = []
# TTS state
self.tts_text_start_pos = 0
self.tts_past_key_values = None
self.tts_current_turn_start_time = None
# token2wav state
self.token2wav_initialized = False
self.token2wav_buffer = []
self.flow_cache_base = None
self.hift_cache_base = None
# Audio prefill state
self.audio_buffer = np.array([], dtype=np.float32)
self.pending_logits: Optional[torch.Tensor] = None
self.current_mode: Optional[str] = None
# Force listen state
self._streaming_generate_count = 0
def prepare(
self,
prefix_system_prompt: Optional[str] = None,
suffix_system_prompt: Optional[str] = None,
ref_audio: Optional[np.ndarray] = None,
prompt_wav_path: Optional[str] = None,
mode: Literal["omni", "video", "audio"] = "omni",
context_previous_marker: str = "\n\nprevious: ",
):
self.clear_break_event()
self.clear_session_stop()
self.session_start_time = time.time()
self._reset_streaming_state()
self.decoder.reset()
self.model.init_streaming_processor()
if prompt_wav_path is not None and prompt_wav_path and self.generate_audio:
self._init_token2wav_cache(prompt_wav_path)
self._reset_token2wav_for_new_turn()
self._update_loop(mode)
# Prefill system prompt prefix
if prefix_system_prompt:
tokens = self.tokenizer.encode(prefix_system_prompt, add_special_tokens=False)
for token_id in tokens:
self.decoder.feed(self.decoder.embed_token(token_id))
# Prefill reference audio
if ref_audio is not None:
data = self.processor.process_audio([ref_audio])
embeds_nested = self.model.get_audio_embedding(data, chunk_length=self.model.config.audio_chunk_length)
embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) if embeds_nested else None
if embeds is not None:
self.decoder.feed(embeds)
# 注册 system prompt 保护长度(滑窗时保护这部分不被移除)
if prefix_system_prompt or suffix_system_prompt or ref_audio is not None:
logger.info("[Duplex] prepare: registering system prompt protection")
if self.decoder._window_config.sliding_window_mode == "context":
# Context 保留模式:
# 初始化时布局: [prefix] [suffix] [units...]
# 首次滑窗后布局: [prefix] [context_previous_marker + content] [suffix] [units...]
# 此时先注册 prefix 长度,再 feed suffix
self._prefix_system_prompt = prefix_system_prompt
self._suffix_system_prompt = suffix_system_prompt
self._ref_audio = ref_audio
# 获取 suffix token ids
suffix_token_ids = []
if suffix_system_prompt:
suffix_token_ids = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False)
# 注册(此时 cache 只有 prefix,还没有 suffix,也没有 previous)
self.decoder.register_system_prompt_with_context(
suffix_token_ids=suffix_token_ids,
context_previous_marker=context_previous_marker, # 首次滑窗时动态添加
)
# 现在 feed suffix
for token_id in suffix_token_ids:
self.decoder.feed(self.decoder.embed_token(token_id))
logger.info(
"[Duplex] prepare: context-preserve mode, prefix=%d, suffix=%d tokens, marker='%s'",
self.decoder._preserve_prefix_length,
len(suffix_token_ids),
context_previous_marker.replace("\n", "\\n"),
)
else:
# 非 context 保留模式:先 feed suffix,再注册总长度
if suffix_system_prompt:
tokens = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False)
for token_id in tokens:
self.decoder.feed(self.decoder.embed_token(token_id))
self.decoder.register_system_prompt()
if prefix_system_prompt or suffix_system_prompt:
if ref_audio is not None:
full_prompt = (prefix_system_prompt or "") + "[音频嵌入]" + (suffix_system_prompt or "")
else:
full_prompt = (prefix_system_prompt or "") + (suffix_system_prompt or "")
return full_prompt
return ""
@torch.no_grad()
def streaming_prefill(
self,
audio_waveform: Optional[np.ndarray] = None,
frame_list: Optional[list] = None,
max_slice_nums: Union[int, List[int]] = 1,
):
"""Streaming prefill - called once per second, processing audio/video data
Args:
audio_waveform: audio waveform data
frame_list: image frame list
max_slice_nums: maximum number of slices for HD image encoding (default 1, no slicing)
Can be an int (same for all images) or a list matching frame_list length
Process:
0. determine mode based on input: AUDIO / VISION / OMNI
1. feed token
2. get and feed image embed (if frame_list) - return pending logits in VISION MODE
3. get and feed audio embed (if audio_waveform) - return pending logits in AUDIO/OMNI MODE
Returns:
dict with keys:
- success: bool
- cost_vision_process: float (image processing time)
- cost_vision_embed: float (vision embedding time)
- cost_vision_feed: float (vision feed time)
- cost_audio_process: float (audio processing time)
- cost_audio_embed: float (audio embedding time)
- cost_audio_feed: float (audio feed time)
- cost_all: float (total time)
"""
start_time = time.time()
cost_vision_process = 0.0
cost_vision_embed = 0.0
cost_vision_feed = 0.0
cost_audio_process = 0.0
cost_audio_embed = 0.0
cost_audio_feed = 0.0
def _make_result(success, reasons=""):
reason = reasons
if isinstance(reasons, list):
reason = "; ".join(reasons)
return {
"success": success,
"reason": reason,
"cost_vision_process": cost_vision_process,
"cost_vision_embed": cost_vision_embed,
"cost_vision_feed": cost_vision_feed,
"cost_audio_process": cost_audio_process,
"cost_audio_embed": cost_audio_embed,
"cost_audio_feed": cost_audio_feed,
"cost_all": time.time() - start_time,
}
if self.is_session_stop_set() or self.is_break_set():
return _make_result(False)
has_frames = frame_list is not None and len(frame_list) > 0
has_audio = audio_waveform is not None and len(audio_waveform) > 0
if has_frames and has_audio:
mode = "OMNI"
elif has_frames:
mode = "VISION"
elif has_audio:
mode = "AUDIO"
else:
return _make_result(False)
self.pending_logits = None
# 滑窗:记录 unit 开始位置
logger.info(
"[Duplex] streaming_prefill: mode=%s, has_frames=%s, has_audio=%s, starting unit",
mode,
has_frames,
has_audio,
)
self.decoder.register_unit_start()
# Step 1: Feed token
self.decoder.feed(self.decoder.embed_token(self.unit_token_id))
# Step 2: process image
if has_frames:
t0 = time.time()
# Normalize max_slice_nums to a list matching frame_list length
if isinstance(max_slice_nums, int):
max_slice_nums_list = [max_slice_nums] * len(frame_list)
else:
max_slice_nums_list = list(max_slice_nums)
if len(max_slice_nums_list) != len(frame_list):
raise ValueError(
f"max_slice_nums list length ({len(max_slice_nums_list)}) "
f"must match frame_list length ({len(frame_list)})"
)
# Check if all max_slice_nums are the same (can use batch processing)
all_same = len(set(max_slice_nums_list)) == 1
if all_same:
# All images use the same max_slice_nums, use batch processing
processed_frames = self.processor.process_image(frame_list, max_slice_nums=max_slice_nums_list[0])
if self.device:
processed_frames = processed_frames.to(self.device)
else:
# Different max_slice_nums per image, process individually and merge
all_pixel_values = []
all_tgt_sizes = []
for frame, max_slices in zip(frame_list, max_slice_nums_list):
pf = self.processor.process_image([frame], max_slice_nums=max_slices)
if self.device:
pf = pf.to(self.device)
# pf["pixel_values"][0] is the list of slices for this image
all_pixel_values.extend(pf["pixel_values"][0])
# pf["tgt_sizes"][0] is the array of target sizes for this image's slices
if hasattr(pf["tgt_sizes"][0], "tolist"):
all_tgt_sizes.extend(pf["tgt_sizes"][0].tolist())
else:
all_tgt_sizes.extend(list(pf["tgt_sizes"][0]))
# Reconstruct processed_frames with merged data
processed_frames = {
"pixel_values": [all_pixel_values],
"tgt_sizes": [torch.tensor(all_tgt_sizes) if all_tgt_sizes else []],
}
cost_vision_process = time.time() - t0
t0 = time.time()
# Get vision embeddings for all images (each may have multiple slices)
# vision_hidden_states is a list, one entry per input image
# Each entry contains embeddings for [source_image, slice_1, slice_2, ...]
vision_hidden_states = self.model.get_vision_embedding(processed_frames)
cost_vision_embed = time.time() - t0
if vision_hidden_states is not None and len(vision_hidden_states) > 0:
t0 = time.time()
# vision_hidden_states[0] contains ALL slices from ALL images (flattened)
# Shape: [total_slices, 64, D] where total_slices = sum of slices across all images
# We need to know how many slices each image has to correctly group them
# Calculate slice counts for each image using get_sliced_grid (lightweight, no actual slicing)
slice_counts = [] # e.g., [5, 9] means img1 has 5 slices (1 source + 4 HD), img2 has 9
for frame_idx, frame in enumerate(frame_list):
max_slices = max_slice_nums_list[frame_idx]
if hasattr(frame, "size"):
# get_sliced_grid returns [M, N] grid or None if no slicing needed
# Total images = 1 (source) + M * N (HD slices)
grid = self.processor.image_processor.get_sliced_grid(
frame.size, max_slices, nerver_split=False
)
if grid is not None:
slice_counts.append(1 + grid[0] * grid[1]) # 1 source + M*N slices
else:
slice_counts.append(1) # No slicing, only source image
else:
slice_counts.append(1) # Default: single image, no slicing
# Get the flattened embeddings tensor
# vision_hidden_states is a list with one element (the batch)
# vision_hidden_states[0] shape: [total_slices, 64, D]
all_embeds = vision_hidden_states[0]
# Collect all feed operations first, then execute
# This allows us to identify the last token for VISION mode logits
feed_operations = [] # List of (embed, is_last_for_vision_mode)
embed_idx = 0 # Current index in all_embeds
for img_idx, num_slices in enumerate(slice_counts):
if num_slices == 0:
continue
# First embedding is always the source image (downsampled overview)
# Feed token
feed_operations.append((self.decoder.embed_token(self.image_start_token_id), False))
# Feed source image embedding (shape: [64, D])
feed_operations.append((all_embeds[embed_idx], False))
# Feed token
feed_operations.append((self.decoder.embed_token(self.image_end_token_id), False))
embed_idx += 1
# Remaining embeddings are HD slices (if num_slices > 1)
if num_slices > 1:
for slice_i in range(1, num_slices):
# Feed token
feed_operations.append((self.decoder.embed_token(self.slice_start_token_id), False))
# Feed slice embedding (shape: [64, D])
feed_operations.append((all_embeds[embed_idx], False))
# Feed token
feed_operations.append((self.decoder.embed_token(self.slice_end_token_id), False))
embed_idx += 1
# Mark the last operation for VISION mode logits
if feed_operations:
feed_operations[-1] = (feed_operations[-1][0], True)
# Execute all feed operations
for embed, is_last in feed_operations:
if mode == "VISION" and is_last:
# Get logits from the last token
self.pending_logits, _ = self.decoder.feed(embed, return_logits=True)
else:
self.decoder.feed(embed)
# For OMNI MODE, no pending logits needed here (wait for audio)
cost_vision_feed = time.time() - t0
# Step 3: process audio (if any)
if has_audio:
# accumulate audio to buffer
self.audio_buffer = np.concatenate([self.audio_buffer, audio_waveform])
# calculate required audio length
if self.audio_chunk_idx == 0:
required_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000)
if len(self.audio_buffer) < required_samples:
padding_samples = required_samples - len(self.audio_buffer)
padding = np.zeros(padding_samples, dtype=np.float32)
self.audio_buffer = np.concatenate([padding, self.audio_buffer])
else:
required_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000)
need_samples = self.processor.get_streaming_chunk_size()
if len(self.audio_buffer) < need_samples:
return _make_result(False, f"音频不足: 需要 {need_samples} 样本, 只有 {len(self.audio_buffer)}")
audio_chunk = self.audio_buffer[:need_samples]
t0 = time.time()
batch_feature = self.processor.process_audio_streaming(
audio_chunk,
reset=False,
return_batch_feature=True,
)
if batch_feature is None or batch_feature.audio_features.shape[-1] == 0:
return _make_result(False, "流式音频处理返回空")
# metadata
batch_feature.chunk_idx = self.audio_chunk_idx
batch_feature.use_extra_context = True
batch_feature.prefix_extra_frames = 0 if self.audio_chunk_idx == 0 else 2
batch_feature.suffix_extra_frames = 2
batch_feature = batch_feature.to(self.device)
cost_audio_process = time.time() - t0
t0 = time.time()
embeds_nested = self.model.get_audio_embedding_streaming(
batch_feature,
use_extra_context=batch_feature.use_extra_context,
prefix_extra_frames=batch_feature.prefix_extra_frames,
suffix_extra_frames=batch_feature.suffix_extra_frames,
)
audio_embeds = torch.cat([t for g in embeds_nested for t in g], dim=0)
cost_audio_embed = time.time() - t0
t0 = time.time()
self.pending_logits, _ = self.decoder.feed(audio_embeds, return_logits=True)
cost_audio_feed = time.time() - t0
if self.audio_chunk_idx == 0:
cfg = self.processor._streaming_mel_processor.get_config()
consumed_ms = int(cfg.get("effective_first_chunk_ms", self.FIRST_CHUNK_MS))
consumed_samples = int(consumed_ms * self.SAMPLE_RATE / 1000)
else:
consumed_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000)
self.audio_buffer = self.audio_buffer[consumed_samples:]
self.audio_chunk_idx += 1
self.current_mode = mode
# for VISION mode, need to manually increase chunk count (AUDIO and OMNI modes already increased in _process_audio_buffer)
if mode == "VISION":
self.audio_chunk_idx += 1
return _make_result(True)
@torch.no_grad()
def streaming_generate(
self,
prompt_wav_path=None,
max_new_speak_tokens_per_chunk=20,
decode_mode: str = "sampling",
temperature=0.7,
top_k=20,
top_p=0.8,
listen_prob_scale=1.0,
listen_top_k=None,
text_repetition_penalty=1.05,
text_repetition_window_size=512,
):
start_time = time.time()
if self.is_session_stop_set() or self.is_break_set():
return {
"is_listen": True,
"text": "",
"audio_waveform": self._generate_silence_waveform(),
"end_of_turn": True,
"current_time": self.audio_chunk_idx,
"cost_llm": 0.0,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": 0,
"n_tts_tokens": 0,
}
# check if there are pending logits to process
if not hasattr(self, "pending_logits") or self.pending_logits is None:
return {
"is_listen": True,
"text": "",
"audio_waveform": self._generate_silence_waveform(),
"end_of_turn": False,
"current_time": self.audio_chunk_idx,
"cost_llm": 0.0,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": 0,
"n_tts_tokens": 0,
}
# use pending logits generated in streaming_prefill
logits = self.pending_logits
self.pending_logits = None
# Force listen: check if we should force listen for first N calls
force_listen = self._streaming_generate_count < self.force_listen_count
self._streaming_generate_count += 1
if force_listen:
print(f"[Duplex] streaming_generate: force_listen=True (call #{self._streaming_generate_count})")
total_hidden_in_unit = []
total_ids_in_unit = []
current_time = self.audio_chunk_idx
is_listen = False
end_of_turn = False
llm_start_time = time.time()
for j in range(max_new_speak_tokens_per_chunk):
if j == max_new_speak_tokens_per_chunk - 1:
if self.ls_mode == "explicit":
self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id))
self.total_ids.append(self.chunk_eos_token_id)
break
if force_listen:
last_id = torch.tensor([self.listen_token_id], dtype=torch.long, device=self.device)
else:
last_id = self.decoder.decode(
logits=logits,
mode=decode_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
listen_top_k=listen_top_k,
listen_prob_scale=listen_prob_scale,
text_repetition_penalty=text_repetition_penalty,
text_repetition_window_size=text_repetition_window_size,
debug_print_top5=False,
)
# if current turn not ended, not allowed to listen (only check when not force_listen)
if last_id.item() == self.listen_token_id and (not self.current_turn_ended):
last_id = torch.tensor([self.tts_bos_token_id], dtype=torch.long, device=self.device)
self.total_ids.append(last_id.item())
is_listen = last_id.item() == self.listen_token_id
# termination condition detection
if last_id.item() in self.chunk_terminator_token_ids:
if self.ls_mode == "explicit":
logits, _ = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
break
else:
# normal speak
self.current_turn_ended = False
if last_id.item() in self.chunk_speak_token_ids:
pass
else:
self.res_ids.append(last_id.item())
self.speak_count += 1
logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
assert len(hidden.shape) == 3
assert hidden.shape[0] == 1
assert hidden.shape[1] == 1
end_of_turn = last_id.item() in self.turn_terminator_token_ids
if end_of_turn:
self.current_turn_ended = True
if j != 0:
total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn])
total_ids_in_unit.append(last_id.item())
# Prefill token
unit_end_id = self.tokenizer.convert_tokens_to_ids("")
self.decoder.feed(self.decoder.embed_token(unit_end_id))
self.total_ids.append(unit_end_id)
# 计算生成的文本(用于滑窗 context 保留,过滤掉特殊 token)
generated_text = self.tokenizer.decode(total_ids_in_unit, skip_special_tokens=True) if total_ids_in_unit else ""
# 滑窗:注册 unit 结束,并检查是否需要滑窗
input_type = self.current_mode.lower() if self.current_mode else "audio"
logger.info(
"[Duplex] streaming_generate: completing unit, mode=%s, is_listen=%s, generated=%d tokens",
input_type,
is_listen,
len(total_ids_in_unit),
)
self.decoder.register_unit_end(
input_type=input_type,
generated_tokens=total_ids_in_unit,
is_listen=is_listen,
generated_text=generated_text,
)
# 根据滑窗模式选择滑窗方法
if self.decoder._window_config.sliding_window_mode == "context":
self.decoder.enforce_window_with_context()
elif self.decoder._window_config.sliding_window_mode == "basic":
self.decoder.enforce_window()
llm_end_time = time.time()
if is_listen:
self.total_hidden.append([])
return {
"is_listen": True,
"text": "",
"audio_waveform": self._generate_silence_waveform(),
"end_of_turn": False,
"current_time": current_time,
"cost_llm": llm_end_time - llm_start_time,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": len(total_ids_in_unit),
"n_tts_tokens": 0,
}
self.total_hidden.append(total_hidden_in_unit)
text = generated_text # 复用已计算的文本
print(f"> speak: {text}")
if not self.generate_audio:
return {
"is_listen": False,
"text": text,
"audio_waveform": None,
"end_of_turn": end_of_turn,
"current_time": current_time,
"cost_llm": llm_end_time - llm_start_time,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": len(total_ids_in_unit),
"n_tts_tokens": 0,
}
# TTS generate
tts_start_time = time.time()
tts_prep_start_time = time.time()
tts_condition = self._convert_results_to_tts_input(total_hidden_in_unit)
tts_prep_end_time = time.time()
max_token_per_chunk = 25 + 1
min_token_per_chunk = 25 + 1
if end_of_turn:
min_token_per_chunk = 0
force_flush = False
if self.tts_text_start_pos == 0: # 这是turn的开始
min_token_per_chunk = 0 # 可以允许解码<1s的音频
force_flush = True
if self.tts_current_turn_start_time is None:
self.tts_current_turn_start_time = current_time
new_tokens, old_kv = self.model.tts.generate_chunk(
inputs_embeds=tts_condition,
temperature=self.tts_temperature,
repetition_penalty=self.tts_repetition_penalty,
eos_token=self.tts_eos_token,
force_no_stop=False,
max_new_token=max_token_per_chunk,
min_new_tokens=min_token_per_chunk,
past_key_values=self.tts_past_key_values,
logits_processors=self.tts_logits_processors,
text_start_pos=self.tts_text_start_pos,
)
tts_end_time = time.time()
# 更新 TTS 状态(注意:token2wav 的重置必须在音频生成之后,否则会丢失 buffer 中的 tokens)
if end_of_turn:
self.tts_text_start_pos = 0
self.tts_past_key_values = None
self.tts_current_turn_start_time = None
# 注意:_reset_token2wav_for_new_turn() 移到下面音频生成之后
else:
self.tts_past_key_values = old_kv
self.tts_text_start_pos += tts_condition.shape[1] + new_tokens.shape[1]
# Token2Wav 生成(必须在 reset 之前,否则 buffer 中倒数第二个 chunk 的 tokens 会丢失)
token2wav_start_time = time.time()
audio_waveform = self._generate_waveform_from_tokens(new_tokens, prompt_wav_path, end_of_turn, force_flush=force_flush)
token2wav_end_time = time.time()
# 在音频生成完成后再重置 token2wav 状态,确保 buffer 中的 tokens 都被处理
if end_of_turn:
self._reset_token2wav_for_new_turn()
end_time = time.time()
return {
"is_listen": False,
"text": text,
"audio_waveform": audio_waveform,
"end_of_turn": end_of_turn,
"current_time": current_time,
"cost_llm": llm_end_time - llm_start_time,
"cost_tts_prep": tts_prep_end_time - tts_prep_start_time,
"cost_tts": tts_end_time - tts_start_time,
"cost_token2wav": token2wav_end_time - token2wav_start_time,
"cost_all": end_time - start_time,
"n_tokens": len(total_ids_in_unit),
"n_tts_tokens": new_tokens.numel(),
}
def _convert_results_to_tts_input(self, results):
"""convert LLM hidden states to TTS input"""
if len(results) == 0:
audio_bos = self.model.tts.emb_text(
torch.tensor(
[self.model.tts.audio_bos_token_id],
device=self.model.tts.emb_text.weight.device,
dtype=torch.long,
)
)
return audio_bos.unsqueeze(0)
llm_tokens = []
llm_hidden = []
for hidden in results:
llm_tokens.append(hidden[0])
llm_hidden.append(hidden[1].squeeze(0))
llm_tokens_tensor = torch.Tensor(llm_tokens).to(self.device, dtype=torch.long)
llm_embeds = self.model.tts.emb_text(llm_tokens_tensor)
llm_hidden_tensor = torch.cat(llm_hidden, dim=0)
llm_hidden_tensor = self.model.tts.projector_semantic(llm_hidden_tensor)
llm_hidden_tensor = torch.nn.functional.normalize(llm_hidden_tensor, p=2, dim=-1)
tts_embeds = llm_embeds + llm_hidden_tensor
audio_bos = self.model.tts.emb_text(
torch.tensor(
[self.model.tts.audio_bos_token_id],
device=self.model.tts.emb_text.weight.device,
dtype=torch.long,
)
)
tts_embeds = torch.cat([tts_embeds, audio_bos], dim=0)
return tts_embeds.unsqueeze(0)
def _generate_waveform_from_tokens(
self, new_tokens: torch.Tensor, prompt_wav_path: Optional[str], is_last_chunk: bool = False, force_flush: bool = False
) -> Optional[np.ndarray]:
"""从 audio tokens 生成波形"""
if not self.token2wav_initialized:
print("⚠️ token2wav 未初始化")
return None
CHUNK_SIZE = 25
# 将新 tokens 添加到 buffer
token_ids = torch.reshape(new_tokens, (-1,)).tolist()
self.token2wav_buffer += token_ids
# 检测是否有 chunk_eos token
has_chunk_eos = any(tid in self.chunk_terminator_token_ids for tid in token_ids)
self._log(
"AUDIO",
f"Token2Wav buffer size: {len(self.token2wav_buffer)}, new tokens: {len(token_ids)}, has_chunk_eos: {has_chunk_eos}",
)
pcm_bytes_list = []
# process enough tokens
# if there is chunk_eos, try to flush more content
if has_chunk_eos or force_flush:
# when there is chunk_eos, try to flush more content
while len(self.token2wav_buffer) >= self.pre_lookahead + 5: # at least keep some lookahead
chunk_to_process = min(CHUNK_SIZE + self.pre_lookahead, len(self.token2wav_buffer))
pcm_bytes = self.model.tts.audio_tokenizer.stream(
self.token2wav_buffer[:chunk_to_process],
prompt_wav=prompt_wav_path,
)
pcm_bytes_list.append(pcm_bytes)
self.token2wav_buffer = self.token2wav_buffer[min(CHUNK_SIZE, chunk_to_process - self.pre_lookahead) :]
else:
while len(self.token2wav_buffer) >= CHUNK_SIZE + self.pre_lookahead:
pcm_bytes = self.model.tts.audio_tokenizer.stream(
self.token2wav_buffer[: CHUNK_SIZE + self.pre_lookahead],
prompt_wav=prompt_wav_path,
)
pcm_bytes_list.append(pcm_bytes)
self.token2wav_buffer = self.token2wav_buffer[CHUNK_SIZE:]
# if is the last chunk, flush remaining tokens
if is_last_chunk and len(self.token2wav_buffer) > 0:
self._log("AUDIO", f"Flushing final {len(self.token2wav_buffer)} tokens")
pcm_bytes = self.model.tts.audio_tokenizer.stream(
self.token2wav_buffer,
prompt_wav=prompt_wav_path,
last_chunk=True,
)
pcm_bytes_list.append(pcm_bytes)
self.token2wav_buffer = []
if not pcm_bytes_list:
return None
# merge PCM and convert to numpy array (24kHz, int16 -> float32)
all_pcm = b"".join(pcm_bytes_list)
if len(all_pcm) == 0:
self._log("AUDIO", "No audio bytes generated")
return None
pcm_np = np.frombuffer(all_pcm, dtype=" np.ndarray:
"""generate silence waveform (24kHz)"""
sample_rate = 24000
num_samples = int(duration_sec * sample_rate)
return np.zeros(num_samples, dtype=np.float32)
def get_generated_text(self) -> str:
return self.tokenizer.decode(self.res_ids)
def get_current_time(self) -> int:
return self.audio_chunk_idx
def _log(self, event_type: str, message: str, data: Optional[dict] = None):
if self.session_start_time is None:
self.session_start_time = time.time()
timestamp = time.time() - self.session_start_time
log_entry = {"timestamp": timestamp, "type": event_type, "message": message}
if data:
log_entry["data"] = data
self.session_logs.append(log_entry)
prefix = f"[{timestamp:6.3f}s]"
if event_type == "FEED":
print(f"{prefix} [FEED] {message}")
elif event_type == "DECODE":
print(f"{prefix} [DECODE] {message}")
elif event_type == "SYS":
print(f"{prefix} [SYS] {message}")
elif event_type == "AUDIO":
print(f"{prefix} [AUDIO] {message}")
else:
print(f"{prefix} [{event_type}] {message}")
def save_session_log(self, output_dir: str = "tmp_session_logs"):
import json
import os
from datetime import datetime
os.makedirs(output_dir, exist_ok=True)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(output_dir, f"session_log_{timestamp_str}.json")
with open(log_file, "w", encoding="utf-8") as f:
json.dump(self.session_logs, f, ensure_ascii=False, indent=2)
print(f"session log saved to: {log_file}")
return log_file
def _update_loop(self, mode):
modes = {
"omni": [
"",
"",
"",
"",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>:decode",
"",
],
"video": ["", "", "", ":decode", ""],
"audio": [
"",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>",
"<|audio|>:decode",
"",
],
}
self.loop = modes.get(mode, None)
@torch.no_grad()
def offline_generate(
self,
frame_path: Optional[str] = None,
audio_path: Optional[str] = None,
max_new_speak_tokens=600,
max_new_speak_tokens_per_chunk=20,
max_loop_time=300,
decode_mode: str = "sampling",
temperature=0.7,
top_k=20,
top_p=0.8,
listen_top_k=None,
listen_prob_scale=1.0,
text_repetition_penalty=1.05,
text_repetition_window_size=512,
) -> Dict:
from .stream_providers import LoopPlanner
from .stream_providers import MockStreamProvider
from .stream_providers import StreamProvider
res_ids = []
total_ids = []
total_hidden = []
per_second_results = []
total_res_with_time = []
iter_data = MockStreamProvider(frame_path, audio_path, self.processor)
audio_data, frame_data = iter_data.all_data(device=self.device)
frame_emb_list = None
if frame_data is not None:
frame_emb_list = self.model.get_vision_embedding(frame_data)
if frame_emb_list and len(frame_emb_list) > 0:
frame_emb_list = frame_emb_list[0]
else:
frame_emb_list = None
audio_emb_list = None
if audio_data is not None:
audio_emb_list = self.model.get_audio_embedding(
audio_data, chunk_length=self.model.config.audio_chunk_length
)
if audio_emb_list and len(audio_emb_list) > 0:
audio_emb_list = torch.cat(audio_emb_list[0], dim=0)
else:
audio_emb_list = None
if audio_emb_list is None and frame_emb_list is None:
logger.warning("Warning: No audio or frame data available, returning empty result")
return {"total_ids": [], "res_ids": []}
loop = LoopPlanner(self.step_map, self.loop, self.tokenizer)
provider = StreamProvider(audio_emb_list, frame_emb_list, audio_chunk_len=1, image_step=1, audio_chunk_ms=100)
speak_count = 0
while (
speak_count < max_new_speak_tokens
and not provider.finished()
and provider.get_current_time() < max_loop_time
):
modal, tid, action = loop.next_plan()
total_ids.append(tid)
if action == "sliding":
self.decoder.sliding_embeds()
continue
if modal == "audio":
embeds = provider.next_audio()
elif modal == "image":
embeds = provider.next_image()
if embeds is not None:
embeds = embeds.squeeze(0)
elif modal == "sp_text":
embeds = self.decoder.embed_token(tid)
else:
raise ValueError(f"Unknown modal: {modal}")
if embeds is None:
break
if action == "prefill":
self.decoder.feed(embeds)
elif action == "decode":
total_hidden_in_unit = []
current_second_ids = []
logits, _ = self.decoder.feed(embeds, return_logits=True)
current_time = provider.get_current_time()
end_of_turn = False
for j in range(max_new_speak_tokens_per_chunk):
if j == max_new_speak_tokens_per_chunk - 1:
if self.ls_mode == "explicit":
self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id))
total_ids.append(self.chunk_eos_token_id)
break
last_id = self.decoder.decode(
logits=logits,
mode=decode_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
listen_top_k=listen_top_k,
listen_prob_scale=listen_prob_scale,
text_repetition_penalty=text_repetition_penalty,
text_repetition_window_size=text_repetition_window_size,
debug_print_top5=False,
)
is_listen = last_id.item() == self.listen_token_id
end_of_turn = last_id.item() in self.turn_terminator_token_ids
total_res_with_time.append(
{
"text": self.tokenizer.decode([last_id.item()]),
"timestamp": provider.get_current_time(),
}
)
total_ids.append(last_id.item())
if last_id.item() in self.chunk_terminator_token_ids:
if self.ls_mode == "explicit":
self.decoder.feed(self.decoder.embed_token(last_id.item()))
break
else:
if last_id.item() in self.chunk_speak_token_ids:
pass
else:
res_ids.append(last_id.item())
current_second_ids.append(last_id.item())
speak_count += 1
logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
assert len(hidden.shape) == 3
assert hidden.shape[0] == 1
assert hidden.shape[1] == 1
if j != 0:
total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn])
per_second_results.append(
{
"time": current_time,
"ids": current_second_ids,
"text": self.tokenizer.decode(current_second_ids) if current_second_ids else "",
}
)
if is_listen:
pass
else:
total_hidden.append(total_hidden_in_unit)
if last_id.item() in self.turn_terminator_token_ids:
break
return {
"total_ids": total_ids,
"res_ids": res_ids,
"total_res_with_time": total_res_with_time,
"total_hidden": total_hidden,
"per_second_results": per_second_results,
}
class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.d_model
try:
# compatible old transformers
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
layer_idx=layer_idx,
)
except:
from transformers.models.whisper.modeling_whisper import WhisperAttention
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
layer_idx=layer_idx,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = False,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, past_key_values = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
past_key_value=past_key_values,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if use_cache:
outputs += (past_key_values,)
return outputs
# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
class MiniCPMWhisperEncoder(WhisperEncoder):
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]
)
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = None,
use_extra_context: Optional[bool] = False,
prefix_extra_frames: Optional[int] = 1,
suffix_extra_frames: Optional[int] = 1,
return_debug: Optional[bool] = False,
cnn_min_length: Optional[int] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 初始化debug字典
debug_info = {} if return_debug else None
# Ignore copy
input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
# 可选:将短输入pad到最小长度以确保CNN计算一致性
original_length = input_features.shape[2]
padded_for_cnn = False
if cnn_min_length is not None and original_length < cnn_min_length:
padded_features = torch.zeros(
input_features.shape[0],
input_features.shape[1],
cnn_min_length,
dtype=input_features.dtype,
device=input_features.device,
)
padded_features[:, :, :original_length] = input_features
input_features = padded_features
padded_for_cnn = True
if return_debug:
debug_info["padded_for_cnn"] = True
debug_info["original_length"] = original_length
debug_info["padded_length"] = cnn_min_length
if return_debug:
debug_info["input_features"] = input_features.clone()
conv1_output = self.conv1(input_features)
if return_debug:
debug_info["after_conv1"] = conv1_output.clone()
inputs_embeds = nn.functional.gelu(conv1_output)
if return_debug:
debug_info["after_gelu1"] = inputs_embeds.clone()
conv2_output = self.conv2(inputs_embeds)
if return_debug:
debug_info["after_conv2"] = conv2_output.clone()
inputs_embeds = nn.functional.gelu(conv2_output)
if return_debug:
debug_info["after_gelu2"] = inputs_embeds.clone()
# 如果之前进行了padding,现在需要移除padding的影响
if padded_for_cnn:
# Conv1: stride=1, 输出长度=输入长度
# Conv2: stride=2, 输出长度=(输入长度+1)//2
actual_cnn_output_length = (original_length + 1) // 2
inputs_embeds = inputs_embeds[:, :, :actual_cnn_output_length]
if return_debug:
debug_info["after_unpad"] = inputs_embeds.clone()
# 如果使用额外上下文,CNN操作后需要移除冗余帧
# conv2的stride=2,所以输入的冗余帧在输出中会变成一半(向上取整)
if use_extra_context:
# 输入有prefix_extra_frames前缀帧和suffix_extra_frames后缀帧
# conv2 stride=2,输出长度 = ceil(输入长度 / 2)
# 对于2帧冗余,输出是1帧(ceil(2/2) = 1)
prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0 # 向上取整
suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0 # 向上取整
# 移除前后冗余帧 (batch, channels, time)
if prefix_to_remove > 0:
inputs_embeds = inputs_embeds[:, :, prefix_to_remove:]
if 0 < suffix_to_remove < inputs_embeds.shape[2]:
inputs_embeds = inputs_embeds[:, :, :-suffix_to_remove]
if return_debug and use_extra_context:
debug_info["after_redundancy_removal"] = inputs_embeds.clone()
inputs_embeds = inputs_embeds.permute(0, 2, 1)
if return_debug:
debug_info["after_permute"] = inputs_embeds.clone()
embed_pos = self.embed_positions.weight
past_key_values_length = 0
if use_cache:
if past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif isinstance(past_key_values, list):
past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache())
elif isinstance(past_key_values, DynamicCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
else:
pass
past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1])
if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
logger.warning("seems the audio is longer than 30s. repeating the last part of the audio")
embed_pos_front = embed_pos[past_key_values_length:, :]
embed_pos = torch.cat(
(
embed_pos_front,
torch.repeat_interleave(
embed_pos[-1, :].unsqueeze(0),
inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length,
dim=0,
),
)
)
else:
embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :]
else:
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
if return_debug:
debug_info["positional_embedding"] = embed_pos.clone()
debug_info["past_key_values_length"] = past_key_values_length
hidden_states = inputs_embeds + embed_pos
if return_debug:
debug_info["after_add_pos_embed"] = hidden_states.clone()
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if return_debug:
debug_info["after_dropout"] = hidden_states.clone()
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
# Ignore copy
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
past_key_values,
use_cache,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
past_key_values=past_key_values,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if return_debug and idx < 3: # 只记录前3层以节省内存
debug_info[f"after_layer_{idx}"] = hidden_states.clone()
if use_cache:
next_encoder_cache = layer_outputs[2 if output_attentions else 1]
else:
next_encoder_cache = None
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if return_debug:
debug_info["after_layer_norm"] = hidden_states.clone()
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
result = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
if return_debug:
return result, debug_info
return result
result = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
past_key_values=next_encoder_cache,
)
if return_debug:
return result, debug_info
return result
class MultiModalProjector(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
def forward(self, audio_features):
hidden_states = self.relu(self.linear1(audio_features))
hidden_states = self.linear2(hidden_states)
return hidden_states
class MiniCPMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.in_dim = config.llm_hidden_size
self.out_dim = config.hidden_size
self.intermediate_size = config.llm_intermediate_size
self.gate_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True)
self.up_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True)
self.down_proj = nn.Linear(self.intermediate_size, self.out_dim, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
@dataclass
class MiniCPMTTSGenerationOutput(ModelOutput):
"""
Output class for MiniCPMTTS generation.
Args:
new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
finished (bool): Boolean indicating whether generation is complete.
"""
new_ids: torch.LongTensor = None
audio_input_ids: torch.LongTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_input_ids: Optional[torch.LongTensor] = None
finished: bool = None
def make_streaming_chunk_mask_inference(
tts_text_scope: List[int],
tts_text_mask: torch.Tensor,
streaming_audio_chunk_size: int = 50,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cuda"),
max_sequence_length: int = 4096,
):
"""
Example:
Input sequence:
[t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...]
Output 4D causal mask:
------- text positions -------
[0] <- here is [Stts]
[0, 0] <- here is [spk_emb] * N
[0, 0, 0]
[0, 0, 0, 0]
[0, 0, 0, 0, 0]
------- audio positions --------
[0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token
v- here is [Ptts]
[0, 0, -inf, -inf, -inf, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
"""
# Create a complete attention mask for input embeds [batch_size, seq_len], without considering audio mask as audio is always at the end
assert tts_text_mask.dtype == torch.int8
padding_mask = torch.ones(max_sequence_length, dtype=torch.int8, device=device)
padding_mask[tts_text_scope[0] : tts_text_scope[1]] = tts_text_mask
# Initialize a standard upper triangular causal mask
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(max_sequence_length, max_sequence_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if max_sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
raise ValueError("max_sequence_length of tts could not be 1.")
# For each data sample
audio_token_start = tts_text_scope[1]
audio_duration = max_sequence_length - tts_text_scope[1]
# Record which text chunk the current audio chunk can see up to
text_pivot = 0
num_valid_text_tokens = torch.sum(tts_text_mask).item() - 1 # [Ptts] excluded
# How many audio chunks are in total, the num of buckets should be smaller as possible
num_text_tokens_per_audio_chunk = 10
# For each chunk of audio
for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)):
audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size
audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size
# New text seen by this new audio chunk
new_text_this_chunk = num_text_tokens_per_audio_chunk
# The right bound of visible text tokens
text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens)
# Mask all text chunks after the visible ones
# -> [text_pivot, len(tts_text_scope)-1] excluding [Ptts]
# print("audio_chunk_start-1", audio_chunk_start-1, "audio_chunk_end-1", audio_chunk_end-1, "tts_text_scope[0] + text_pivot", tts_text_scope[0] + text_pivot, "tts_text_scope[1] - 1", tts_text_scope[1] - 1)
causal_mask[
audio_chunk_start - 1 : audio_chunk_end - 1,
# tts_text_scope[0] + text_pivot: tts_text_scope[1],
tts_text_scope[0] + text_pivot : tts_text_scope[1] - 1,
] = min_dtype
# Mask the padding parts in tts_text_masks (no position will attend to it)
causal_mask[:, padding_mask == 0] = min_dtype
# Add extra dimensions, [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len]
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
return causal_mask
class MiniCPMTTS(PreTrainedModel):
config_class = MiniCPMTTSConfig
def __init__(self, config: MiniCPMTTSConfig, audio_tokenizer: None):
super().__init__(config)
self.use_llm_hidden_state = config.use_llm_hidden_state
self.use_text = config.use_text
self.streaming = config.streaming
self.streaming_text_chunk_min = config.streaming_text_chunk_min
self.streaming_text_chunk_max = config.streaming_text_chunk_max
self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
self.streaming_text_reserved_len = config.streaming_text_reserved_len
# streaming tts
self.streaming_text_chunk_size = config.streaming_text_chunk_max
self.audio_bos_token_id = config.audio_bos_token_id
self.num_mel_bins = config.num_mel_bins
self.num_vq = config.num_vq
self.num_audio_tokens = config.num_audio_tokens
self.top_p = config.top_p
self.top_k = config.top_k
self.repetition_penalty = config.repetition_penalty
self.interleaved = config.interleaved
self.attention_type = config.attention_type
self.recomputed_chunks = config.recomputed_chunks
self.window_size = config.window_size
if self.attention_type == "sliding_recompute" and self.window_size <= self.recomputed_chunks:
raise ValueError(
f"Sliding recompute requires window_size > recomputed_chunks, but got window_size={self.window_size} and recomputed_chunks={self.recomputed_chunks}"
)
if config.backbone_model == "llama":
model_config = LlamaConfig(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=config.num_hidden_layers,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
attn_implementation=config.attn_implementation,
)
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
model = LlamaModel(model_config)
self.model = model
else:
raise ValueError(f"Unsupported backbone model: {config.backbone_model}")
self.projector_spk = self.create_projector(config)
self.projector_semantic = self.create_projector(config)
self.audio_tokenizer = audio_tokenizer
self.emb_code = nn.ModuleList(
[nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)]
)
self.head_code = nn.ModuleList(
[
weight_norm(
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
name="weight",
)
for _ in range(config.num_vq)
]
)
self.condition_type = config.condition_type
return
@staticmethod
def create_projector(config):
if config.projector_type == "mlp":
return MultiModalProjector(config.llm_dim, config.hidden_size)
elif config.projector_type == "minicpm":
return MiniCPMMLP(config)
elif config.projector_type == "default":
return nn.Linear(config.llm_dim, config.hidden_size, bias=False)
else:
raise ValueError(f"Unsupported projector type: {config.projector_type}")
# non-streaming
@torch.inference_mode()
def generate(
self,
inputs_embeds: torch.Tensor,
eos_token: Union[int, torch.Tensor],
force_no_stop=False,
min_new_token=50,
max_new_token=2048,
show_tqdm=True,
streaming=False,
text_lengths=None,
sampling_params: TTSSamplingParams = TTSSamplingParams(),
):
temperature = torch.tensor(
[sampling_params.temperature] * self.config.num_vq,
dtype=torch.float,
device=self.device,
)
temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
inputs_embeds.device
)
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens,
repetition_penalty=sampling_params.repetition_penalty,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
)
# We only support batch size `1` for now
assert inputs_embeds.shape[0] == 1
eos_token = eos_token.to(inputs_embeds.device)
finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
condition_length = inputs_embeds.shape[1]
pbar: Optional[tqdm] = None
if show_tqdm:
pbar = tqdm(
total=max_new_token,
desc="code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
)
if streaming:
raise NotImplementedError("this kind of streaming is not supported yet")
new_tokens = torch.zeros(
inputs_embeds.shape[0],
max_new_token,
self.num_vq,
device=inputs_embeds.device,
dtype=torch.long,
)
past_key_values = None
for t in range(max_new_token):
audio_bos = False
# If this is the first audio token, the case is special
if t == 0:
audio_bos = True
inputs_embeds = inputs_embeds
position_ids = torch.tensor(
list(range(0, condition_length)),
dtype=torch.long,
device=self.device,
).unsqueeze(0)
if streaming:
raise NotImplementedError("this kind of streaming is not supported yet")
else:
causal_mask_4d = None
else:
code_emb = []
for q in range(self.num_vq):
x = self.emb_code[q](new_tokens[:, t - 1 : t, q])
code_emb.append(x)
inputs_embeds = torch.stack(code_emb, 3).sum(3)
position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze(
0
)
if streaming:
raise NotImplementedError("this kind of streaming is not supported yet")
else:
causal_mask_4d = None
if self.config.backbone_model == "llama":
outputs: BaseModelOutputWithPast = self.model(
position_ids=position_ids,
cache_position=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=causal_mask_4d,
use_cache=True,
output_attentions=False,
# return_dict=True, # Add this to ensure returns dict with past_key_values
)
else:
raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}")
del position_ids
del inputs_embeds
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= temperature
if not audio_bos:
input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
logits_token = input_ids_sliced.reshape(
input_ids_sliced.size(0) * input_ids_sliced.size(1),
-1,
).to(self.device)
del input_ids_sliced
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
for logitsWarpers in logits_warpers:
logits = logitsWarpers(logits_token, logits)
del logits_token
if t < min_new_token:
logits[:, eos_token] = -torch.inf
if force_no_stop:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
del scores
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
new_tokens[:, t] = idx_next
if t == 0 and finish.any():
break
del idx_next
if finish.all():
break
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
if not finish.all():
logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
genrated_input_ids = new_tokens[:, 0:t, :]
return MiniCPMTTSGenerationOutput(
new_ids=genrated_input_ids,
audio_input_ids=None, # for update purpose
past_key_values=None, # for update purpose
past_input_ids=None, # for update purpose
finished=finish.all(),
)
# fake streaming
@torch.inference_mode()
def generate_mock_legacy_streaming(
self,
inputs_embeds: torch.Tensor,
eos_token: Union[int, torch.Tensor],
force_no_stop=False,
min_new_token=50,
max_new_token=2048,
show_tqdm=True,
streaming=False,
text_lengths=None,
sampling_params: TTSSamplingParams = TTSSamplingParams(),
valid_text_length=None,
):
assert valid_text_length is not None, "valid_text_length should be not None"
tts_text_scope = [0, inputs_embeds.shape[1]]
tts_text_mask = torch.zeros(inputs_embeds.shape[1], dtype=torch.int8, device=inputs_embeds.device)
tts_text_mask[0:valid_text_length] = 1
tts_text_mask[-1] = 1 # [Ptts]
streaming_mask_4d_full = make_streaming_chunk_mask_inference(
tts_text_scope=tts_text_scope,
tts_text_mask=tts_text_mask,
dtype=torch.bfloat16,
device=self.device,
streaming_audio_chunk_size=50,
max_sequence_length=4096,
)
temperature = torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.device)
temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
inputs_embeds.device
)
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens,
repetition_penalty=sampling_params.repetition_penalty,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
)
# We only support batch size `1` for now
assert inputs_embeds.shape[0] == 1
eos_token = eos_token.to(inputs_embeds.device)
finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
condition_length = inputs_embeds.shape[1]
pbar: Optional[tqdm] = None
if show_tqdm:
pbar = tqdm(
total=max_new_token,
desc="code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
)
new_tokens = torch.zeros(
inputs_embeds.shape[0],
max_new_token,
self.num_vq,
device=inputs_embeds.device,
dtype=torch.long,
)
past_key_values = None
for t in range(max_new_token):
audio_bos = False
if t == 0:
audio_bos = True
inputs_embeds = inputs_embeds
position_ids = torch.tensor(
list(range(0, condition_length)),
dtype=torch.long,
device=self.device,
).unsqueeze(0)
causal_mask_4d = streaming_mask_4d_full[:, :, :condition_length, :condition_length]
else:
code_emb = []
for q in range(self.num_vq):
x = self.emb_code[q](new_tokens[:, t - 1 : t, q])
code_emb.append(x)
inputs_embeds = torch.stack(code_emb, 3).sum(3)
position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze(
0
)
causal_mask_4d = streaming_mask_4d_full[
:,
:,
condition_length + t : condition_length + t + 1,
: condition_length + t,
]
# get length of past_key_values
past_key_values_length = past_key_values[0][0].shape[2]
assert causal_mask_4d.shape[-1] == (past_key_values_length + 1)
if self.config.backbone_model == "llama":
outputs: BaseModelOutputWithPast = self.model(
position_ids=position_ids,
cache_position=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=causal_mask_4d,
use_cache=True,
output_attentions=False,
# return_dict=True, # Add this to ensure returns dict with past_key_values
)
else:
raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}")
del position_ids
del inputs_embeds
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= temperature
if not audio_bos:
input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
logits_token = input_ids_sliced.reshape(
input_ids_sliced.size(0) * input_ids_sliced.size(1),
-1,
).to(self.device)
del input_ids_sliced
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
for logitsWarpers in logits_warpers:
logits = logitsWarpers(logits_token, logits)
del logits_token
if t < min_new_token:
logits[:, eos_token] = -torch.inf
if force_no_stop:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
del scores
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
new_tokens[:, t] = idx_next
if t == 0 and finish.any():
break
del idx_next
if finish.all():
break
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
if not finish.all():
logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
genrated_input_ids = new_tokens[:, 0:t, :]
return MiniCPMTTSGenerationOutput(
new_ids=genrated_input_ids,
audio_input_ids=None, # for update purpose
past_key_values=None, # for update purpose
past_input_ids=None, # for update purpose
finished=finish.all(),
)
# non-streaming, interleave
@torch.inference_mode()
def generate_chunk(
self,
inputs_embeds: torch.Tensor,
temperature: torch.Tensor,
repetition_penalty: float,
eos_token: Union[int, torch.Tensor],
force_no_stop=False,
max_new_token=500,
min_new_tokens=0,
past_key_values=None,
logits_processors=None,
text_start_pos=None,
):
"""For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like:
|Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS|
where the last position is the audio BOS token.
So, the first iteration in generation directly forward the model with inputs_embeds, and
the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token.
"""
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty
)
# We only support batch size `1` for now
assert inputs_embeds.shape[0] == 1
eos_token = eos_token.to(inputs_embeds.device)
finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
inputs_embeds.device
)
condition_length = inputs_embeds.shape[1]
new_tokens = torch.zeros(
inputs_embeds.shape[0],
max_new_token,
self.num_vq,
device=inputs_embeds.device,
dtype=torch.long,
)
for t in range(max_new_token):
audio_bos = False
# If this is the first audio token, the case is special
if t == 0:
audio_bos = True
inputs_embeds_ = inputs_embeds
position_ids = torch.tensor(
list(range(text_start_pos, text_start_pos + condition_length)),
dtype=torch.long,
device=self.device,
).unsqueeze(0)
else:
# Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate`
inputs_embeds_ = self.emb_code[0](new_tokens[:, t - 1 : t, 0])
position_ids = torch.tensor(
[text_start_pos + condition_length + t - 1], # 把上一个token prefill进去
dtype=torch.long,
device=self.device,
).unsqueeze(0)
outputs: BaseModelOutputWithPast = self.model(
position_ids=position_ids,
# cache_position=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds_,
use_cache=True,
output_attentions=False,
# return_dict=True, # Add this to ensure returns dict with past_key_values
)
del position_ids
del inputs_embeds_
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= temperature
if not audio_bos:
input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
logits_token = input_ids_sliced.reshape(
input_ids_sliced.size(0) * input_ids_sliced.size(1),
-1,
).to(self.device)
del input_ids_sliced
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
del logits_token
if force_no_stop or t < min_new_tokens:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
del scores
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
new_tokens[:, t] = idx_next
if t == 0 and finish.any():
break
del idx_next
if finish.all():
break
if not finish.all():
logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
# 最新生成的那个token不在此次返回的范围内,如果是eos token,不返回,如果是其他正常的token,也不返回。正常!
genrated_input_ids = new_tokens[:, 0:t, :]
return genrated_input_ids, past_key_values
@torch.inference_mode()
def interleaved_generate(
self,
spk_embeds: torch.Tensor,
conditions: List[torch.Tensor],
temperature: torch.Tensor,
repetition_penalty: float,
eos_token: Union[int, torch.Tensor],
**kwargs,
):
"""
For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like:
|Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS|
where the last position is the audio BOS token.
So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token.
"""
temperature = torch.tensor([temperature], dtype=torch.float, device=self.device)
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens,
repetition_penalty=repetition_penalty,
)
eos_token = eos_token.to(conditions[0].device)
num_chunks = len(conditions)
text_start_pos = 0
last_window_size = 0
past_key_values = None
for idx in range(num_chunks):
condition = conditions[idx].to(conditions[0].device)
if self.attention_type == "sliding_recompute":
recomputed_conditions = []
if (
idx >= self.window_size
and (idx - self.recomputed_chunks) % (self.window_size - self.recomputed_chunks) == 0
):
for i in range(self.recomputed_chunks):
recomputed_conditions.append(conditions[idx - self.recomputed_chunks + i])
recomputed_conditions.append(
self.emb_code[0](generated_tokens[-self.recomputed_chunks + i][:, :, 0])
)
recomputed_conditions.append(condition)
condition = torch.cat(recomputed_conditions, dim=1)
text_start_pos = 0
new_tokens, old_kv = self.generate_chunk(
inputs_embeds=condition,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token=eos_token,
force_no_stop=False,
max_new_token=500,
past_key_values=None,
logits_processors=logits_processors,
text_start_pos=text_start_pos,
)
else:
new_tokens, old_kv = self.generate_chunk(
inputs_embeds=condition,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token=eos_token,
force_no_stop=False,
max_new_token=500,
past_key_values=past_key_values,
logits_processors=logits_processors,
text_start_pos=text_start_pos,
)
else:
new_tokens, old_kv = self.generate_chunk(
inputs_embeds=condition,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token=eos_token,
force_no_stop=False,
max_new_token=500,
past_key_values=past_key_values,
logits_processors=logits_processors,
text_start_pos=text_start_pos,
)
past_key_values = []
if self.attention_type == "sliding_window" and idx >= 1:
for layer_idx in range(len(old_kv)):
past_key_values.append(
(
old_kv[layer_idx][0][:, :, last_window_size:, :],
old_kv[layer_idx][1][:, :, last_window_size:, :],
)
)
else:
past_key_values = old_kv
last_window_size = condition.shape[1] + new_tokens.shape[1]
text_start_pos += last_window_size
if idx == 0:
generated_tokens = [new_tokens]
else:
generated_tokens.append(new_tokens)
return MiniCPMTTSGenerationOutput(new_ids=torch.cat(generated_tokens, dim=1), finished=True)
class CustomRepetitionPenaltyLogitsProcessorRepeat:
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.size(1) > self.past_window:
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
if freq.size(0) > self.max_input_ids:
freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_()
alpha = torch.pow(self.penalty, freq)
scores = scores.contiguous()
inp = scores.multiply(alpha)
oth = scores.divide(alpha)
con = scores < 0
out = torch.where(con, inp, oth)
del inp, oth, scores, con, alpha
return out
def gen_logits(num_code: int, top_p=0.7, top_k=20, repetition_penalty=1.0):
logits_warpers = []
if top_p is not None:
logits_warpers.append(TopPLogitsWarper(top_p, min_tokens_to_keep=3))
if top_k is not None:
logits_warpers.append(TopKLogitsWarper(top_k, min_tokens_to_keep=3))
logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16))
return logits_warpers, logits_processors
# Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for positionidspositionidsposition_ids.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
from transformers.models.paligemma.modeling_paligemma import (
_prepare_4d_causal_attention_mask_with_cache_position,
)
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
# "cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs