#!/usr/bin/env python # -*- coding: utf-8 -*- import logging from dataclasses import dataclass from typing import Optional import torch import torch.nn.functional as F logger = logging.getLogger(__name__) @dataclass class GenerateChunkOutput: chunk_token_ids: torch.Tensor current_inputs_embeds: torch.Tensor input_last_hidden_states: Optional[torch.Tensor] # for tts use_speaker_embedding last_hidden_states: Optional[torch.Tensor] # for tts input feature (projector_semantic) past_key_values: Optional[torch.Tensor] finished: bool class ChunkPrefillChunkGenerate: def __init__(self, model, tokenizer, terminators): self.tokenizer = tokenizer self.model = model self.terminators = terminators self.terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] self.embedding_layer = self.model.get_input_embeddings() self.forbidden_tokens = [ ":", ":", ";", "#", "“", "”", "‘", "’", "@", "*", "【", "】", "「", "」", "(", ")", "(", ")", "[", "]", "&", "/", "$", ] self.forbidden_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.forbidden_tokens] bad_token_ids = getattr(tokenizer, "bad_token_ids", []) if bad_token_ids: self.forbidden_token_ids.extend(bad_token_ids) @staticmethod def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs): num_beams = kwargs.get("num_beams", 1) generation_config = { "num_beams": num_beams, "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05, } if num_beams > 1: generation_config.update({"num_beams": 3, "repetition_penalty": 1.2, "do_sample": False}) elif do_sample: generation_config.update( { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05, } ) else: generation_config.update({"do_sample": False, "repetition_penalty": 1.05}) generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) generation_config["min_new_tokens"] = min_new_tokens generation_config["max_new_tokens"] = max_new_tokens return generation_config @staticmethod def _get_cache_length(past_key_values) -> int: if past_key_values is None: return 0 if hasattr(past_key_values, "get_seq_length"): return past_key_values.get_seq_length() if isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0: first_layer = past_key_values[0] if isinstance(first_layer, (tuple, list)) and len(first_layer) > 0: return first_layer[0].shape[2] return 0 def non_chunk_generate( self, input_ids=None, inputs_embeds=None, attention_mask=None, max_new_tokens=30, min_new_tokens=0, do_sample=True, **kwargs, ): assert (input_ids is not None and inputs_embeds is None) or (input_ids is None and inputs_embeds is not None) generation_config = self.prepare_generation_config( do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs ) input_ids = input_ids.to(self.model.device) if input_ids is not None else None inputs_embeds = inputs_embeds.to(self.model.device) if inputs_embeds is not None else None if attention_mask is not None: attention_mask = attention_mask.to(self.model.device) model_inputs = { "attention_mask": attention_mask, "pad_token_id": self.tokenizer.eos_token_id, "suppress_tokens": self.forbidden_token_ids, "eos_token_id": self.terminators_ids, "output_hidden_states": True, "return_dict_in_generate": True, } if input_ids is not None: model_inputs["input_ids"] = input_ids if inputs_embeds is not None: model_inputs["inputs_embeds"] = inputs_embeds with torch.no_grad(): outputs = self.model.generate(**model_inputs, **generation_config) return outputs def chunk_generate( self, inputs_embeds: torch.Tensor, past_key_values, is_first_generate_chunk: bool, chunk_size: int, return_hidden_states: bool, do_sample: bool, temperature: float, top_p: float, top_k: int, repetition_penalty: float = 1.05, all_input_ids: Optional[torch.Tensor] = None, ) -> GenerateChunkOutput: finished = False current_inputs_embeds = inputs_embeds.clone() input_last_hidden_states = [] last_hidden_states = [] generated_tokens = [] for token_idx in range(chunk_size): if is_first_generate_chunk and token_idx == 0: # first generate chunk, prefill inputs_embeds model_inputs = { "inputs_embeds": current_inputs_embeds, "past_key_values": past_key_values, "use_cache": True, "output_hidden_states": return_hidden_states, } else: # for all other cases: prefill the latest generated token model_inputs = { "inputs_embeds": current_inputs_embeds[:, -1:, :], "past_key_values": past_key_values, "use_cache": True, "output_hidden_states": return_hidden_states, } with torch.no_grad(): outputs = self.model(**model_inputs) # last token's logits logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=inputs_embeds.device) # forbid specific tokens decoding = model.generate@suppress_tokens if self.forbidden_token_ids: logits[:, self.forbidden_token_ids] = float("-inf") past_key_values = outputs.past_key_values PENALTY_WINDOW_SIZE = 128 # apply repetition penalty if repetition_penalty != 1.0: # get token ids for repetition penalty if all_input_ids is not None: # use global input ids (including original input and generated part) if len(generated_tokens) > 0: generated_token_ids = torch.cat(generated_tokens, dim=1) current_sequence = torch.cat( [ all_input_ids[:, -PENALTY_WINDOW_SIZE:], generated_token_ids, ], dim=1, ) else: current_sequence = all_input_ids[:, -PENALTY_WINDOW_SIZE:] unique_token_ids = torch.unique(current_sequence.squeeze(0)) elif len(generated_tokens) > 0: # revert to original logic: only use generated tokens generated_token_ids = torch.cat(generated_tokens, dim=1).squeeze(0) unique_token_ids = torch.unique(generated_token_ids) else: unique_token_ids = torch.tensor([], dtype=torch.long, device=logits.device) # apply repetition penalty for token_id in unique_token_ids: if logits[0, token_id] > 0: logits[0, token_id] = logits[0, token_id] / repetition_penalty else: logits[0, token_id] = logits[0, token_id] * repetition_penalty # apply temperature if temperature != 1.0: logits = logits / temperature if do_sample: # Top-k filtering if top_k > 0: top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1))) logits_filtered = torch.full_like(logits, float("-inf")) logits_filtered.scatter_(1, top_k_indices, top_k_logits) logits = logits_filtered # Top-p filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # remove tokens with cumulative probability greater than top_p sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float("-inf") # sampling probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) if return_hidden_states: if is_first_generate_chunk and token_idx == 0: input_last_hidden_states.append(outputs.hidden_states[-1]) else: last_hidden_states.append(outputs.hidden_states[-1]) # if terminator token, stop generating if next_token.item() in self.terminators_ids: finished = True break generated_tokens.append(next_token) # convert new token to embeddings and concatenate next_token_embed = self.embedding_layer(next_token) # update inputs_embeds, add one current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embed], dim=1) if len(generated_tokens) > 0: chunk_token_ids = torch.cat(generated_tokens, dim=1) else: # special case: if last chunk and first predict is eos token, return last token of previous chunk. return a tensor with shape (1, 0) if finished: chunk_token_ids = torch.zeros((1, 0), dtype=torch.long, device=current_inputs_embeds.device) else: raise Exception("this should not happen") if len(last_hidden_states) > 0: last_hidden_states = torch.cat(last_hidden_states, dim=1) else: # special case: if last chunk, return last token of previous chunk. if finished: last_hidden_states = torch.cat(last_hidden_states, dim=1) else: raise Exception("this should not happen") if len(input_last_hidden_states) > 0: input_last_hidden_states = torch.cat(input_last_hidden_states, dim=1) else: input_last_hidden_states = None return GenerateChunkOutput( chunk_token_ids=chunk_token_ids, current_inputs_embeds=current_inputs_embeds, input_last_hidden_states=input_last_hidden_states, last_hidden_states=last_hidden_states, past_key_values=past_key_values, finished=finished, ) def chunk_generate_hf( self, inputs_embeds, past_key_values, is_first_generate_chunk, chunk_size=30, return_hidden_states=True, do_sample=False, **kwargs, ) -> GenerateChunkOutput: if not do_sample and kwargs.get("num_beams", None) is not None and kwargs.get("num_beams", None) > 1: logger.warning("chunk generate does not support beam search, fail to greedy search") kwargs["num_beams"] = 1 finished = False current_inputs_embeds = inputs_embeds.clone() input_last_hidden_states = None last_hidden_states_list = [] generated_tokens = [] cache_length = self._get_cache_length(past_key_values) for token_idx in range(chunk_size): if is_first_generate_chunk and token_idx == 0: gen_inputs_embeds = current_inputs_embeds input_seq_len = current_inputs_embeds.shape[1] else: gen_inputs_embeds = current_inputs_embeds[:, -1:, :] input_seq_len = 1 # construct attention_mask and cache_position total_length = cache_length + input_seq_len attention_mask = torch.ones((1, total_length), dtype=torch.long, device=self.model.device) cache_position = torch.arange(cache_length, total_length, dtype=torch.long, device=self.model.device) gen_config = self.prepare_generation_config(do_sample=do_sample, max_new_tokens=1, **kwargs) outputs = self.model.generate( inputs_embeds=gen_inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, cache_position=cache_position, use_cache=True, pad_token_id=self.tokenizer.eos_token_id, suppress_tokens=self.forbidden_token_ids, eos_token_id=self.terminators_ids, output_hidden_states=return_hidden_states, return_dict_in_generate=True, **gen_config, ) next_token = outputs.sequences[:, -1:] # update past_key_values and cache_length past_key_values = outputs.past_key_values cache_length = self._get_cache_length(past_key_values) # get hidden states if return_hidden_states and hasattr(outputs, "hidden_states") and outputs.hidden_states is not None: if is_first_generate_chunk and token_idx == 0: if len(outputs.hidden_states) > 0: input_last_hidden_states = outputs.hidden_states[0][-1] if len(outputs.hidden_states) > 1: last_hidden_states_list.append(outputs.hidden_states[1][-1]) else: if len(outputs.hidden_states) > 0: if len(outputs.hidden_states) > 1: last_hidden_states_list.append(outputs.hidden_states[1][-1]) else: last_hidden_states_list.append(outputs.hidden_states[0][-1]) if next_token.item() in self.terminators_ids: finished = True break generated_tokens.append(next_token) next_token_embed = self.embedding_layer(next_token) current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embed], dim=1) if len(generated_tokens) > 0: chunk_token_ids = torch.cat(generated_tokens, dim=1) else: chunk_token_ids = torch.zeros((1, 0), dtype=torch.long, device=self.model.device) if len(last_hidden_states_list) > 0: last_hidden_states = torch.cat(last_hidden_states_list, dim=1) else: hidden_dim = self.model.config.hidden_size last_hidden_states = torch.empty((1, 0, hidden_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device) return GenerateChunkOutput( chunk_token_ids=chunk_token_ids, current_inputs_embeds=current_inputs_embeds, input_last_hidden_states=input_last_hidden_states, last_hidden_states=last_hidden_states, past_key_values=past_key_values, finished=finished, ) def chunk_prefill_and_generate( self, inputs_embeds, prefill_chunk_size=5, generate_chunk_size=10, return_hidden_states=True, max_new_tokens=30, min_new_tokens=0, do_sample=True, chunk_fn="chunk_generate", **kwargs, ): assert inputs_embeds is not None generation_config = self.prepare_generation_config( do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs ) print(f"chunk_prefill_and_generate - generation config: {generation_config}") inputs_embeds = inputs_embeds.to(self.model.device) bs, seq_len = inputs_embeds.shape[:2] assert bs == 1, "batch should be 1" past_key_values = None with torch.no_grad(): last_prefill_chunk_embeds = None # prefill for start_idx in range(0, seq_len, prefill_chunk_size): end_idx = min(start_idx + prefill_chunk_size, seq_len) chunk_embeds = inputs_embeds[:, start_idx:end_idx, :] is_last_prefill_chunk = end_idx == seq_len if not is_last_prefill_chunk: model_inputs = { "inputs_embeds": chunk_embeds, "past_key_values": past_key_values, "use_cache": True, "output_hidden_states": return_hidden_states, } outputs = self.model(**model_inputs) past_key_values = outputs.past_key_values else: last_prefill_chunk_embeds = chunk_embeds break # decode if last_prefill_chunk_embeds is None: raise ValueError("last prefill chunk not found") generation_inputs_embeds = last_prefill_chunk_embeds.clone() generated_ids = torch.empty((bs, 0), dtype=torch.long, device=self.model.device) all_hidden_states = [] num_chunks_decode = (max_new_tokens + generate_chunk_size - 1) // generate_chunk_size for chunk_idx in range(num_chunks_decode): is_first_generate_chunk = chunk_idx == 0 if chunk_fn == "chunk_generate": output = self.chunk_generate( inputs_embeds=generation_inputs_embeds, past_key_values=past_key_values, is_first_generate_chunk=is_first_generate_chunk, chunk_size=generate_chunk_size + 1 * is_first_generate_chunk, return_hidden_states=return_hidden_states, do_sample=do_sample, temperature=generation_config.get("temperature", 0.7), top_p=generation_config.get("top_p", 0.8), top_k=generation_config.get("top_k", 20), repetition_penalty=generation_config.get("repetition_penalty", 1.05), all_input_ids=None, ) elif chunk_fn == "chunk_generate_hf": output = self.chunk_generate_hf( inputs_embeds=generation_inputs_embeds, past_key_values=past_key_values, is_first_generate_chunk=is_first_generate_chunk, chunk_size=generate_chunk_size + 1 * is_first_generate_chunk, return_hidden_states=return_hidden_states, min_new_tokens=min_new_tokens, do_sample=do_sample, **kwargs, ) else: raise NotImplementedError(f"not supported chunk_fn: {chunk_fn}") generated_ids = torch.cat([generated_ids, output.chunk_token_ids], dim=1) generation_inputs_embeds = output.current_inputs_embeds past_key_values = output.past_key_values if return_hidden_states and output.last_hidden_states is not None: all_hidden_states.append(output.last_hidden_states) if output.finished: break return generated_ids, all_hidden_states