#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 流式Mel特征处理器 用于实时音频流的Mel频谱特征提取,支持chunk-based处理。 支持配置CNN冗余以保证与离线处理的一致性。 """ import logging from typing import Dict from typing import Optional from typing import Tuple import numpy as np import torch from .processing_audio_minicpma import MiniCPMAAudioProcessor logger = logging.getLogger(__name__) class StreamingMelProcessorExact: """ 严格离线等价的流式Mel处理器。 思路: - 累积全部历史音频到缓冲;每次新增后用同一个 feature_extractor 计算整段 mel。 - 只输出"已稳定"的帧:帧中心不依赖未来(右侧)上下文,即 center + n_fft//2 <= 当前缓冲长度。 - 结束时(flush)再输出最后一批帧,确保与离线全量计算完全一致。 代价:每次会对累积缓冲做一次特征提取(可按需优化为增量)。 """ def __init__( self, feature_extractor: MiniCPMAAudioProcessor, chunk_ms: int = 100, first_chunk_ms: Optional[int] = None, sample_rate: int = 16000, n_fft: int = 400, hop_length: int = 160, n_mels: int = 80, verbose: bool = False, cnn_redundancy_ms: int = 10, # (以ms给定,通常10ms=1帧) # --- 滑窗参数(Trigger模式) --- enable_sliding_window: bool = False, # 是否启用滑窗 slide_trigger_seconds: float = 30.0, # 触发滑窗的缓冲区秒数阈值 slide_stride_seconds: float = 10.0, # 每次滑窗移动的秒数 ): self.feature_extractor = feature_extractor self.chunk_ms = chunk_ms self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.n_mels = n_mels self.verbose = verbose self.chunk_samples = int(round(chunk_ms * sample_rate / 1000)) self.chunk_frames = self.chunk_samples // hop_length # 对齐到 hop_length 的整数倍,避免帧边界不齐 hop = self.hop_length raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000)) aligned_first = max(hop, (raw_first_samples // hop) * hop) self.first_chunk_samples = aligned_first self.half_window = n_fft // 2 # 需要的右侧上下文 # 冗余帧数(以帧为单位),<=1帧:10ms → 1帧 self.cnn_redundancy_ms = cnn_redundancy_ms self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000) self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length) # --- 滑窗配置(Trigger模式) --- self.enable_sliding_window = enable_sliding_window self.trigger_seconds = slide_trigger_seconds self.slide_seconds = slide_stride_seconds # --- 位移/基准(全局帧坐标) --- self.left_samples_dropped = 0 # 已从左侧丢弃的样本数 self.base_T = 0 # 当前 mel_full[:, :, 0] 对应的"全局帧"下标 self.reset() def reset(self): self.buffer = np.zeros(0, dtype=np.float32) self.last_emitted_T = 0 self.total_samples_processed = 0 self.chunk_count = 0 self.is_first = True self.left_samples_dropped = 0 self.base_T = 0 def get_chunk_size(self) -> int: return self.first_chunk_samples if self.is_first else self.chunk_samples def get_expected_output_frames(self) -> int: raise NotImplementedError("get_expected_output_frames is not implemented") def _extract_full(self) -> torch.Tensor: # 当缓冲长度小于 n_fft 时,Whisper 的内部 STFT 在 center=True 且 pad 模式下会报错 # (pad 大于输入长度)。此时本来也没有稳定帧可输出,所以直接返回空特征。 if len(self.buffer) < self.n_fft: raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}") # 如果 buffer 长度 小于 5s 的话,用 set_spac_log_norm(log_floor_db=-10) 或者 上一次缓存的结果 if len(self.buffer) < 5 * self.sample_rate: # TODO: 这里最好的还是 做一些 实验选择 一个 最好的,现在这个 是通过 经验 选择的, 可以看 MiniCPMAAudioProcessor 的 main 实现 self.feature_extractor.set_spac_log_norm(log_floor_db=-10) # 如果 buffer 长度 大于 5s 的话,用 set_spac_log_norm(dynamic_range_db=8) else: self.feature_extractor.set_spac_log_norm(dynamic_range_db=8) feats = self.feature_extractor( self.buffer, sampling_rate=self.sample_rate, return_tensors="pt", padding=False, ) return feats.input_features # [1, 80, T] def _stable_frames_count(self) -> int: # 已稳定帧数 = floor((len(buffer) - half_window) / hop) + 1,最小为0 L = int(self.buffer.shape[0]) if L <= 0: return 0 if L < self.half_window: return 0 return max(0, (L - self.half_window) // self.hop_length + 1) def _maybe_slide_buffer(self): """Trigger模式滑窗:当缓冲区达到触发阈值时,滑动固定长度的窗口。""" if not self.enable_sliding_window: return sr = self.sample_rate hop = self.hop_length L = len(self.buffer) # 将秒数转换为样本数 trigger_samples = int(self.trigger_seconds * sr) stride_samples = int(self.slide_seconds * sr) # 检查是否达到触发阈值 if L < trigger_samples: return # 计算需要丢弃的样本数(固定滑动 stride_samples) drop = stride_samples # 不能丢掉后续发射还需要的左侧上下文 # 在trigger模式下,我们只需要保护最小必要的数据 # 即:确保不丢弃未来可能需要的帧 last_emitted_local = self.last_emitted_T - self.base_T # 只保护必要的上下文(例如,最近的1秒数据) min_keep_seconds = 1.0 # 保留至少1秒的数据以确保处理的连续性 min_keep_samples = int(min_keep_seconds * sr) # guard_samples 是我们必须保留的最小样本数 guard_samples = min(min_keep_samples, L - drop) # 限制:不得越过安全边界;并对齐 hop max_allowed_drop = max(0, L - guard_samples) drop = min(drop, max_allowed_drop) drop = (drop // hop) * hop if drop <= 0: return # 真正丢弃 & 更新基准 self.buffer = self.buffer[drop:] self.left_samples_dropped += drop self.base_T += drop // hop if self.verbose: print( f"[Slide] Trigger模式: drop={drop/sr:.2f}s samples, base_T={self.base_T}, buffer_after={len(self.buffer)/sr:.2f}s" ) def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]: self.chunk_count += 1 # 追加到缓冲 if len(self.buffer) == 0: self.buffer = audio_chunk.astype(np.float32, copy=True) else: self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)]) # --- 滑窗处理 --- self._maybe_slide_buffer() # 全量提取(针对当前窗口) mel_full = self._extract_full() T_full = mel_full.shape[-1] # 当前窗口的局部帧数 stable_T = min(T_full, self._stable_frames_count()) # 局部可稳定帧 stable_T_global = self.base_T + stable_T # 映射到全局帧坐标 # 计划本次发射的核心帧(全局坐标) core_start_g = self.last_emitted_T core_end_g = core_start_g + self.chunk_frames required_stable_g = core_end_g + self.cnn_redundancy_frames if self.verbose: print( f"[Exact] buffer_len={len(self.buffer)} samples, T_full(local)={T_full}, " f"stable_T(local)={stable_T}, base_T={self.base_T}, " f"stable_T(global)={stable_T_global}, last_emitted={self.last_emitted_T}" ) if stable_T_global >= required_stable_g or is_last_chunk: emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames) emit_end_g = core_end_g + self.cnn_redundancy_frames # 全局 -> 局部索引 emit_start = max(0, emit_start_g - self.base_T) emit_end = emit_end_g - self.base_T emit_start = max(0, min(emit_start, T_full)) emit_end = max(emit_start, min(emit_end, T_full)) mel_output = mel_full[:, :, emit_start:emit_end] self.last_emitted_T = core_end_g # 仅推进核心帧指针(全局) else: mel_output = mel_full[:, :, 0:0] self.total_samples_processed += len(audio_chunk) self.is_first = False info = { "type": "exact_chunk", "chunk_number": self.chunk_count, "emitted_frames": mel_output.shape[-1], "stable_T": stable_T, "T_full": T_full, "base_T": self.base_T, "stable_T_global": stable_T_global, "buffer_len_samples": int(self.buffer.shape[0]), "left_samples_dropped": self.left_samples_dropped, "core_start": core_start_g, # 如果保留原字段名,这里用全局值 "core_end": core_end_g, # 同上 } return mel_output, info def flush(self) -> torch.Tensor: """在流结束时调用,输出剩余未发出的帧,保证与离线一致(按全局坐标计算)。""" if len(self.buffer) == 0: return torch.zeros(1, 80, 0) mel_full = self._extract_full() T_local = mel_full.shape[-1] T_global = self.base_T + T_local if self.last_emitted_T < T_global: start_l = max(0, self.last_emitted_T - self.base_T) tail = mel_full[:, :, start_l:] self.last_emitted_T = T_global if self.verbose: print(f"[Exact] flush {tail.shape[-1]} frames (T_global={T_global})") return tail return mel_full[:, :, 0:0] def get_config(self) -> Dict: return { "chunk_ms": self.chunk_ms, "first_chunk_ms": self.first_chunk_ms, "effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0, "sample_rate": self.sample_rate, "n_fft": self.n_fft, "hop_length": self.hop_length, "cnn_redundancy_ms": self.cnn_redundancy_ms, "cnn_redundancy_frames": self.cnn_redundancy_frames, "enable_sliding_window": self.enable_sliding_window, "trigger_seconds": self.trigger_seconds, "slide_seconds": self.slide_seconds, } def get_state(self) -> Dict: return { "chunk_count": self.chunk_count, "last_emitted_T": self.last_emitted_T, "total_samples_processed": self.total_samples_processed, "buffer_len": int(self.buffer.shape[0]), "base_T": self.base_T, "left_samples_dropped": self.left_samples_dropped, } def get_snapshot(self) -> Dict: """获取完整状态快照(包括 buffer),用于抢跑恢复 Returns: 包含完整状态的字典,可用于 restore_snapshot 恢复 """ buffer_copy = self.buffer.copy() snapshot = { "chunk_count": self.chunk_count, "last_emitted_T": self.last_emitted_T, "total_samples_processed": self.total_samples_processed, "buffer": buffer_copy, "base_T": self.base_T, "left_samples_dropped": self.left_samples_dropped, "is_first": self.is_first, # 保存 feature_extractor 的状态(关键:确保 mel 特征提取的确定性) "fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None), "fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None), "fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None), } logger.debug( "[MelProcessor] Created snapshot: chunk_count=%d, last_emitted_T=%d, " "buffer_len=%d, buffer_sum=%.6f, total_samples=%d", self.chunk_count, self.last_emitted_T, len(buffer_copy), float(buffer_copy.sum()) if len(buffer_copy) > 0 else 0.0, self.total_samples_processed, ) return snapshot def restore_snapshot(self, snapshot: Dict) -> None: """从快照恢复状态 Args: snapshot: 由 get_snapshot 返回的快照字典 """ # 记录恢复前的状态 prev_state = { "chunk_count": self.chunk_count, "last_emitted_T": self.last_emitted_T, "buffer_len": len(self.buffer), } # 恢复状态 self.chunk_count = snapshot["chunk_count"] self.last_emitted_T = snapshot["last_emitted_T"] self.total_samples_processed = snapshot["total_samples_processed"] self.buffer = snapshot["buffer"].copy() # 复制 buffer self.base_T = snapshot["base_T"] self.left_samples_dropped = snapshot["left_samples_dropped"] self.is_first = snapshot["is_first"] # 恢复 feature_extractor 的状态(关键:确保 mel 特征提取的确定性) if snapshot.get("fe_dynamic_log_norm") is not None: self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"] if snapshot.get("fe_dynamic_range_db") is not None: self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"] if snapshot.get("fe_log_floor_db") is not None: self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"] logger.info( "[MelProcessor] Restored snapshot: chunk_count %d->%d, last_emitted_T %d->%d, " "buffer_len %d->%d, total_samples=%d", prev_state["chunk_count"], self.chunk_count, prev_state["last_emitted_T"], self.last_emitted_T, prev_state["buffer_len"], len(self.buffer), self.total_samples_processed, )