#!/usr/bin/env python # -*- coding: utf-8 -*- from typing import List from typing import Union import torch import torch.nn.functional as F import torch.nn.utils.parametrize as P class TTSStreamingGenerator: def __init__( self, model, temperature: float, eos_token: Union[int, torch.Tensor], chunk_size: int = 25, # s3tokenizer 1s = 25token tts_last_turn_tokens: torch.Tensor = None, logits_processors=[], logits_warpers=[], ): self.tts = model self.device = model.device self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device) self.eos_token = ( torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device) ) self.num_vq = model.num_vq self.num_audio_tokens = model.num_audio_tokens self.window_size = model.window_size self.recomputed_chunks = model.recomputed_chunks self.emb_code = model.emb_code self.head_code = model.head_code # Logits processors self.logits_processors = logits_processors # Logits warpers (like TopP/TopK), separate from processors self.logits_warpers = logits_warpers # initialize state self.past_key_values = None self.text_start_pos = 0 self.idx = -1 # start from -1, become 0 when first called self.all_conditions = [] self.all_generated_tokens = [] self.tts_last_turn_tokens = tts_last_turn_tokens self.spk_emb = None audio_bos = [self.tts.audio_bos_token_id] audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long) self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0) self.text_eos_embed = self.tts.emb_text( torch.tensor( [self.tts.config.text_eos_token_id], device=self.tts.emb_text.weight.device, dtype=torch.long, ) ).unsqueeze(0) # buffer related, used to fill up chunk_size and yield to outside self.chunk_size = chunk_size self._token_buffer: List[torch.Tensor] = [] @torch.inference_mode() def generate_with_buffer( self, condition: torch.Tensor, text_finished: bool = False, max_new_token: int = 500, ): """input a condition embedding chunk, generate audio token each time, and accumulate to buffer, only yield when buffer satisfies chunk_size. Yields: torch.Tensor of shape [chunk_size] (2D: [1, chunk_size]) """ self.idx += 1 self.device = self.tts.device # if text finished, first concatenate Text EOS if text_finished: condition = torch.cat([condition, self.text_eos_embed], dim=1) # always concatenate Audio BOS condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device) self.all_conditions.append(condition) current_condition = condition condition_length = current_condition.shape[1] finished = torch.zeros(1, dtype=torch.bool, device=self.device) chunk_generated_tokens = [] for t in range(max_new_token): if t == 0: inputs_embeds = current_condition pos_ids = torch.arange( self.text_start_pos, self.text_start_pos + condition_length, dtype=torch.long, device=self.device, ).unsqueeze(0) else: last = self.all_generated_tokens[-1] # last: [1,1], directly as code id inputs_embeds = self.emb_code[0](last) pos_ids = torch.tensor( [self.text_start_pos + condition_length + t - 1], dtype=torch.long, device=self.device, ).unsqueeze(0) outputs = self.tts.model( position_ids=pos_ids, past_key_values=self.past_key_values, inputs_embeds=inputs_embeds, use_cache=True, ) hidden_states = outputs.last_hidden_state self.past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= self.temperature audio_bos = len(self.all_generated_tokens) == 0 and t == 0 if not audio_bos: # use generated tokens (current chunk) as input for processor/warper (align with modeling_minicpmo) all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) # [1, T] for processor in self.logits_processors: logits = processor(all_generated_tokens, logits) for warper in self.logits_warpers: logits = warper(all_generated_tokens, logits) del all_generated_tokens # sample next token (only use first codebook, same as generate) scores = F.softmax(logits, dim=-1) idx_next = torch.multinomial(scores, num_samples=1) # [(B*num_vq), 1] next_id = idx_next.view(-1, self.num_vq)[:, 0:1] # only take first codebook → [B, 1] del scores if next_id.eq( self.eos_token ).any(): # generated audio eos token, means this chunk is finished, no longer generate new tokens finished[:] = True else: # eos token cannot be added to buffer, he does not speak. # convert next_id to correct shape [1, 1], no num_vq dimension if next_id.dim() == 0: # if scalar next_tok = next_id.unsqueeze(0).unsqueeze(0) # [1, 1] elif next_id.dim() == 1: # if 1D [1] next_tok = next_id.unsqueeze(0) # [1, 1] else: next_tok = next_id self.all_generated_tokens.append(next_tok) chunk_generated_tokens.append(next_tok) self._token_buffer.append(next_tok) if len(self._token_buffer) == 0: # case 1: if last text chunk, yield None if text_finished: yield torch.empty(1, 0, dtype=torch.long, device=self.device), True break # case 2: if not last text chunk, break directly else: break else: # buffer has something # case 1: if buffer is larger/equal to chunk_size, yield out if len(self._token_buffer) >= self.chunk_size: batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) # [1, chunk_size] yield batch, False # → [1, chunk_size] # discard yielded part self._token_buffer = self._token_buffer[self.chunk_size :] # case 2: if buffer is smaller than chunk_size else: # if generation finished, and is the last text chunk, yield all remaining tokens, then break if finished.all(): if text_finished: batch = torch.cat(self._token_buffer, dim=1) # [1, chunk_size] yield batch, True # → [1, chunk_size] self._token_buffer = [] break else: # not the last text chunk, need to wait for next text chunk to fill up buffer, then this call ends break else: # generation of this audio chunk is not finished, continue generating continue self.text_start_pos += condition_length + len(chunk_generated_tokens) # note: remaining tokens in buffer will be kept, and accumulated next time