| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Ernie VL model""" |
| | import re |
| | import math |
| | import itertools |
| | from dataclasses import dataclass |
| | from collections import defaultdict |
| | from copy import deepcopy |
| | from functools import partial |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn.attention import SDPBackend, sdpa_kernel |
| |
|
| | from transformers.activations import ACT2FN |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import ModelOutput |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import logging |
| | from .configuration_ernie4_5_vl import ( |
| | DFNRopeVisionTransformerConfig, |
| | Ernie4_5_MoEConfig, |
| | Ernie4_5_VLMoEConfig, |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | __all__ = [ |
| | "Ernie4_5_VLMoeForConditionalGeneration", |
| | "DFNRopeVisionTransformerPreTrainedModel", |
| | "VariableResolutionResamplerModel", |
| | ] |
| |
|
| |
|
| | class TokenType: |
| | """token type definition""" |
| |
|
| | text = 0 |
| | image = 1 |
| | video = 2 |
| |
|
| |
|
| | class UniqueNameGuard: |
| | """name guard""" |
| |
|
| | def __init__(self, prefix=""): |
| | self.prefix = prefix |
| | self.counter = {} |
| |
|
| | def __enter__(self): |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | pass |
| |
|
| | def get_unique_name(self, name): |
| | """get unique name""" |
| | if name not in self.counter: |
| | self.counter[name] = 0 |
| | else: |
| | self.counter[name] += 1 |
| | return f"{self.prefix}{name}_{self.counter[name]}" |
| |
|
| |
|
| | class RopeEmbedding(nn.Module): |
| | """ |
| | Rotary Position Embedding (RoPE) implementation for transformer models. |
| | |
| | RoPE encodes absolute positional information with rotation matrices and |
| | naturally incorporates relative position information in self-attention. |
| | |
| | Args: |
| | head_dim (int): Dimension size of each attention head |
| | compression_ratio (float, optional): Sequence length compression ratio. Defaults to 1.0. |
| | base (int, optional): Base value for frequency calculation. Defaults to 10000. |
| | |
| | Attributes: |
| | head_dim (int): Dimension size of each attention head |
| | compression_ratio (float): Sequence length compression factor |
| | base (int): Base value for frequency calculation |
| | """ |
| |
|
| | def __init__(self, head_dim, compression_ratio=1.0, base=10000, freq_allocation=0): |
| | """ |
| | Initialize RoPE embedding layer. |
| | |
| | Args: |
| | head_dim: Dimension of each attention head |
| | compression_ratio: Scaling factor for position indices |
| | base: Base value for frequency calculation |
| | """ |
| | super().__init__() |
| | self.head_dim = head_dim |
| | self.compression_ratio = compression_ratio |
| | self.base = base |
| |
|
| | |
| | self.freq_allocation = freq_allocation |
| |
|
| | def forward(self, seq_length, position_ids=None): |
| | """ |
| | Compute rotary position embeddings for given sequence length. |
| | |
| | Args: |
| | seq_length (int): Maximum sequence length |
| | position_ids (Tensor, optional): Custom position indices. Defaults to None. |
| | |
| | Returns: |
| | Tensor: Rotary position embeddings of shape [1, 1, seq_length, head_dim] |
| | """ |
| | indices = torch.arange(0, self.head_dim, 2, dtype=torch.float32) |
| | indices = 1 / self.base ** (indices / self.head_dim) |
| | if position_ids is None: |
| | position_ids = torch.arange( |
| | 0, seq_length, 1, dtype=torch.float32 |
| | ).unsqueeze(1) |
| | position_ids = position_ids / self.compression_ratio |
| | sinusoid_inp = position_ids * indices.unsqueeze(0) |
| | else: |
| | position_ids = position_ids / self.compression_ratio |
| | seq_length = position_ids.shape[-1] |
| | sinusoid_inp = position_ids.unsqueeze(-1).to( |
| | torch.float32 |
| | ) * indices.unsqueeze(0) |
| | pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1) |
| | pos_emb = pos_emb.view(-1, 1, seq_length, self.head_dim) |
| | pos_emb = pos_emb.detach() |
| | return pos_emb |
| |
|
| | def apply_rotary(self, rp, q, k): |
| | """ |
| | Apply rotary position embeddings to queries and keys. |
| | |
| | Args: |
| | rp (Tensor): Rotary position embeddings |
| | q (Tensor): Query tensor [batch, heads, seq_len, dim] |
| | k (Tensor): Key tensor [batch, heads, seq_len, dim] |
| | |
| | Returns: |
| | Tuple[Tensor, Tensor]: Rotated queries and keys |
| | """ |
| | sin, cos = torch.chunk(rp, 2, dim=-1) |
| | |
| | sin_pos = torch.stack([sin, sin], dim=-1).reshape(rp.shape) |
| | |
| | cos_pos = torch.stack([cos, cos], dim=-1).reshape(rp.shape) |
| | |
| | rotate_half_q = torch.stack( |
| | [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1 |
| | ).reshape(q.shape) |
| | query = (q.to(torch.float32) * cos_pos) + ( |
| | rotate_half_q.to(torch.float32) * sin_pos |
| | ) |
| | |
| | rotate_half_k = torch.stack( |
| | [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1 |
| | ).reshape(k.shape) |
| | key = (k.to(torch.float32) * cos_pos) + ( |
| | rotate_half_k.to(torch.float32) * sin_pos |
| | ) |
| | return query, key |
| |
|
| | def apply_rotary_3d(self, rp, q, k, position_ids): |
| | """ |
| | rope 3d rotary |
| | |
| | args: |
| | rp: [1, max_seqlen, 1, head_dim] |
| | q: [bsz, seqlen, head, head_dim] |
| | k: [bsz, seqlen, head, head_dim] |
| | position_ids: [bsz, seqlen, 3] |
| | """ |
| | current_device = q.device |
| | sin, cos = torch.chunk(rp, 2, axis=-1) |
| | assert position_ids.shape[:1] == q.shape[:1] |
| | batch_indices = torch.arange(end=position_ids.shape[0]) |
| | batch_indices = batch_indices[..., None] |
| | sin = sin.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device) |
| | cos = cos.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device) |
| |
|
| | assert self.freq_allocation != 0 |
| | sin_t = sin[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] |
| | sin_h = sin[ |
| | batch_indices, |
| | position_ids[..., 1], |
| | :, |
| | : self.head_dim // 2 - self.freq_allocation : 2, |
| | ] |
| | sin_w = sin[ |
| | batch_indices, |
| | position_ids[..., 2], |
| | :, |
| | 1 : self.head_dim // 2 - self.freq_allocation : 2, |
| | ] |
| | sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape( |
| | sin_h.shape[:-1] + (sin_h.shape[-1] * 2,) |
| | ) |
| | sin_thw = torch.cat([sin_hw, sin_t], dim=-1) |
| |
|
| | cos_t = cos[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] |
| | cos_h = cos[ |
| | batch_indices, |
| | position_ids[..., 1], |
| | :, |
| | : self.head_dim // 2 - self.freq_allocation : 2, |
| | ] |
| | cos_w = cos[ |
| | batch_indices, |
| | position_ids[..., 2], |
| | :, |
| | 1 : self.head_dim // 2 - self.freq_allocation : 2, |
| | ] |
| | cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape( |
| | cos_h.shape[:-1] + (cos_h.shape[-1] * 2,) |
| | ) |
| | cos_thw = torch.cat([cos_hw, cos_t], dim=-1) |
| |
|
| | |
| | sin_pos = ( |
| | torch.stack([sin_thw, sin_thw], dim=-1) |
| | .reshape(sin_thw.shape[:3] + (sin_thw.shape[-1] * 2,)) |
| | .to(current_device) |
| | ) |
| | |
| | cos_pos = ( |
| | torch.stack([cos_thw, cos_thw], dim=-1) |
| | .reshape(cos_thw.shape[:3] + (cos_thw.shape[-1] * 2,)) |
| | .to(current_device) |
| | ) |
| |
|
| | |
| | rotate_half_q = torch.stack( |
| | [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1 |
| | ).reshape(q.shape) |
| | query = (q.to(torch.float32) * cos_pos) + ( |
| | rotate_half_q.to(torch.float32) * sin_pos |
| | ) |
| | |
| | rotate_half_k = torch.stack( |
| | [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1 |
| | ).reshape(k.shape) |
| | key = (k.to(torch.float32) * cos_pos) + ( |
| | rotate_half_k.to(torch.float32) * sin_pos |
| | ) |
| | return query, key |
| |
|
| |
|
| | class Ernie4_5_MLP(nn.Module): |
| | """ |
| | Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model. |
| | """ |
| |
|
| | def __init__(self, config, layer_idx=0): |
| | """ |
| | Initialize the MLP module with configuration options. |
| | |
| | Args: |
| | config (Ernie4_5_Config): Model configurations. |
| | layer_idx (int): Index of current layer (default: 0) |
| | """ |
| | super().__init__() |
| | self.config = config |
| | self.hidden_size = config.hidden_size |
| | self.intermediate_size = config.intermediate_size |
| |
|
| | self.gate_proj = nn.Linear( |
| | self.hidden_size, self.intermediate_size, bias=config.use_bias |
| | ) |
| | self.up_proj = nn.Linear( |
| | self.hidden_size, self.intermediate_size, bias=config.use_bias |
| | ) |
| | self.down_proj = nn.Linear( |
| | self.intermediate_size, self.hidden_size, bias=config.use_bias |
| | ) |
| |
|
| | def forward(self, x): |
| | """ |
| | Forward pass through the MLP module. |
| | |
| | Args: |
| | x (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
| | |
| | Returns: |
| | Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] |
| | """ |
| | current_device = self.gate_proj.weight.data.device |
| | x = x.to(current_device) |
| | down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
| | return down_proj |
| |
|
| |
|
| | class Ernie4_5_Attention(nn.Module): |
| | """Multi-headed attention from 'Attention Is All You Need' paper""" |
| |
|
| | def __init__(self, config, layer_idx=0): |
| | """Initialize the attention layer. |
| | |
| | Args: |
| | config (Ernie4_5_Config): Model configuration. |
| | layer_idx (int, optional): Index in transformer stack. Defaults to 0. |
| | """ |
| | super().__init__() |
| | self.layer_idx = layer_idx |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.num_key_value_heads = config.num_key_value_heads |
| | self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| | self.head_dim = self.hidden_size // self.num_heads |
| | self.is_gqa = ( |
| | self.num_key_value_heads is not None |
| | and self.num_key_value_heads != self.num_heads |
| | ) |
| |
|
| | self.freq_allocation = getattr(config, "freq_allocation", 0) |
| | assert ( |
| | self.freq_allocation is not None |
| | ), "freq_allocation must be provided if rope_3d is on." |
| |
|
| | if config.tensor_parallel_degree > 1: |
| | assert ( |
| | self.num_heads % config.tensor_parallel_degree == 0 |
| | ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" |
| | self.num_heads = self.num_heads // config.tensor_parallel_degree |
| | if self.is_gqa: |
| | assert ( |
| | self.num_key_value_heads % config.tensor_parallel_degree == 0 |
| | ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" |
| | self.num_key_value_heads = ( |
| | self.num_key_value_heads // config.tensor_parallel_degree |
| | ) |
| | q_hidden_size = self.head_dim * self.num_heads |
| | if self.is_gqa: |
| | logger.info( |
| | f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" |
| | ) |
| | assert ( |
| | self.num_heads % self.num_key_value_heads == 0 |
| | ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" |
| | kv_hidden_size = self.head_dim * self.num_key_value_heads |
| | else: |
| | kv_hidden_size = self.head_dim * self.num_heads |
| |
|
| | self.q_proj = nn.Linear(self.hidden_size, q_hidden_size, bias=config.use_bias) |
| | self.k_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias) |
| | self.v_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias) |
| |
|
| | self.o_proj = nn.Linear( |
| | self.hidden_size, |
| | self.hidden_size, |
| | bias=config.use_bias, |
| | ) |
| |
|
| | self.rotary_emb = RopeEmbedding( |
| | self.head_dim, |
| | compression_ratio=config.compression_ratio, |
| | base=config.rope_theta, |
| | freq_allocation=self.freq_allocation, |
| | ) |
| | self.config = config |
| | if self.config.use_flash_attention: |
| | self.attn_func = self._flash_attention_wrapper |
| | else: |
| | self.attn_func = self.core_attn |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attn_mask_start_row_indices: Optional[torch.Tensor] = None, |
| | position_ids: Optional[Tuple[torch.Tensor]] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | token_type_ids: Optional[Tuple[torch.Tensor]] = None, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| | """Compute attention outputs. |
| | |
| | Args: |
| | hidden_states (torch.Tensor): Input tensor [bsz, seq_len, hidden_size] |
| | past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached key/value states |
| | attention_mask (Optional[torch.Tensor]): Attention mask tensor |
| | attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices |
| | position_ids (Optional[torch.Tensor]): Position indices for RoPE |
| | output_attentions (bool): Return attention weights if True |
| | use_cache (bool): Cache key/value states if True |
| | |
| | Returns: |
| | Tuple containing: |
| | - attention_output: [bsz, seq_len, hidden_size] |
| | - attention_weights: Optional attention probabilities |
| | - updated_key_value_cache: Optional updated cache |
| | """ |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids[:, :-1] |
| |
|
| | bsz, q_len, _ = hidden_states.shape |
| | query_states = self.q_proj(hidden_states).reshape( |
| | [bsz, q_len, -1, self.head_dim] |
| | ) |
| | key_states = self.k_proj(hidden_states).reshape([bsz, q_len, -1, self.head_dim]) |
| | value_states = self.v_proj(hidden_states).reshape( |
| | [bsz, q_len, -1, self.head_dim] |
| | ) |
| |
|
| | attn_output, attn_weights, past_key_value = self.rope_attn( |
| | query_states=query_states, |
| | key_states=key_states, |
| | value_states=value_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | output_attentions=output_attentions, |
| | past_key_value=past_key_value, |
| | use_cache=use_cache, |
| | attn_mask_start_row_indices=attn_mask_start_row_indices, |
| | ) |
| | attn_output = self.o_proj(attn_output) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| | def repeat_kv(self, hidden_states, n_rep): |
| | """ |
| | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| | """ |
| | batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| | if n_rep == 1: |
| | return hidden_states |
| | hidden_states = hidden_states[:, :, None, :, :].expand( |
| | batch, num_key_value_heads, n_rep, slen, head_dim |
| | ) |
| | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| |
|
| | def _flash_attention_wrapper( |
| | self, |
| | q, |
| | k, |
| | v, |
| | attention_mask=None, |
| | attn_mask_start_row_indices=None, |
| | seq_length=None, |
| | ): |
| | """Wrapper for flash attention implementation. |
| | Args: |
| | q (torch.Tensor): Query tensor |
| | k (torch.Tensor): Key tensor |
| | v (torch.Tensor): Value tensor |
| | attention_mask (Optional[torch.Tensor]): Attention mask |
| | attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices |
| | seq_length (Optional[int]): Sequence length |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Attention output and weights |
| | """ |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| |
|
| | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
| | out = F.scaled_dot_product_attention( |
| | q, |
| | k, |
| | v, |
| | attn_mask=None, |
| | dropout_p=self.config.attention_probs_dropout_prob, |
| | is_causal=q.shape[-2] == k.shape[-2], |
| | scale=1 |
| | / (getattr(self.config, "scale_qk_coeff", 1.0) * self.head_dim**0.5), |
| | enable_gqa=self.is_gqa, |
| | ) |
| | out = out.transpose(1, 2) |
| | out = out.contiguous().view(out.size(0), out.size(1), -1) |
| |
|
| | return out, None |
| |
|
| | def core_attn( |
| | self, |
| | q, |
| | k, |
| | v, |
| | attention_mask=None, |
| | attn_mask_start_row_indices=None, |
| | seq_length=None, |
| | ): |
| | """Standard self-attention implementation. |
| | |
| | Args: |
| | q (torch.Tensor): Query tensor |
| | k (torch.Tensor): Key tensor |
| | v (torch.Tensor): Value tensor |
| | attention_mask (Optional[torch.Tensor]): Attention mask |
| | attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices |
| | seq_length (Optional[int]): Sequence length |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Attention output and weights |
| | """ |
| | origin_dtype = q.dtype |
| |
|
| | q = q.permute(0, 2, 1, 3) |
| | k = k.permute(0, 2, 1, 3) |
| | v = v.permute(0, 2, 1, 3) |
| |
|
| | scale_qk_coeff = getattr(self.config, "scale_qk_coeff", 1.0) * ( |
| | self.head_dim**0.5 |
| | ) |
| |
|
| | q = q / scale_qk_coeff |
| |
|
| | |
| | if self.is_gqa: |
| | |
| | repeat_factor = self.num_heads // self.num_key_value_heads |
| | k = self.repeat_kv(k, repeat_factor) |
| | v = self.repeat_kv(v, repeat_factor) |
| |
|
| | product = torch.matmul(q, k.transpose(-2, -1)) |
| |
|
| | product = product.to(torch.float32) |
| | if getattr(self.config, "scale_qk_coeff", 1.0) != 1.0: |
| | product = product * getattr(self.config, "scale_qk_coeff", 1.0) |
| |
|
| | seq_len = product.size(-1) |
| | mask = torch.triu( |
| | torch.ones((seq_len, seq_len), dtype=torch.bool, device=product.device), |
| | diagonal=1, |
| | ) |
| | product = product.masked_fill(mask, float("-inf")) |
| | weights = F.softmax(product, dim=-1) |
| |
|
| | weights = weights.to(origin_dtype) |
| |
|
| | if getattr(self.config, "attention_probs_dropout_prob", 0.0) > 0: |
| | weights = F.dropout( |
| | weights, |
| | self.config.attention_probs_dropout_prob, |
| | training=self.training, |
| | ) |
| |
|
| | out = torch.matmul(weights, v) |
| |
|
| | |
| | out = out.permute(0, 2, 1, 3) |
| | out = out.contiguous().view(out.size(0), out.size(1), -1) |
| |
|
| | return out, weights |
| |
|
| | def rope_attn( |
| | self, |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | position_ids, |
| | output_attentions=False, |
| | past_key_value=None, |
| | use_cache=False, |
| | attn_mask_start_row_indices=None, |
| | ): |
| | """Attention computation with rotary embeddings. |
| | |
| | Args: |
| | mix_layer (Optional[torch.Tensor]): Combined QKV projection |
| | query_states (torch.Tensor): Query states |
| | key_states (torch.Tensor): Key states |
| | value_states (torch.Tensor): Value states |
| | attention_mask (Optional[torch.Tensor]): Attention mask |
| | position_ids (Optional[torch.Tensor]): Position indices |
| | output_attentions (bool): Return attention weights |
| | past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached states |
| | use_cache (bool): Cache new states |
| | attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices |
| | |
| | Returns: |
| | Tuple containing: |
| | - attention_output: Result tensor |
| | - attention_weights: Optional weights |
| | - updated_key_value_cache: Optional cache |
| | """ |
| |
|
| | query_states_dtype = query_states.dtype |
| |
|
| | assert position_ids is not None, "rope3d requires pos-id" |
| | kv_seq_len = position_ids.max() + 1 |
| | offset = 0 |
| | if past_key_value is not None: |
| | offset = position_ids.max() |
| | kv_seq_len = position_ids.max() + 1 |
| | position_ids = position_ids[:, -1:, :] |
| |
|
| | cos_sin = self.rotary_emb(kv_seq_len).permute([0, 2, 1, 3]) |
| | if offset > 0 and position_ids is None: |
| | cos_sin = cos_sin[:, offset:] |
| | query_states, key_states = self.rotary_emb.apply_rotary_3d( |
| | cos_sin, query_states, key_states, position_ids |
| | ) |
| |
|
| | query_states = query_states.to(query_states_dtype) |
| | key_states = key_states.to(query_states_dtype) |
| | if past_key_value is not None: |
| | |
| | key_states = torch.cat([past_key_value[0], key_states], dim=1) |
| | value_states = torch.cat([past_key_value[1], value_states], dim=1) |
| |
|
| | |
| | past_key_value = [key_states, value_states] if use_cache else None |
| | seq_length = query_states.shape[1] |
| | attn_output, attn_weights = self.attn_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | attn_mask_start_row_indices, |
| | seq_length, |
| | ) |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| |
|
| | class FusedDropoutImpl(nn.Module): |
| | """ |
| | Fused dropout implementation with residual connection support. |
| | |
| | This layer combines dropout and residual addition in a single operation for better performance, |
| | particularly on GPU devices. The dropout is conditionally applied based on the probability. |
| | |
| | Args: |
| | prob (float): Dropout probability (between 0 and 1) |
| | mode (str): Dropout mode, either 'upscale_in_train' or 'downscale_in_infer' |
| | |
| | Attributes: |
| | prob (float): Stores the dropout probability |
| | mode (str): Stores the dropout mode |
| | dropout (nn.Dropout): The actual dropout layer instance |
| | """ |
| |
|
| | def __init__(self, prob, mode): |
| | """ |
| | Initialize the fused dropout layer. |
| | |
| | Args: |
| | prob (float): Dropout probability (0 means no dropout) |
| | mode (str): Dropout mode ('upscale_in_train' or 'downscale_in_infer') |
| | """ |
| | super().__init__() |
| | self.prob = prob |
| | self.dropout = nn.Dropout(p=prob) |
| |
|
| | def forward(self, x, y): |
| | """ |
| | Forward pass of the fused dropout layer. |
| | |
| | Args: |
| | x (Tensor): Input tensor to potentially apply dropout on |
| | y (Tensor): Residual tensor to add to the (possibly dropped out) x |
| | |
| | Returns: |
| | Tensor: Result of x (with optional dropout) + y |
| | """ |
| | if self.prob > 0: |
| | x = self.dropout(x) |
| | output = x + y |
| |
|
| | return output |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | """ |
| | Root Mean Square Layer Normalization (RMSNorm) implementation. |
| | |
| | RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs, |
| | omitting the mean-centering operation. This provides computational efficiency while maintaining |
| | good performance. |
| | |
| | """ |
| |
|
| | def __init__(self, config): |
| | """ |
| | Initialize RMSNorm layer. |
| | |
| | Args: |
| | config (Ernie4_5_Config): Model configuration. |
| | """ |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | self.weight = nn.Parameter( |
| | torch.ones(self.hidden_size, dtype=torch.get_default_dtype()) |
| | ) |
| | self.variance_epsilon = config.rms_norm_eps |
| |
|
| | def forward(self, hidden_states): |
| | """ |
| | Apply RMS normalization to input hidden states. |
| | |
| | Args: |
| | hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
| | |
| | Returns: |
| | Tensor: Normalized output tensor of same shape as input |
| | |
| | Note: |
| | - computes RMSNorm manually: |
| | 1. Compute variance of features |
| | 2. Apply reciprocal square root normalization |
| | 3. Scale by learned weight parameter |
| | - Maintains original dtype for numerical stability during computation |
| | """ |
| | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| | hidden_states = torch.rsqrt(variance + self.variance_epsilon) * hidden_states |
| | return hidden_states.to(self.weight.dtype) * self.weight |
| |
|
| |
|
| | class Ernie4_5_MoeMLP(Ernie4_5_MLP): |
| | """Mixture of Experts (MoE) variant of ERNIE's MLP layer.""" |
| |
|
| | def __init__(self, config, layer_idx=0): |
| | """Initialize the MoE MLP layer. |
| | |
| | Args: |
| | config (Ernie4_5_MoEConfig): Configuration for MoE architecture. |
| | layer_idx (int): Index of current layer in transformer stack |
| | """ |
| |
|
| | if getattr(config, "disable_ffn_model_parallel", False): |
| | config = deepcopy(config) |
| | config.tensor_parallel_degree = 1 |
| |
|
| | super().__init__(config, layer_idx=layer_idx) |
| | self.moe_dropout_prob = config.moe_dropout_prob |
| |
|
| | def forward(self, x): |
| | """Forward pass through MoE MLP layer. |
| | |
| | Args: |
| | x (paddle.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
| | or [seq_len, hidden_size] |
| | |
| | Returns: |
| | paddle.Tensor: Output tensor with same shape as input |
| | """ |
| | current_device = self.gate_proj.weight.data.device |
| | x = x.to(current_device) |
| | x = F.silu(self.gate_proj(x)) * self.up_proj(x) |
| | if self.moe_dropout_prob > 0: |
| | x = F.dropout(input=x, p=self.moe_dropout_prob) |
| | ret = self.down_proj(x) |
| | return ret |
| |
|
| |
|
| | def masked_fill(x, mask, value): |
| | """ |
| | Fills elements of the input tensor with a given value where mask is True. |
| | """ |
| | return torch.where(mask, torch.full_like(x, value), x) |
| |
|
| |
|
| | def _squared_l2_norm(x: torch.Tensor) -> torch.Tensor: |
| | """Computes 0.5 * sum(x^2)""" |
| | return 0.5 * torch.sum(x * x) |
| |
|
| |
|
| | @torch.no_grad() |
| | def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): |
| | """ |
| | Computes optimal transport matrix and Sinkhorn distance using Sinkhorn-Knopp algorithm. |
| | """ |
| | n, _ = M.shape |
| | P = F.softmax(-M / lam, dim=1) |
| | u = torch.zeros(n, dtype=torch.float32, device=M.device) |
| |
|
| | for _ in range(max_iters): |
| | P_sum_1 = P.sum(1) |
| | if (u - P_sum_1).abs().max() < epsilon: |
| | break |
| | u = P_sum_1 |
| | P *= (r / (u + 1e-8)).unsqueeze(1) |
| | P *= (c / (P.sum(0) + 1e-8)).unsqueeze(0) |
| |
|
| | P = torch.where(~P.isnan(), P, torch.zeros_like(P)) |
| | return P, _ |
| |
|
| |
|
| | class Top2Gate(nn.Module): |
| | """ |
| | Gate module implementing Top2Gating as described in Gshard paper. |
| | """ |
| |
|
| | def __init__(self, config, layer_idx: int, group=None, gate_weight=None) -> None: |
| | """ |
| | Initialize the MoE (Mixture of Experts) layer. |
| | |
| | Args: |
| | config: Model configuration containing MoE parameters |
| | layer_idx: Index of this layer in the model |
| | group: Distributed communication group |
| | gate_weight: Optional pre-existing gate weight tensor |
| | """ |
| | super().__init__() |
| | self.config = config |
| |
|
| | self.model_dim = config.hidden_size |
| | self.num_experts = config.moe_num_experts |
| | self.num_experts_tensor = ( |
| | sum(config.moe_num_experts) |
| | if config.multimodel_experts |
| | else config.moe_num_experts |
| | ) |
| |
|
| | self.cap = config.moe_capacity |
| | self.group = group |
| |
|
| | self.layer_idx = layer_idx |
| |
|
| | self.sinkhorn_2gate = config.sinkhorn_2gate |
| | self.sinkhorn_temp = config.sinkhorn_temp |
| | self.use_correction_bias = config.moe_use_aux_free |
| | self.use_token_type_bias = config.get("moe_use_token_type_bias", False) |
| |
|
| | self.act = partial(F.softmax, dim=-1) |
| |
|
| | self.no_jitter = True |
| | self.expert_drop = False |
| | self.eye_matrix = None |
| | self.eye_matrix_size = None |
| | self.norm_gate_logits = config.moe_norm_gate_logits |
| | self.one = torch.ones([], dtype=torch.float32) |
| |
|
| | self.moe_aux_loss_lambda = torch.tensor(config.moe_aux_loss_lambda).to( |
| | dtype=torch.float32 |
| | ) |
| | self.moe_z_loss_lambda = torch.tensor(config.moe_z_loss_lambda).to( |
| | dtype=torch.float32 |
| | ) |
| | self.moe_orthogonal_loss_lambda = torch.tensor( |
| | config.moe_orthogonal_loss_lambda |
| | ).to(dtype=torch.float32) |
| |
|
| | if self.moe_aux_loss_lambda.ndim == 0: |
| | self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) |
| | if self.moe_z_loss_lambda.ndim == 0: |
| | self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) |
| | if self.moe_orthogonal_loss_lambda.ndim == 0: |
| | self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( |
| | 0 |
| | ) |
| |
|
| | self.experts_type_ids = None |
| |
|
| | self.eps = torch.tensor([1e-12]).to(dtype=torch.float32) |
| | if config.multimodel_experts: |
| | if config.get("moe_use_hard_gate", False): |
| | self.num_experts_list = [] |
| | self.experts_type_mask = [] |
| | |
| | experts_ids = torch.zeros( |
| | [sum(self.num_experts)], dtype=torch.int64 |
| | ).reshape((1, -1)) |
| | offset = 0 |
| | for i, expert_num in enumerate(self.num_experts): |
| | experts_ids[:, offset : offset + expert_num] = i |
| | offset += expert_num |
| | self.experts_type_ids = experts_ids.reshape([-1]) |
| | logger.info( |
| | f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" |
| | ) |
| | for i, expert_num in enumerate(self.num_experts): |
| | self.experts_type_mask.append( |
| | self.experts_type_ids == i, |
| | ) |
| | self.num_experts_list.append(expert_num) |
| | else: |
| | |
| | assert ( |
| | not config.moe_group_experts |
| | ), "group_experts must use hard_gate when multimodel_experts is True" |
| | else: |
| | self.num_experts_list = [self.num_experts] |
| |
|
| | if gate_weight is not None: |
| | self.weight = gate_weight |
| |
|
| | assert ( |
| | not self.config.moe_use_token_type_bias |
| | ), "gate_weights is from outside, token_type_bias can't be used" |
| | logger.info("moe use gate_weight from outside") |
| | |
| | self._cast_to_low_precision = False |
| | self._cast_to_low_precison = False |
| | else: |
| | self._create_gate_parameter() |
| | logger.info( |
| | f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " |
| | f"use_token_type_bias:{self.use_token_type_bias} " |
| | f"gate_act:{config.moe_gate_act} " |
| | f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" |
| | ) |
| |
|
| | def _create_gate_parameter(self): |
| | """ |
| | Create gate weight parameter. |
| | """ |
| | if self.config.multimodel_experts: |
| | |
| | self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( |
| | len(self.num_experts) |
| | ) |
| | self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( |
| | len(self.num_experts) |
| | ) |
| | self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( |
| | len(self.num_experts) |
| | ) |
| |
|
| | for i, num_experts in enumerate(self.num_experts): |
| | if i == 1: |
| | with UniqueNameGuard(f"mm_gate_{self.layer_idx}_"): |
| | p = nn.Parameter( |
| | torch.empty( |
| | self.model_dim, |
| | num_experts, |
| | dtype=torch.float32, |
| | device="cpu", |
| | ) |
| | ) |
| | nn.init.xavier_uniform_(p) |
| | else: |
| | p = nn.Parameter( |
| | torch.empty( |
| | self.model_dim, |
| | num_experts, |
| | dtype=torch.float32, |
| | device="cpu", |
| | ) |
| | ) |
| | nn.init.xavier_uniform_(p) |
| | self.register_parameter( |
| | "weight" if i == 0 else f"weight_{i}", |
| | p, |
| | ) |
| | else: |
| | self.weight = nn.Parameter( |
| | torch.empty(self.model_dim, self.num_experts, dtype=torch.float32) |
| | ) |
| | nn.init.xavier_uniform_(self.weight) |
| | |
| | self._cast_to_low_precision = False |
| | self._cast_to_low_precison = False |
| |
|
| | def get_gate_weight(self, transform_weight, is_multimodel=True): |
| | """ |
| | 在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体 |
| | transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠 |
| | """ |
| | if not is_multimodel or not self.config.multimodel_experts: |
| | return self.weight |
| | else: |
| | return torch.cat( |
| | [ |
| | getattr(self, "weight" if i == 0 else f"weight_{i}") |
| | for i in range(len(self.num_experts)) |
| | ], |
| | -1, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input: torch.Tensor, |
| | token_type_ids: torch.Tensor = None, |
| | transform_weight: bool = True, |
| | correction_bias: torch.Tensor = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Forward pass through the gate. |
| | |
| | Args: |
| | input: Input tensor of shape [Seq, Dim] |
| | token_type_ids: Token type IDs tensor of shape [Seq] |
| | transform_weight: Whether to transform weights for multimodal experts |
| | correction_bias: Bias tensor for correction |
| | |
| | Returns: |
| | tuple: (capacity, dispatch_mask, combine_weights, scatter_index, router_loss, logits) |
| | """ |
| | orig_dtype = input.dtype |
| | current_device = input.device |
| | weight = self.get_gate_weight(transform_weight) |
| |
|
| | logits = F.linear( |
| | input.to(dtype=torch.float32, device=current_device), |
| | weight.T.to(dtype=torch.float32, device=current_device), |
| | ) |
| |
|
| | ( |
| | capacity, |
| | dispatch_mask, |
| | combine_weights, |
| | scatter_index, |
| | l_aux, |
| | l_zloss, |
| | ) = self.top2_gating( |
| | logits, |
| | correction_bias=( |
| | correction_bias.to(device=current_device) |
| | if correction_bias is not None |
| | else None |
| | ), |
| | ) |
| |
|
| | combine_weights = combine_weights.to(orig_dtype) |
| | return capacity, dispatch_mask, combine_weights, scatter_index, None, logits |
| |
|
| | def get_capacity(self, num_tokens, cap_factor=None, is_multimodel=True): |
| | """ |
| | Calculate capacity based on number of tokens. |
| | |
| | Args: |
| | num_tokens: Number of input tokens |
| | cap_factor: Optional capacity factor override |
| | |
| | Returns: |
| | int: Calculated capacity |
| | """ |
| | if is_multimodel and self.config.multimodel_experts: |
| | num_experts = sum(self.num_experts_list) |
| | elif isinstance(self.num_experts, (list, tuple)): |
| | num_experts = self.num_experts[0] |
| | else: |
| | num_experts = self.num_experts |
| | if cap_factor is not None: |
| | cap = cap_factor |
| | else: |
| | if self.training: |
| | cap = self.cap[0] |
| | elif num_tokens < num_experts: |
| | cap = self.cap[2] |
| | else: |
| | cap = self.cap[1] |
| | |
| | capacity = int(cap * num_tokens // num_experts) |
| | assert ( |
| | capacity > 0 |
| | ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" |
| | return capacity |
| |
|
| | def top2_gating(self, logits, cap=None, correction_bias=None): |
| | """ |
| | Implement Top2 gating mechanism. |
| | |
| | Args: |
| | logits: Input logits tensor |
| | cap: Optional capacity override |
| | correction_bias: Bias tensor for correction |
| | |
| | Returns: |
| | tuple: (capacity, dispatch_masks, combine_weights, scatter_indexes, loss_aux, loss_z) |
| | |
| | Note: |
| | capacity: The maximum number that each token can be dispatched. |
| | dispatch_masks: Masks used for dispatching. The first element is the mask for the first |
| | type of tokens; the second element is the mask for the second type of tokens. |
| | combine_weights: Weights used for combining. The first element is the weight for the first |
| | type of tokens; the second element is the weight for the second type of tokens. |
| | scatter_indexes: Indexes used for scattering. The first element is the index for the first |
| | type of tokens; the second element is the index for the second type of tokens. |
| | loss_aux: Auxiliary loss. |
| | loss_z: Z loss. |
| | """ |
| | gates = self.act(logits) |
| |
|
| | |
| | assert logits.ndim == 2, logits.shape |
| | num_tokens = gates.shape[0] |
| | num_experts = gates.shape[1] |
| | |
| | capacity = self.get_capacity(logits.shape[0], cap) |
| | current_device = logits.device |
| |
|
| | |
| | score_for_argmax = ( |
| | gates + correction_bias.unsqueeze(0) |
| | if correction_bias is not None |
| | else gates |
| | ) |
| | indices1_s = torch.argmax(score_for_argmax, dim=1) |
| | mask1 = F.one_hot(indices1_s, num_classes=num_experts).to( |
| | dtype=torch.int64, device=current_device |
| | ) |
| |
|
| | |
| | |
| | if self.training and not self.no_jitter: |
| | gumbels = ( |
| | -torch.empty_like( |
| | logits, |
| | device=current_device, |
| | ) |
| | .exponential_() |
| | .log() |
| | ) |
| | logits_w_noise = logits + gumbels |
| | else: |
| | logits_w_noise = logits |
| |
|
| | logits_except1 = masked_fill( |
| | logits_w_noise, |
| | mask1.to(dtype=torch.bool, device=current_device), |
| | float("-inf"), |
| | ) |
| | score_for_argmax = ( |
| | self.act(logits_except1) + correction_bias.unsqueeze(0) |
| | if correction_bias is not None |
| | else logits_except1 |
| | ) |
| | indices2_s_original = torch.argmax(score_for_argmax, dim=1) |
| |
|
| | if self.training and self.sinkhorn_2gate: |
| | r = ( |
| | torch.ones(num_tokens, dtype=torch.float32, device=current_device) |
| | / num_tokens |
| | ) |
| | c_mask_sum = mask1.to(dtype=torch.float32, device=current_device).sum(0) |
| | c = capacity - c_mask_sum |
| | c = torch.maximum(c, torch.zeros_like(c, device=current_device)) |
| | c_sum = c.sum() |
| | if c_sum > 0: |
| | c = c / c_sum |
| | else: |
| | c = torch.ones_like(c, device=current_device) / num_experts |
| |
|
| | pi, _ = compute_optimal_transport( |
| | -logits_except1.to(dtype=torch.float32, device=current_device).detach(), |
| | r, |
| | c, |
| | lam=self.sinkhorn_temp, |
| | ) |
| | pi = masked_fill( |
| | pi, mask1.to(dtype=torch.bool, device=current_device), float("-inf") |
| | ) |
| | indices2_s = torch.argmax(pi, dim=1) |
| | else: |
| | indices2_s = indices2_s_original |
| |
|
| | mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).to( |
| | dtype=torch.int64, device=current_device |
| | ) |
| |
|
| | |
| | locations1 = ( |
| | torch.cumsum(mask1, dim=0) - 1 |
| | ) |
| | locations2 = torch.cumsum(mask2, dim=0) - 1 |
| | |
| | locations2 += torch.sum(mask1, dim=0, keepdim=True) |
| |
|
| | |
| | mask1 = mask1 * (locations1 < capacity).to( |
| | dtype=torch.int64, device=current_device |
| | ) |
| | mask2 = mask2 * (locations2 < capacity).to( |
| | dtype=torch.int64, device=current_device |
| | ) |
| |
|
| | |
| | locations1_s = torch.sum(locations1 * mask1, dim=1) |
| | locations2_s = torch.sum(locations2 * mask2, dim=1) |
| |
|
| | |
| | mask1_float = mask1.to(dtype=torch.float32, device=current_device) |
| | mask2_float = mask2.to(dtype=torch.float32, device=current_device) |
| | gates1_s = (gates * mask1_float).sum(dim=-1) |
| | gates2_s = (gates * mask2_float).sum(dim=-1) |
| | |
| |
|
| | if self.norm_gate_logits: |
| | denom_s = gates1_s + gates2_s |
| | |
| | denom_s = torch.clamp(denom_s, min=1e-6) |
| | gates1_s /= denom_s |
| | gates2_s /= denom_s |
| | if self.training and self.expert_drop: |
| | |
| | gates2_s = torch.where( |
| | 2 * gates2_s < torch.rand_like(gates2_s, device=current_device), |
| | torch.zeros_like(gates2_s, device=current_device), |
| | gates2_s, |
| | ) |
| |
|
| | |
| | gates1 = gates1_s.unsqueeze(1) * mask1_float |
| | gates2 = gates2_s.unsqueeze(1) * mask2_float |
| |
|
| | combine1_weight, expert1_index = torch.max(gates1, dim=-1, keepdim=True) |
| | scatter1_index = expert1_index.squeeze(-1) * capacity + locations1_s |
| | scatter1_index = scatter1_index.to(dtype=torch.int64, device=current_device) |
| | dispatch1_mask = combine1_weight.to( |
| | dtype=torch.bool, device=current_device |
| | ).detach() |
| |
|
| | combine2_weight, expert2_index = torch.max(gates2, dim=-1, keepdim=True) |
| | scatter2_index = expert2_index.squeeze(-1) * capacity + locations2_s |
| | scatter2_index = scatter2_index.to(dtype=torch.int64, device=current_device) |
| | dispatch2_mask = combine2_weight.to( |
| | dtype=torch.bool, device=current_device |
| | ).detach() |
| | |
| |
|
| | return ( |
| | capacity, |
| | torch.cat((dispatch1_mask, dispatch2_mask), 1), |
| | torch.cat((combine1_weight, combine2_weight), 1), |
| | torch.stack((scatter1_index, scatter2_index), 1), |
| | None, |
| | None, |
| | ) |
| |
|
| | def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group): |
| | """ |
| | Calculate optimized orthogonal loss for each weight. |
| | |
| | Args: |
| | weight: Weight tensor |
| | use_group: Whether to use expert groups |
| | |
| | Returns: |
| | Tensor: Calculated orthogonal loss |
| | """ |
| | if weight.dtype != torch.float32: |
| | weight = weight.to(torch.float32) |
| |
|
| | wnorm = torch.norm(weight, p=2, dim=1) |
| | weight = weight / torch.maximum(wnorm, self.eps.to(weight.device)).unsqueeze(1) |
| |
|
| | if use_group: |
| | weight = weight.reshape( |
| | [self.config.moe_k, -1, weight.shape[1]] |
| | ) |
| | eye_matrix = torch.eye( |
| | weight.shape[1], dtype=weight.dtype, device=weight.device |
| | ).unsqueeze(0) |
| | else: |
| | eye_matrix = torch.eye( |
| | weight.shape[0], dtype=weight.dtype, device=weight.device |
| | ) |
| |
|
| | weight_matmul = torch.matmul(weight, weight.T) |
| |
|
| | orthogonal_loss = weight_matmul - eye_matrix |
| | orthogonal_loss = _squared_l2_norm(orthogonal_loss) / ( |
| | orthogonal_loss.size(0) * orthogonal_loss.size(1) |
| | ) |
| | return orthogonal_loss |
| |
|
| |
|
| | class TopKGate(Top2Gate): |
| | """ |
| | Fused version of TopK gate for improved performance. |
| | """ |
| |
|
| | def forward( |
| | self, |
| | input: torch.Tensor, |
| | token_type_ids=None, |
| | transform_weight=True, |
| | is_multimodel=True, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Forward pass for fused gate. |
| | |
| | Args: |
| | input: Input tensor |
| | token_type_ids: Token type IDs |
| | transform_weight: Whether to transform weights |
| | |
| | Returns: |
| | tuple: (logits, capacity, router_loss) |
| | """ |
| | current_device = input.device |
| | weight = self.get_gate_weight(transform_weight, is_multimodel=is_multimodel) |
| |
|
| | logits = F.linear( |
| | input.to(dtype=torch.float32, device=current_device), |
| | weight.T.to(dtype=torch.float32, device=current_device), |
| | ) |
| | if self.use_token_type_bias: |
| | assert token_type_ids is not None |
| | assert ( |
| | token_type_ids.max() < self.bias.shape[0] |
| | ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" |
| | bias = self.bias[token_type_ids] |
| | logits = logits + bias |
| |
|
| | return logits |
| |
|
| |
|
| | gate_class = dict( |
| | top2=Top2Gate, |
| | topk=TopKGate, |
| | ) |
| |
|
| |
|
| | def get_gate( |
| | config: Ernie4_5_MoEConfig, |
| | expert: nn.Module, |
| | layer_idx: int, |
| | ) -> Tuple[nn.Module, nn.ModuleList]: |
| | """Initialize and distribute MoE (Mixture of Experts) components. |
| | |
| | Creates gate layer and distributed expert network for MoE architecture. |
| | |
| | Args: |
| | config (Ernie4_5_MoEConfig): Configuration for MoE architecture |
| | expert (nn.Module): Prototype expert network to be replicated |
| | layer_idx (int): Index of current layer in transformer stack |
| | |
| | Returns: |
| | Tuple[nn.Module, nn.ModuleList]: |
| | - gate: Initialized gate layer for routing |
| | - experts: ModuleList containing expert networks |
| | """ |
| | moe_num_experts = ( |
| | sum(config.moe_num_experts) |
| | if config.multimodel_experts |
| | else config.moe_num_experts |
| | ) |
| | experts = nn.ModuleList([]) |
| |
|
| | for expert_id, (experts_num, fc) in enumerate(expert): |
| | experts_to_append = [] |
| | if not hasattr(fc, "__len__"): |
| | experts_to_append.append(fc) |
| | if expert_id == 1: |
| | with UniqueNameGuard("_mm_deepcopy"): |
| | for _ in range(experts_num - 1): |
| | experts_to_append.append(deepcopy(fc)) |
| | else: |
| | for _ in range(experts_num - 1): |
| | experts_to_append.append(deepcopy(fc)) |
| | else: |
| | experts_to_append = fc |
| | for ex in experts_to_append: |
| | for p in ex.parameters(): |
| | p.expert_type = f"expert_type_{expert_id}" |
| | index = 0 |
| | for i in range(experts_num): |
| | if i // experts_num == 0: |
| | experts.append(experts_to_append[index]) |
| | index += 1 |
| | else: |
| | experts.append(None) |
| |
|
| | assert ( |
| | len(experts) == moe_num_experts |
| | ), f"experts.len={len(experts)} != experts_num={experts_num}" |
| | logger.info(f"MOE-GATE:-{config.moe_gate}") |
| |
|
| | gate = gate_class[config.moe_gate.lower()](config, layer_idx=layer_idx) |
| |
|
| | if config.multimodel_experts and config.moe_use_hard_gate and moe_num_experts > 2: |
| | lm_experts = experts[: config.moe_num_experts[0]] |
| | lm_gate = gate |
| | else: |
| | if config.multimodel_experts and config.moe_use_hard_gate: |
| | lm_gate, lm_experts = gate, experts |
| | else: |
| | lm_gate, lm_experts = None, None |
| |
|
| | logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") |
| |
|
| | return gate, experts, lm_gate, lm_experts |
| |
|
| |
|
| | class MoEStatics(nn.Module): |
| | """ |
| | Stores MoE (Mixture of Experts) statistics |
| | and expert usage information. |
| | """ |
| |
|
| | def __init__(self, config, layer_idx): |
| | """ |
| | Initialize MoE statistics tracking. |
| | |
| | Args: |
| | config: Model configuration containing MoE parameters |
| | layer_idx: Index of the MoE layer in the model |
| | """ |
| | super().__init__() |
| | self._cast_to_low_precision = False |
| | self._cast_to_low_precison = False |
| | num_experts = ( |
| | config.moe_num_experts[0] |
| | if config.multimodel_experts |
| | else config.moe_num_experts |
| | ) |
| | if config.multimodel_experts: |
| | assert ( |
| | len(set(config.moe_num_experts)) == 1 |
| | ), "assume expert group has same size, got: {config.moe_num_experts}" |
| |
|
| | with UniqueNameGuard(f"mm_layer_{layer_idx}_"): |
| | num_experts_groups = ( |
| | len(config.moe_num_experts) if config.multimodel_experts else 1 |
| | ) |
| | p = nn.Parameter( |
| | torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), |
| | requires_grad=False, |
| | ) |
| | self.e_score_correction_bias = p |
| | p = torch.zeros(num_experts_groups, num_experts, dtype=torch.int64) |
| | self.expert_usage = p |
| |
|
| |
|
| | def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): |
| | """ |
| | Reorders input tensor based on gate results with capacity truncation and padding. |
| | |
| | Args: |
| | x (Tensor): Input tensor of shape [Seq, Dim] |
| | dispatch_mask (Tensor): Dispatching mask of shape [Seq, 2] |
| | scatter_index (Tensor): Scatter indices of shape [Seq, 2] |
| | num_experts (int): Number of experts |
| | capacity (int): Capacity per expert |
| | |
| | Returns: |
| | Tensor: Dispatched output tensor of shape [Expert*Capacity, Dim] |
| | """ |
| | output = None |
| | orig_dtype = x.dtype |
| | scatter_index_unbound = [scatter_index[:, 0], scatter_index[:, 1]] |
| | dispatch_mask_unbound = [dispatch_mask[:, 0], dispatch_mask[:, 1]] |
| |
|
| | for i_scatter_index, i_dispatch_mask in zip( |
| | scatter_index_unbound, dispatch_mask_unbound |
| | ): |
| | updates = x * i_dispatch_mask.unsqueeze(-1).to(orig_dtype) |
| | init_output = torch.zeros( |
| | num_experts * capacity, x.shape[-1], dtype=orig_dtype, device=x.device |
| | ) |
| |
|
| | index = i_scatter_index.unsqueeze(-1).expand(-1, x.shape[-1]) |
| | if output is None: |
| | output = init_output.scatter_add(0, index, updates) |
| | else: |
| | output = output + init_output.scatter_add(0, index, updates) |
| | if output.dtype != orig_dtype: |
| | output = output.to(orig_dtype) |
| | return output |
| |
|
| |
|
| | def combining(x, combine_weights, scatter_index): |
| | """ |
| | Combines and aggregates input matrix using combination weights. |
| | |
| | Args: |
| | x (Tensor): Input tensor of shape [num_experts * capacity, dim] |
| | combine_weights (Tensor): Combination weights of shape [seq, 2] |
| | scatter_index (Tensor): Scatter indices of shape [seq, 2] |
| | |
| | Returns: |
| | Tensor: Combined output tensor of shape [seq, dim] |
| | """ |
| | dim = x.shape[-1] |
| |
|
| | current_device = scatter_index.device |
| | x = x.to(current_device) |
| | scatter_index = scatter_index.reshape([-1]) |
| | num_k = combine_weights.shape[-1] |
| |
|
| | combine_weights = combine_weights.unsqueeze(1).to(current_device) |
| |
|
| | x = x[scatter_index].reshape([-1, num_k, dim]) |
| |
|
| | return torch.matmul(combine_weights, x).squeeze( |
| | 1 |
| | ) |
| |
|
| |
|
| | class MOELayer(nn.Module): |
| | """ |
| | Mixture of Experts layer implementation based on GShard paper. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | gate: nn.Module, |
| | experts: List[nn.Module], |
| | layer_idx: int, |
| | shared_experts: Optional[List[nn.Module]] = None, |
| | group=None, |
| | recompute: bool = False, |
| | k: int = 2, |
| | all_to_all_dropout: float = 0, |
| | group_experts: bool = False, |
| | moe_statics=None, |
| | moe_num_experts=None, |
| | ): |
| | """ |
| | Initialize MoE layer. |
| | |
| | Args: |
| | gate: Gate network for expert selection |
| | experts: List of expert networks |
| | layer_idx: Index of this layer in the model |
| | group: Distributed communication group |
| | recompute: Whether to enable recomputation |
| | k: Number of experts to select per token |
| | all_to_all_dropout: Dropout rate for all-to-all communication |
| | group_experts: Whether to group experts |
| | moe_statics: MoE statistics tracking object |
| | """ |
| | super().__init__() |
| | self.gate = gate |
| | self.layer_idx = layer_idx |
| |
|
| | if isinstance(experts, nn.ModuleList): |
| | self.experts = experts |
| | else: |
| | logger.info(f"using fused experts, type={type(experts)}") |
| | self.experts = experts |
| | self.shared_experts = shared_experts |
| |
|
| | self.group = group |
| | self.k = k |
| | self.all_to_all_dropout = all_to_all_dropout |
| | self.use_correction_bias = moe_statics is not None |
| | self.moe_statics = moe_statics |
| | if self.use_correction_bias: |
| | logger.info( |
| | f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" |
| | ) |
| | assert self.gate.config.moe_use_aux_free |
| |
|
| | self.world_size = 1 |
| | self.rank = 0 |
| |
|
| | self.multimodal_experts = ( |
| | isinstance(moe_num_experts, (tuple, list)) and len(moe_num_experts) > 1 |
| | ) |
| | self.num_local_experts = len(self.experts) // self.world_size |
| | if self.multimodal_experts: |
| | self.num_local_multimodal_experts = [ |
| | num // self.world_size for num in moe_num_experts |
| | ] |
| | self.multimodal_expert_index = [0] + list( |
| | itertools.accumulate(moe_num_experts) |
| | ) |
| |
|
| | self.input_preprocess = self.output_postprocess = None |
| | self.group_experts = group_experts |
| | self.config = self.gate.config |
| | self.zero = torch.tensor(0).to(dtype=torch.float32) |
| |
|
| | def forward_experts(self, dispatched_input): |
| | """ |
| | Forward pass through experts sequentially. |
| | |
| | Args: |
| | dispatched_input: Input tensor of shape [num_experts, capacity, dim] |
| | |
| | Returns: |
| | Tensor: Expert outputs of shape [num_experts, capacity, dim] |
| | """ |
| |
|
| | if not self.multimodal_experts: |
| | true_experts = self.experts[ |
| | self.rank |
| | * self.num_local_experts : (self.rank + 1) |
| | * self.num_local_experts |
| | ] |
| | else: |
| | true_experts = [] |
| | for i, num in enumerate(self.num_local_multimodal_experts): |
| | current_modal_experts = self.experts[ |
| | self.multimodal_expert_index[i] : self.multimodal_expert_index[ |
| | i + 1 |
| | ] |
| | ] |
| | true_experts.extend( |
| | current_modal_experts[self.rank * num : (self.rank + 1) * num] |
| | ) |
| |
|
| | dispatched_input = dispatched_input.reshape( |
| | [self.world_size, self.num_local_experts, -1, dispatched_input.shape[-1]] |
| | ) |
| | current_device = dispatched_input.device |
| | expert_outputs = [] |
| | if isinstance(self.experts, nn.ModuleList): |
| | chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0) |
| | assert len(chunks) == len( |
| | true_experts |
| | ), f"{len(chunks)}, {len(true_experts)}" |
| | for chunk, expert in zip(chunks, true_experts): |
| | expert_outputs.append(expert(chunk)) |
| | else: |
| | dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous() |
| | orig_shape = dispatched_input.shape |
| | chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1]) |
| | chunks = self.experts(chunks) |
| | chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0) |
| | expert_outputs.extend(chunks) |
| |
|
| | for i, expert_output in enumerate(expert_outputs): |
| | expert_outputs[i] = expert_output.to(current_device) |
| | expert_output = torch.stack(expert_outputs, dim=1) |
| | return expert_output |
| |
|
| | def moe_gate_dispatch( |
| | self, |
| | x: torch.Tensor, |
| | gate_logits: torch.Tensor, |
| | k: int, |
| | capacity: Optional[int], |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """dispatch input to experts based on gate logits""" |
| |
|
| | S, H = x.shape |
| | E = gate_logits.shape[1] |
| | device = x.device |
| | if self.use_correction_bias: |
| | _, topk_idx = torch.topk(gate_logits + self.moe_statics.e_score_correction_bias[0].detach().to(gate_logits.device), k, dim=-1) |
| | topk_prob = torch.gather(gate_logits, dim=1, index=topk_idx) |
| | else: |
| | topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1) |
| | combine_weights = topk_prob |
| | expert_id = topk_idx |
| | y = x.new_zeros((E, capacity, H)) |
| | scatter_index = x.new_full((k, S), -1, dtype=torch.int32) |
| | |
| | slot_counter = torch.zeros(E, dtype=torch.int32, device=device) |
| |
|
| | for tok in range(S): |
| | for route in range(k): |
| | e = expert_id[tok, route].item() |
| | slot = slot_counter[e].item() |
| | if slot >= capacity: |
| | combine_weights[tok, route] = 0.0 |
| | continue |
| | |
| | scatter_index[route, tok] = e * capacity + slot |
| | y[e, slot] = x[tok] |
| | slot_counter[e] += 1 |
| |
|
| | expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64) |
| |
|
| | return y, combine_weights, scatter_index, expert_offset, expert_id |
| |
|
| | def gate_and_dispatch(self, input, token_type_ids=None, is_multimodel=True): |
| | """ |
| | Calculate gate and dispatch inputs. |
| | |
| | Args: |
| | input: Input tensor of shape [seq, dim] |
| | |
| | Returns: |
| | tuple: (dispatched_input, combine_weights, dispatch_mask, |
| | scatter_index, router_loss, gate_logits, gate_prob) |
| | """ |
| | d_model = input.shape[1] |
| | if isinstance(self.gate, (TopKGate)): |
| | capacity = self.gate.get_capacity( |
| | input.shape[0], is_multimodel=is_multimodel |
| | ) |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids.reshape([-1]) |
| | gate_logits = self.gate( |
| | input, token_type_ids=token_type_ids, is_multimodel=is_multimodel |
| | ) |
| | prob = self.gate.act(gate_logits) |
| | ( |
| | dispatched_input, |
| | combine_weights_unnorm, |
| | scatter_index, |
| | dispatch_mask, |
| | _, |
| | ) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity) |
| | dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0))) |
| |
|
| | scatter_index.detach() |
| | dispatch_mask.detach() |
| |
|
| | scatter_index = scatter_index.transpose(0, 1) |
| | combine_weights = combine_weights_unnorm / torch.clamp( |
| | combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12 |
| | ) |
| | combine_weights = combine_weights.to(dtype=dispatched_input.dtype) |
| |
|
| | else: |
| | ( |
| | capacity, |
| | dispatch_mask, |
| | combine_weights, |
| | scatter_index, |
| | router_loss, |
| | gate_logits, |
| | ) = self.gate( |
| | input, |
| | ) |
| | prob = None |
| | dispatched_input = dispatching( |
| | input, |
| | dispatch_mask, |
| | scatter_index, |
| | num_experts=self.world_size * self.num_local_experts, |
| | capacity=capacity, |
| | ) |
| |
|
| | dispatched_input = dispatched_input.reshape( |
| | [self.world_size * self.num_local_experts, capacity, d_model] |
| | ) |
| |
|
| | dispatch_mask = dispatch_mask.detach() |
| | scatter_index = scatter_index.detach() |
| | return ( |
| | dispatched_input, |
| | combine_weights, |
| | dispatch_mask, |
| | scatter_index, |
| | None, |
| | gate_logits, |
| | prob, |
| | ) |
| |
|
| | def combine_expert_output(self, expert_output, combine_weights, scatter_index): |
| | """ |
| | Combine expert outputs using combination weights. |
| | |
| | Args: |
| | expert_output: Expert outputs [num_experts, capacity, dim] |
| | combine_weights: Combination weights |
| | scatter_index: Scatter indices |
| | |
| | Returns: |
| | Tensor: Combined output [seqlen, dim] |
| | """ |
| | expert_output = expert_output.reshape( |
| | [-1, expert_output.shape[-1]] |
| | ) |
| |
|
| | combined_output = combining(expert_output, combine_weights, scatter_index) |
| |
|
| | if self.output_postprocess is not None: |
| | combined_output = self.output_postprocess(combined_output) |
| |
|
| | return combined_output |
| |
|
| | def forward( |
| | self, |
| | input: torch.Tensor, |
| | token_type_ids=None, |
| | is_multimodel=True, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Forward pass through MoE layer. |
| | |
| | Args: |
| | input: Input tensor of shape [s, d] |
| | |
| | Returns: |
| | tuple: (output, combine_weights, router_loss, gate_logits) |
| | """ |
| | if input.dim() == 3: |
| | orig_shape = input.shape |
| | input = input.reshape([-1, input.shape[-1]]) |
| | else: |
| | orig_shape = None |
| | assert ( |
| | input.dim() == 2 |
| | ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids.clone()[:, :-1] |
| |
|
| | assert self.gate is not None |
| |
|
| | gate_input = input |
| |
|
| | ( |
| | dispatched_input, |
| | combine_weights, |
| | dispatch_mask, |
| | scatter_index, |
| | router_loss, |
| | gate_logits, |
| | gate_prob, |
| | ) = self.gate_and_dispatch( |
| | gate_input, token_type_ids, is_multimodel=is_multimodel |
| | ) |
| |
|
| | if self.shared_experts is not None: |
| | shared_out = self.shared_experts(input) |
| |
|
| | expert_out = self.forward_experts(dispatched_input) |
| |
|
| | combined_output = self.combine_expert_output( |
| | expert_out, combine_weights, scatter_index |
| | ) |
| |
|
| | if self.shared_experts is not None: |
| | combined_output += shared_out |
| |
|
| | if orig_shape: |
| | combined_output = combined_output.clone().reshape( |
| | orig_shape[:-1] + (combined_output.shape[-1],) |
| | ) |
| | return combined_output, combine_weights, None, gate_logits |
| |
|
| |
|
| | class MOEAllGatherLayerV2(MOELayer): |
| | """ |
| | MoE Layer with allgather implement. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | gate: nn.Module, |
| | experts: List[nn.Module], |
| | layer_idx, |
| | shared_experts: Optional[List[nn.Module]] = None, |
| | group=None, |
| | recompute=False, |
| | k=2, |
| | enable_reverse_token_drop=False, |
| | all_to_all_dropout=0, |
| | group_experts=False, |
| | use_expert_out_alltoall=True, |
| | use_expert_alltoall_overlap=False, |
| | use_padding=True, |
| | dense_token_type=3, |
| | moe_statics=None, |
| | moe_num_experts=None, |
| | ): |
| | super().__init__( |
| | gate, |
| | experts, |
| | layer_idx, |
| | shared_experts, |
| | group, |
| | recompute, |
| | k, |
| | all_to_all_dropout, |
| | group_experts, |
| | moe_statics, |
| | moe_num_experts, |
| | ) |
| | self.enable_reverse_token_drop = enable_reverse_token_drop |
| | self.is_allgather_moe_layer = True |
| | self.use_padding = use_padding |
| |
|
| | self.send_rank = None |
| | self.local_expert_id = None |
| | self.dense_experts = None |
| | self.dense_token_type = dense_token_type |
| | self.capacity_tensor = None |
| | logger.info( |
| | f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " |
| | f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} " |
| | f"enable_reverse_token_drop={self.enable_reverse_token_drop}" |
| | ) |
| | self.two = torch.tensor(2).to(dtype=torch.float32) |
| | self.zero = torch.tensor(0).to(dtype=torch.float32) |
| |
|
| | def forward( |
| | self, |
| | input: torch.Tensor, |
| | token_type_ids=None, |
| | use_dense_expert=False, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Implements forward pass for Mixture-of-Experts (MoE) layer with distributed communication. |
| | |
| | Core Functionality: |
| | - Processes input through gating network to determine expert assignments |
| | - Combines expert outputs and calculates routing loss |
| | |
| | Key Features: |
| | 1. Supports both dense and sparse expert computation modes |
| | 2. Implements fused gating and dispatch for performance optimization |
| | 3. Handles sequence length padding/unpadding for irregular inputs |
| | 4. Enables communication-computation overlap through asynchronous operations |
| | |
| | Args: |
| | input (Tensor): Input tensor of shape [seq_len, hidden_dim] |
| | token_type_ids: Optional segmentation markers for heterogeneous inputs |
| | use_dense_expert: Flag to enable dense expert computation bypass |
| | |
| | Returns: |
| | tuple: ( |
| | combined_output: Aggregated expert outputs [seq_len, hidden_dim], |
| | combine_weights: Expert combination coefficients, |
| | ) |
| | """ |
| | use_fuse = isinstance(self.gate, (TopKGate)) |
| | assert use_fuse |
| | if input.ndim == 3: |
| | orig_shape = input.shape |
| | input = input.reshape([-1, input.shape[-1]]) |
| | else: |
| | orig_shape = None |
| |
|
| | assert ( |
| | len(input.shape) == 2 |
| | ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" |
| | dispatch_token_type_ids = None |
| | global_dense_expert_mask = None |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids[:, :-1].reshape([-1]) |
| | dispatch_token_type_ids = token_type_ids |
| | if use_dense_expert: |
| | global_dense_expert_mask = ( |
| | dispatch_token_type_ids == self.dense_token_type |
| | ) |
| |
|
| | assert self.gate is not None |
| |
|
| | ( |
| | dispatched_input, |
| | global_hidden_states, |
| | local_combine_weights, |
| | expert_num_global_no_token_drop, |
| | expert_num_global, |
| | expert_num_global_list, |
| | local_scatter_index, |
| | scatter_index_rev, |
| | router_loss, |
| | (gate_logits, gate_prob), |
| | (gate_logits_mm, gate_prob_mm), |
| | expert_num_local, |
| | ) = self.fused_gate_and_dispatch( |
| | input, token_type_ids, global_dense_expert_mask |
| | ) |
| |
|
| | seqlen_this_mp = input.shape[0] |
| | if len(scatter_index_rev): |
| | recv_rank_local = scatter_index_rev // seqlen_this_mp |
| | else: |
| | recv_rank_local = scatter_index_rev |
| |
|
| | if self.send_rank is None: |
| | capacity = self.gate.get_capacity(input.shape[0]) |
| | self.send_rank = ( |
| | torch.arange(1) |
| | .repeat_interleave(capacity * self.num_local_experts) |
| | .to(torch.int32) |
| | ) |
| | self.local_expert_id = ( |
| | torch.arange(self.num_local_experts) |
| | .repeat_interleave(capacity) |
| | .repeat(1) |
| | .to(self.send_rank.dtype) |
| | ) |
| | send_rank = self.send_rank |
| | local_expert_id = self.local_expert_id |
| |
|
| | expert_outs = self.forward_experts(*dispatched_input) |
| | for e in expert_outs: |
| | if e is not None: |
| | current_device = e.device |
| | break |
| | expert_outs = torch.cat( |
| | [e.to(current_device) for e in expert_outs if e is not None], dim=0 |
| | ) |
| |
|
| | |
| | combined_output = self.combine_expert_output( |
| | expert_outs, local_combine_weights, local_scatter_index |
| | ) |
| |
|
| | if self.shared_experts is not None: |
| | shared_out = self.shared_experts(input).to(combined_output.device) |
| | combined_output += shared_out |
| |
|
| | if orig_shape: |
| | combined_output = combined_output.reshape( |
| | *orig_shape[:-1], combined_output.shape[-1] |
| | ) |
| |
|
| | return combined_output, local_combine_weights, None, gate_logits |
| |
|
| | def _expand_modality_expert_id( |
| | self, |
| | expert_id: torch.Tensor, |
| | seqlen: int, |
| | k: int, |
| | num_expert_per_modality: int, |
| | group_size: int, |
| | modality_offset: int, |
| | is_group_expert: bool, |
| | ) -> torch.Tensor: |
| | """ |
| | expert_id: tensor of shape (seqlen, k), containing expert ids |
| | Returns: tensor of same shape, with updated expert ids |
| | """ |
| | device = expert_id.device |
| | expert_id = expert_id.clone() |
| |
|
| | if is_group_expert: |
| | |
| | offsets = (torch.arange(k, device=device) * group_size).view( |
| | 1, k |
| | ) |
| | expert_id += offsets |
| |
|
| | if num_expert_per_modality <= 0: |
| | return expert_id |
| |
|
| | |
| | rank = expert_id // num_expert_per_modality |
| | expert_id_in_rank = expert_id % num_expert_per_modality |
| |
|
| | |
| | expert_id_out = ( |
| | rank * (num_expert_per_modality * 2) |
| | + expert_id_in_rank |
| | + modality_offset * num_expert_per_modality |
| | ) |
| |
|
| | return expert_id_out |
| |
|
| | def expand_modality_expert_id( |
| | self, |
| | expert_id, |
| | num_expert_per_modality, |
| | group_size, |
| | modality_offset, |
| | is_group_expert, |
| | ): |
| | """expand expert id for modality aware moe layer""" |
| | seq_len, k = expert_id.shape |
| |
|
| | return self._expand_modality_expert_id( |
| | expert_id, |
| | seq_len, |
| | k, |
| | num_expert_per_modality, |
| | group_size, |
| | modality_offset, |
| | is_group_expert, |
| | ) |
| |
|
| | def fused_gate_logits_process_fused( |
| | self, gate_logits_lm, gate_logits_mm=None, token_type_ids=None |
| | ): |
| | """Process gating logits for expert selection in Mixture-of-Experts (MoE) layers. |
| | |
| | Core Functionality: |
| | - Transforms raw gating logits into expert selection weights and IDs |
| | - Supports both grouped and standard expert selection modes |
| | - Handles bias correction for improved expert load balancing |
| | |
| | Args: |
| | gate_logits_lm (Tensor): Raw gating scores of shape [batch_size, total_experts] |
| | |
| | Returns: |
| | tuple: ( |
| | lm_weight_and_expert_id: Combined tensor containing selection weights |
| | and expert IDs [batch_size, 2*top_k], |
| | prob_flat: Flattened expert probabilities [batch_size, total_experts] |
| | ) |
| | """ |
| | top_k = self.k |
| | num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] |
| | group_size = gate_logits_lm.shape[-1] // top_k |
| | if self.group_experts: |
| | assert not self.use_correction_bias |
| | gate_logits_lm = gate_logits_lm.reshape( |
| | [gate_logits_lm.shape[0], top_k, -1] |
| | ) |
| | prob_lm = self.gate.act(gate_logits_lm) |
| | prob_lm_ = prob_lm |
| | weight_lm, expert_id_lm = prob_lm_.topk(k=1, dim=-1) |
| | weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) |
| | group_size = gate_logits_lm.shape[-1] |
| | expert_id_lm = expert_id_lm.squeeze(-1) |
| | else: |
| | prob_lm = self.gate.act(gate_logits_lm) |
| | if self.use_correction_bias: |
| | prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[ |
| | 0 |
| | ].detach().to(prob_lm.device) |
| | else: |
| | prob_lm_ = prob_lm |
| | weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, dim=-1) |
| |
|
| | if self.use_correction_bias: |
| | batch_idx = ( |
| | torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) |
| | ) |
| | weight_lm = prob_lm[batch_idx, expert_id_lm] |
| |
|
| | expert_id_lm = self.expand_modality_expert_id( |
| | expert_id_lm, |
| | num_expert_per_modality=( |
| | num_expert_per_rank_per_modality if token_type_ids is not None else 0 |
| | ), |
| | group_size=group_size, |
| | modality_offset=0, |
| | is_group_expert=self.group_experts, |
| | ) |
| | expert_id_lm = expert_id_lm.reshape(weight_lm.shape) |
| | lm_weight_and_expert_id = torch.cat( |
| | [weight_lm, expert_id_lm.to(torch.float32)], -1 |
| | ) |
| |
|
| | if token_type_ids is None or gate_logits_mm is None: |
| | return ( |
| | lm_weight_and_expert_id, |
| | prob_lm.reshape([prob_lm.shape[0], -1]), |
| | None, |
| | ) |
| |
|
| | prob_mm = self.gate.act(gate_logits_mm) |
| | if self.use_correction_bias: |
| | prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[ |
| | 1 |
| | ].detach().to(prob_lm.device) |
| | else: |
| | prob_mm_ = prob_mm |
| | weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, dim=-1) |
| | if self.use_correction_bias: |
| | batch_idx = ( |
| | torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) |
| | ) |
| | weight_mm = prob_mm[batch_idx, expert_id_mm] |
| |
|
| | expert_id_mm = self.expand_modality_expert_id( |
| | expert_id_mm, |
| | num_expert_per_modality=num_expert_per_rank_per_modality, |
| | group_size=group_size, |
| | modality_offset=1, |
| | is_group_expert=False, |
| | ) |
| | expert_id_mm = expert_id_mm.reshape(weight_mm.shape) |
| | mm_weight_and_expert_id = torch.cat( |
| | [weight_mm, expert_id_mm.to(torch.float32)], -1 |
| | ) |
| | weight_and_expert = torch.where( |
| | (token_type_ids == 0).unsqueeze(-1), |
| | lm_weight_and_expert_id.to(token_type_ids.device), |
| | mm_weight_and_expert_id.to(token_type_ids.device), |
| | ) |
| | return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm |
| |
|
| | def moe_gate_dispatch_partial_nosoftmaxtopk( |
| | self, |
| | x, |
| | combine_weights, |
| | expert_id, |
| | k, |
| | num_experts, |
| | ): |
| | """ |
| | MoE Gate Dispatch kernel |
| | """ |
| | device = x.device |
| | dtype = x.dtype |
| | num_rows, hidden_size = x.shape |
| | k = expert_id.shape[1] |
| | expert_ids_flat = expert_id.reshape(-1) |
| | combine_weights_flat = combine_weights.reshape(-1) |
| |
|
| | expanded_token_ids = torch.arange(num_rows * k, device=device) |
| |
|
| | sorted_expert_ids, sorted_indices = torch.sort(expert_ids_flat, stable=True) |
| | sorted_indices = sorted_indices.to(expanded_token_ids.device) |
| |
|
| | sorted_expanded_token_ids = expanded_token_ids[sorted_indices] |
| |
|
| | expert_nums_local = torch.zeros(num_experts, dtype=torch.int64, device=device) |
| |
|
| | for expert_idx in range(num_experts): |
| | count = (sorted_expert_ids == expert_idx).sum().item() |
| | expert_nums_local[expert_idx] = count |
| |
|
| | total_dispatched_tokens = torch.cumsum(expert_nums_local, dim=0)[-1].item() |
| |
|
| | y = x[sorted_indices // k] |
| |
|
| | scatter_index = torch.full((k, num_rows), -1, dtype=torch.int32, device=device) |
| |
|
| | for i, (expanded_idx, sorted_pos) in enumerate( |
| | zip(sorted_expanded_token_ids, range(total_dispatched_tokens)) |
| | ): |
| | token_idx = expanded_idx // k |
| | k_idx = expanded_idx % k |
| | scatter_index[k_idx, token_idx] = sorted_pos |
| |
|
| | scatter_index_rev = sorted_indices // k |
| |
|
| | combine_weights_out = combine_weights.clone() |
| |
|
| | return ( |
| | y, |
| | combine_weights_out, |
| | scatter_index, |
| | scatter_index_rev, |
| | expert_nums_local, |
| | expert_nums_local, |
| | ) |
| |
|
| | def fused_gate_and_dispatch( |
| | self, input, token_type_ids=None, global_dense_expert_mask=None |
| | ): |
| | """Implements fused expert gating and token dispatch logic for Mixture-of-Experts (MoE) layers. |
| | |
| | Core Functionality: |
| | - Computes expert selection probabilities and routing weights |
| | - Performs distributed token-to-expert assignment |
| | - Handles communication and synchronization in model-parallel environments |
| | |
| | Args: |
| | input (Tensor): Input tensor of shape [seq_len, hidden_dim] |
| | |
| | Returns: |
| | tuple: ( |
| | dispatched_input: Expert-assigned tokens [num_experts, capacity, hidden_dim], |
| | global_hidden_states: Full sequence representations, |
| | local_combine_weights: Local expert combination weights, |
| | expert_num_global_notrunc: Global expert token counts (without capacity truncation), |
| | expert_num_global: Actual expert token counts, |
| | expert_num_global_list: Per-expert token counts, |
| | local_scatter_index: Local token reorganization indices, |
| | scatter_index_rev: Reverse scattering indices, |
| | router_loss: Calculated routing loss, |
| | gate_outputs: Raw gating network outputs, |
| | expert_num_local: Local expert utilization counts |
| | ) |
| | """ |
| | seqlen, d_model = input.shape |
| | args = () |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids.reshape([-1]) |
| | args = (token_type_ids,) |
| |
|
| | router_loss = torch.zeros([1], dtype=torch.float32) |
| | top_k = self.k |
| |
|
| | def build_weights_and_expert_id(input): |
| | nonlocal token_type_ids, args |
| | logits = self.gate(input, *args, transform_weight=False) |
| | if self.config.multimodel_experts: |
| | gate_logits_lm, gate_logits_mm = logits.chunk(2, dim=-1) |
| | else: |
| | gate_logits_lm, gate_logits_mm = logits, None |
| |
|
| | weigth_and_expert, gate_prob_lm, gate_prob_mm = ( |
| | self.fused_gate_logits_process_fused( |
| | gate_logits_lm, |
| | gate_logits_mm, |
| | token_type_ids if global_dense_expert_mask is None else None, |
| | ) |
| | ) |
| | return ( |
| | weigth_and_expert, |
| | gate_logits_lm, |
| | gate_logits_mm, |
| | gate_prob_lm, |
| | gate_prob_mm, |
| | ) |
| |
|
| | capacity = self.gate.get_capacity(input.shape[0]) * self.world_size |
| | global_hidden_states = input |
| | ( |
| | combine_weights_and_expert_id, |
| | gate_logits_lm, |
| | gate_logits_mm, |
| | gate_prob_lm, |
| | gate_prob_mm, |
| | ) = build_weights_and_expert_id(input) |
| |
|
| | combine_weights_unnorm, expert_id = combine_weights_and_expert_id.chunk( |
| | 2, dim=-1 |
| | ) |
| | expert_id = expert_id.to(torch.int32) |
| | num_experts = ( |
| | sum(self.config.moe_num_experts) |
| | if isinstance(self.config.moe_num_experts, (tuple, list)) |
| | else self.config.moe_num_experts |
| | ) |
| | if global_dense_expert_mask is not None: |
| | combine_weights_unnorm[global_dense_expert_mask] = 0.0 |
| | expert_id[global_dense_expert_mask] = num_experts |
| | num_experts += 1 |
| |
|
| | ( |
| | dispatched_input, |
| | combine_weights_unnorm, |
| | scatter_index, |
| | scatter_index_rev, |
| | expert_num_global, |
| | expert_num_local, |
| | ) = self.moe_gate_dispatch_partial_nosoftmaxtopk( |
| | global_hidden_states, |
| | combine_weights_unnorm, |
| | expert_id, |
| | top_k, |
| | num_experts, |
| | ) |
| |
|
| | if self.use_correction_bias: |
| | if self.gate.config.multimodel_experts: |
| | |
| | for i in range(len(self.moe_statics.expert_usage)): |
| | self.moe_statics.expert_usage[i] += ( |
| | expert_num_local[self.gate.experts_type_mask[i]] |
| | .detach() |
| | .to(self.moe_statics.expert_usage.device) |
| | ) |
| | else: |
| | |
| | self.moe_statics.expert_usage[0] += expert_num_local.detach().to( |
| | self.moe_statics.expert_usage.device |
| | ) |
| |
|
| | |
| | if scatter_index_rev.ndim == 0: |
| | assert not self.use_padding |
| | scatter_index_rev = torch.empty([0], dtype=scatter_index_rev.dtype) |
| |
|
| | expert_num_global_notrunc = expert_num_global |
| | self.capacity_tensor = torch.tensor(capacity).to(dtype=expert_num_global.dtype) |
| | expert_num_global = torch.minimum(expert_num_global, self.capacity_tensor) |
| |
|
| | if global_dense_expert_mask is not None: |
| | expert_num_global = expert_num_global[:-1] |
| | expert_num_local = expert_num_local[:-1] |
| | expert_num_global_notrunc = expert_num_global_notrunc[:-1] |
| |
|
| | scatter_index = scatter_index.transpose(1, 0) |
| | scatter_index = scatter_index.to(combine_weights_unnorm.device) |
| |
|
| | last_local_expert = 0 |
| | expert_offset_global = expert_num_global.cumsum(-1) |
| |
|
| | expert_num_global_list = expert_num_global |
| | if self.use_padding: |
| | offset = last_local_expert * capacity |
| | else: |
| | offset = 0 |
| | local_combine_weights_unnorm = combine_weights_unnorm.contiguous() |
| | local_scatter_index = torch.where( |
| | combine_weights_unnorm > 0.0, |
| | scatter_index + offset, |
| | scatter_index, |
| | ) |
| | if self.gate.norm_gate_logits: |
| | local_combine_weights = local_combine_weights_unnorm / torch.clip( |
| | local_combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 |
| | ) |
| | else: |
| | local_combine_weights = local_combine_weights_unnorm |
| | local_combine_weights = local_combine_weights.to(dispatched_input.dtype) |
| | if self.use_padding: |
| | dispatched_input = dispatched_input.reshape( |
| | [self.num_local_experts, -1, d_model] |
| | ) |
| | dispatched_input = dispatched_input.unbind(0) |
| | else: |
| | s = 0 |
| | e = self.num_local_experts |
| | expert_num_local = expert_num_local.tolist()[s:e] |
| | expert_num_local_valid = [i for i in expert_num_local if i > 0] |
| | valid_pos = [j for j, i in enumerate(expert_num_local) if i > 0] |
| | if expert_num_local_valid: |
| | dispatched_input_list = dispatched_input.split(expert_num_local_valid) |
| | dispatched_input = [None] * len(expert_num_local) |
| | for p, t in zip(valid_pos, dispatched_input_list): |
| | dispatched_input[p] = t |
| | else: |
| | dispatched_input = [dispatched_input] + ( |
| | [None] * (len(expert_num_local) - 1) |
| | ) |
| |
|
| | expert_num_global_list = expert_num_global_list.tolist() |
| |
|
| | return ( |
| | dispatched_input, |
| | global_hidden_states, |
| | local_combine_weights, |
| | expert_num_global_notrunc, |
| | expert_num_global, |
| | expert_num_global_list, |
| | local_scatter_index, |
| | scatter_index_rev, |
| | router_loss, |
| | (gate_logits_lm, gate_prob_lm), |
| | (gate_logits_mm, gate_prob_mm), |
| | expert_num_local, |
| | ) |
| |
|
| | def forward_experts(self, *dispatched_input): |
| | """Execute expert model computations in sequence for Mixture-of-Experts (MoE) layer. |
| | |
| | Core Functionality: |
| | - Distributes dispatched tokens to local expert models |
| | - Handles empty expert inputs with zero-initialized fallback |
| | - Maintains gradient flow for expert outputs |
| | - Aggregates outputs from all active experts |
| | |
| | Args: |
| | *dispatched_input: Variable-length expert-specific input tensors |
| | |
| | Returns: |
| | list: Expert output tensors (None for inactive experts) |
| | |
| | Implementation Details: |
| | 1. Processes valid expert inputs through corresponding expert models |
| | 2. Generates dummy inputs for inactive experts to preserve model structure |
| | 3. Aggregates dummy outputs to first active expert to maintain gradient flow |
| | """ |
| | expert_outputs = [] |
| | assert isinstance(self.experts, nn.ModuleList), type(self.experts) |
| |
|
| | no_tokens_expert_outputs = [] |
| | true_experts = self.experts[ |
| | self.rank |
| | * self.num_local_experts : (self.rank + 1) |
| | * self.num_local_experts |
| | ] |
| | for iexpert, chunk in enumerate(dispatched_input): |
| | if chunk is None: |
| | expert_outputs.append(None) |
| | continue |
| |
|
| | expert_out = true_experts[iexpert](chunk.contiguous()) |
| | expert_outputs.append(expert_out) |
| |
|
| | if len(no_tokens_expert_outputs) > 0: |
| | first_has_tokens_idx = 0 |
| | for idx, expert_out in enumerate(expert_outputs): |
| | if expert_out is not None: |
| | first_has_tokens_idx = idx |
| | break |
| | for idx, expert_out in enumerate(no_tokens_expert_outputs): |
| | expert_outputs[first_has_tokens_idx] += expert_out |
| |
|
| | return expert_outputs |
| |
|
| |
|
| | class Ernie4_5_DecoderLayer(nn.Module): |
| | """A single transformer decoder layer in ERNIE-MoE model. |
| | |
| | Contains self-attention and feed-forward components with optional MoE (Mixture of Experts) |
| | support, residual connections, and layer normalization. |
| | """ |
| |
|
| | _keep_in_fp32_modules = ["mlp.gate", "e_score_correction_bias"] |
| |
|
| | def __init__(self, config, layer_idx): |
| | """Initialize the decoder layer. |
| | |
| | Args: |
| | config (Ernie4_5_MoEConfig): Model configuration. |
| | layer_idx (int): Index of this layer in the transformer stack |
| | """ |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | self.layer_idx = layer_idx |
| | self.config = config |
| | self.use_moe = config.use_moe |
| | self.self_attn = Ernie4_5_Attention(config, layer_idx) |
| |
|
| | moe_layer_start_index = ( |
| | min(config.moe_layer_start_index) |
| | if isinstance(config.moe_layer_start_index, (tuple, list)) |
| | else config.moe_layer_start_index |
| | ) |
| | moe_layer_end_index = ( |
| | max(config.moe_layer_end_index) |
| | if isinstance(config.moe_layer_end_index, (tuple, list)) |
| | else config.moe_layer_end_index |
| | ) |
| |
|
| | if ( |
| | self.use_moe |
| | and ((layer_idx + 1) % config.moe_layer_interval == 0) |
| | and layer_idx >= moe_layer_start_index |
| | and layer_idx <= moe_layer_end_index |
| | ): |
| | gate, experts, lm_gate, lm_experts, moe_statics = ( |
| | self._init_gate_and_experts(layer_idx) |
| | ) |
| | shared_experts = ( |
| | self._init_shared_experts() |
| | if hasattr(config, "moe_num_shared_experts") |
| | else None |
| | ) |
| |
|
| | dense_experts = None |
| | moe_cls = MOELayer |
| | if config.moe_multimodal_dispatch_use_allgather: |
| | logger.info("Enable MOEAllGatherLayerV2!") |
| | moe_cls = partial( |
| | MOEAllGatherLayerV2, |
| | use_expert_out_alltoall="alltoall" |
| | in config.moe_multimodal_dispatch_use_allgather, |
| | use_padding=False, |
| | enable_reverse_token_drop=config.moe_reverse_token_drop, |
| | dense_token_type=config.moe_dense_experts_token_type_id, |
| | ) |
| | else: |
| | assert ( |
| | dense_experts is None |
| | ), "only `MOEAllGatherLayerV2` can process dense experts" |
| |
|
| | self.mlp = moe_cls( |
| | gate=gate, |
| | experts=experts, |
| | layer_idx=layer_idx, |
| | shared_experts=shared_experts, |
| | group=config.moe_group, |
| | recompute=False, |
| | k=config.moe_k, |
| | all_to_all_dropout=config.moe_all_to_all_dropout, |
| | group_experts=False, |
| | moe_statics=moe_statics, |
| | moe_num_experts=config.moe_num_experts, |
| | ) |
| |
|
| | _mlp_text = MOELayer( |
| | gate=lm_gate, |
| | experts=lm_experts, |
| | layer_idx=layer_idx, |
| | shared_experts=shared_experts, |
| | group=config.moe_group, |
| | recompute=False, |
| | k=config.moe_k, |
| | all_to_all_dropout=config.moe_all_to_all_dropout, |
| | group_experts=False, |
| | moe_statics=moe_statics, |
| | moe_num_experts=config.moe_num_experts, |
| | ) |
| | self.mlp_text = ( |
| | lambda: _mlp_text |
| | ) |
| | else: |
| | self.mlp = Ernie4_5_MLP(config) |
| |
|
| | Norm = RMSNorm |
| |
|
| | self.input_layernorm = Norm(config) |
| | self.post_attention_layernorm = Norm(config) |
| |
|
| | self.residual_add1 = FusedDropoutImpl( |
| | config.hidden_dropout_prob, mode="upscale_in_train" |
| | ) |
| | self.residual_add2 = FusedDropoutImpl( |
| | config.hidden_dropout_prob, mode="upscale_in_train" |
| | ) |
| |
|
| | def _init_shared_experts(self): |
| | """init shared experts |
| | |
| | Returns: |
| | _type_: _description_ |
| | """ |
| | cfg = deepcopy(self.config) |
| | if cfg.moe_num_shared_experts > 0: |
| | if cfg.moe_intermediate_size: |
| | inter_size = ( |
| | next(iter(cfg.moe_intermediate_size)) |
| | if isinstance(cfg.moe_intermediate_size, (tuple, list)) |
| | else cfg.moe_intermediate_size |
| | ) |
| | cfg.intermediate_size = inter_size * cfg.moe_num_shared_experts |
| | else: |
| | cfg.intermediate_size = ( |
| | cfg.intermediate_size * cfg.moe_num_shared_experts |
| | ) |
| | cfg.disable_ffn_model_parallel = False |
| | shared_experts = Ernie4_5_MoeMLP(cfg, True) |
| | else: |
| | shared_experts = None |
| | return shared_experts |
| |
|
| | def _init_gate_and_experts(self, layer_idx): |
| | """Initialize MoE gate and expert networks. |
| | |
| | Args: |
| | layer_idx (int): Current layer index |
| | |
| | Returns: |
| | Tuple: Contains: |
| | - gate: MoE routing gate |
| | - experts: List of expert networks |
| | - moe_statics: Optional statistics tracker |
| | """ |
| | cfg = deepcopy(self.config) |
| | fc_cls = Ernie4_5_MoeMLP |
| | if cfg.moe_intermediate_size: |
| | if isinstance(cfg.moe_intermediate_size, (tuple, list)): |
| | assert isinstance(cfg.moe_num_experts, (tuple, list)) and len( |
| | cfg.moe_num_experts |
| | ) == len(cfg.moe_intermediate_size) |
| | fc = [] |
| | for _i, (num_experts, intermediate_size) in enumerate( |
| | zip(cfg.moe_num_experts, cfg.moe_intermediate_size) |
| | ): |
| | ex_cfg = deepcopy(cfg) |
| | ex_cfg.intermediate_size = intermediate_size |
| | cur_modality_start_layer_idx = ( |
| | cfg.moe_layer_start_index[_i] |
| | if isinstance(cfg.moe_layer_start_index, (tuple, list)) |
| | else cfg.moe_layer_start_index |
| | ) |
| | cur_modality_end_layer_idx = ( |
| | cfg.moe_layer_end_index[_i] |
| | if isinstance(cfg.moe_layer_end_index, (tuple, list)) |
| | else cfg.moe_layer_end_index |
| | ) |
| | if ( |
| | layer_idx >= cur_modality_start_layer_idx |
| | and layer_idx <= cur_modality_end_layer_idx |
| | ): |
| | if _i == 1: |
| | with UniqueNameGuard(f"mm_expert_{layer_idx}_") as guard: |
| | fc.append((num_experts, fc_cls(ex_cfg))) |
| | else: |
| | fc.append((num_experts, fc_cls(ex_cfg))) |
| | else: |
| | logger.info( |
| | f"moe multimodal experts use Identity layer_idx: {layer_idx}" |
| | ) |
| | fc.append((num_experts, nn.Identity())) |
| | else: |
| | cfg.intermediate_size = cfg.moe_intermediate_size |
| | fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))] |
| | else: |
| | fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))] |
| | if cfg.multimodel_experts: |
| | gate, experts, lm_gate, lm_experts = get_gate(self.config, fc, layer_idx) |
| | else: |
| | gate, experts = get_gate(self.config, fc, layer_idx) |
| | lm_gate, lm_experts = None, None |
| |
|
| | |
| | if cfg.moe_use_aux_free: |
| | moe_statics = MoEStatics(cfg, layer_idx) |
| | else: |
| | moe_statics = None |
| | return gate, experts, lm_gate, lm_experts, moe_statics |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attn_mask_start_row_indices: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = False, |
| | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| | use_cache: Optional[bool] = False, |
| | output_gate_logits=True, |
| | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| | """Forward pass through the decoder layer. |
| | |
| | Args: |
| | hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size] |
| | attention_mask (Optional[torch.Tensor]): Attention mask tensor |
| | attn_mask_start_row_indices (Optional[torch.Tensor]): Indices for variable length attention |
| | position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings |
| | output_attentions (Optional[bool]): Whether to return attention weights |
| | past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states |
| | use_cache (Optional[bool]): Whether to cache key/value states |
| | output_gate_logits (bool): Whether to return MoE gate logits |
| | |
| | Returns: |
| | Union: Various output combinations depending on arguments: |
| | - Base case: Hidden states tensor |
| | - With attention: Tuple of (hidden_states, attention_weights) |
| | - With cache: Tuple of (hidden_states, cached_key_value) |
| | - With MoE: May include gate logits in output tuple |
| | """ |
| | residual = hidden_states |
| |
|
| | if token_type_ids is not None: |
| | is_multimodel_token = token_type_ids.any() |
| | has_dense_experts_token = ( |
| | token_type_ids == self.config.moe_dense_experts_token_type_id |
| | ).any() |
| | is_multimodel_token_cpu = is_multimodel_token.cpu() |
| | has_dense_experts_token_cpu = has_dense_experts_token.cpu() |
| | else: |
| | is_multimodel_token_cpu = None |
| | has_dense_experts_token_cpu = None |
| |
|
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | |
| | (hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = ( |
| | self.self_attn( |
| | hidden_states=hidden_states, |
| | past_key_value=past_key_value, |
| | attention_mask=attention_mask, |
| | attn_mask_start_row_indices=attn_mask_start_row_indices, |
| | position_ids=position_ids, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | token_type_ids=token_type_ids, |
| | ) |
| | ) |
| | hidden_states = self.residual_add1(hidden_states, residual) |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| |
|
| | if isinstance(self.mlp, MOELayer): |
| | if is_multimodel_token_cpu: |
| | hidden_states, _, router_loss, gate_logits = self.mlp( |
| | hidden_states, token_type_ids |
| | ) |
| | else: |
| | hidden_states, _, router_loss, gate_logits = self.mlp_text()( |
| | hidden_states, None, is_multimodel=False |
| | ) |
| | else: |
| | hidden_states = self.mlp(hidden_states) |
| | gate_logits, router_loss = None, None |
| |
|
| | hidden_states = self.residual_add2(hidden_states, residual) |
| |
|
| | outputs = (hidden_states,) |
| |
|
| | if output_attentions: |
| | outputs += (self_attn_weights,) |
| |
|
| | if use_cache: |
| | outputs += (present_key_value,) |
| |
|
| | if self.use_moe: |
| | |
| | if router_loss_attn: |
| | router_loss_attn = router_loss_attn[0] |
| | router_loss = router_loss + router_loss_attn |
| |
|
| | if output_gate_logits: |
| | outputs += (gate_logits,) |
| |
|
| | |
| | if type(outputs) is tuple and len(outputs) == 1: |
| | outputs = outputs[0] |
| |
|
| | return outputs |
| |
|
| |
|
| | class Ernie4_5_PretrainedModel(PreTrainedModel): |
| | """Base class for ERNIE pretrained models.""" |
| |
|
| | config_class = Ernie4_5_MoEConfig |
| | base_model_prefix = "ernie" |
| | _no_split_modules = ["Ernie4_5_DecoderLayer"] |
| |
|
| |
|
| | class Ernie4_5_Model(Ernie4_5_PretrainedModel): |
| | """The core ERNIE transformer model with MoE (Mixture of Experts) support.""" |
| |
|
| | def __init__(self, config: Ernie4_5_MoEConfig): |
| | """Initialize the ERNIE model architecture. |
| | |
| | Args: |
| | config (Ernie4_5_MoEConfig): Model configuration. |
| | """ |
| | super().__init__(config) |
| | self.padding_idx = config.pad_token_id |
| | self.vocab_size = config.vocab_size |
| | self.hidden_size = config.hidden_size |
| | self.config = config |
| |
|
| | self.embed_tokens = nn.Embedding( |
| | self.vocab_size, |
| | self.hidden_size, |
| | ) |
| |
|
| | self.layers = nn.ModuleList( |
| | [Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
| | ) |
| | Norm = RMSNorm |
| | self.norm = Norm(config) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def get_input_embeddings(self): |
| | """Get the input embedding layer. |
| | |
| | Returns: |
| | nn.Embedding: The embedding layer for input tokens |
| | """ |
| | return self.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | """Set new input embeddings. |
| | |
| | Args: |
| | value (nn.Embedding): New embedding layer to use |
| | """ |
| | self.embed_tokens = value |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | position_ids=None, |
| | token_type_ids=None, |
| | attention_mask=None, |
| | attn_mask_start_row_indices=None, |
| | inputs_embeds=None, |
| | use_cache=None, |
| | past_key_values=None, |
| | output_attentions=False, |
| | output_hidden_states=None, |
| | return_dict=False, |
| | ): |
| | """Forward pass through the ERNIE model. |
| | |
| | Args: |
| | input_ids (Optional[torch.Tensor]): Input token IDs |
| | position_ids (Optional[torch.Tensor]): Position indices |
| | attention_mask (Optional[torch.Tensor]): Attention mask |
| | attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices |
| | inputs_embeds (Optional[torch.Tensor]): Precomputed embeddings |
| | use_cache (Optional[bool]): Whether to cache key/value states |
| | past_key_values (Optional[Tuple[Tuple[torch.Tensor]]]): Cached key/value states |
| | output_attentions (Optional[bool]): Whether to output attention weights |
| | output_hidden_states (Optional[bool]): Whether to output all hidden states |
| | return_dict (Optional[bool]): Whether to return dict or tuple |
| | |
| | Returns: |
| | Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
| | Various outputs depending on configuration, including: |
| | - last_hidden_state: Final layer hidden states |
| | - past_key_values: Cached key/value states if use_cache=True |
| | - hidden_states: All hidden states if output_hidden_states=True |
| | - attentions: Attention weights if output_attentions=True |
| | - router_loss: MoE router loss if use_moe=True |
| | - gate_logits: MoE gate logits if use_moe=True |
| | """ |
| | output_attentions = ( |
| | output_attentions |
| | if output_attentions is not None |
| | else self.config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states |
| | if output_hidden_states is not None |
| | else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| |
|
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | |
| | if input_ids is not None and inputs_embeds is not None: |
| | raise ValueError( |
| | "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" |
| | ) |
| | elif input_ids is not None: |
| | _, seq_length = input_ids.shape |
| | elif inputs_embeds is not None: |
| | _, seq_length, _ = inputs_embeds.shape |
| | else: |
| | raise ValueError( |
| | "You have to specify either decoder_input_ids or decoder_inputs_embeds" |
| | ) |
| |
|
| | if past_key_values is None: |
| | past_key_values = tuple([None] * len(self.layers)) |
| |
|
| | seq_length_with_past = seq_length |
| | cache_length = 0 |
| | if past_key_values[0] is not None: |
| | cache_length = past_key_values[0][0].shape[1] |
| | seq_length_with_past += cache_length |
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| |
|
| | inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype) |
| |
|
| | hidden_states = inputs_embeds |
| |
|
| | |
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| | next_decoder_cache = () if use_cache else None |
| | if getattr(self.config, "use_moe", False): |
| | all_router_loss = torch.tensor(0.0).to(device=inputs_embeds.device) |
| | else: |
| | all_router_loss = None |
| | all_gate_logits = () |
| |
|
| | for idx, (decoder_layer) in enumerate(self.layers): |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | past_key_value = ( |
| | past_key_values[idx] if past_key_values is not None else None |
| | ) |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | attention_mask, |
| | attn_mask_start_row_indices, |
| | position_ids, |
| | token_type_ids, |
| | output_attentions, |
| | past_key_value, |
| | use_cache, |
| | ) |
| |
|
| | if isinstance(layer_outputs, (tuple, list)): |
| | hidden_states = layer_outputs[0] |
| | else: |
| | hidden_states = layer_outputs |
| |
|
| | if use_cache: |
| | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| | if self.config.use_moe: |
| | layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] |
| | all_gate_logits = all_gate_logits + (gate_logits,) |
| |
|
| | if past_key_value is not None: |
| | hidden_states = hidden_states[:, -1:, :] |
| |
|
| | hidden_states = self.norm(hidden_states) |
| |
|
| | |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | next_cache = next_decoder_cache if use_cache else None |
| |
|
| | if not return_dict: |
| | return tuple( |
| | v |
| | for v in [ |
| | hidden_states, |
| | next_cache, |
| | all_hidden_states, |
| | all_self_attns, |
| | all_router_loss, |
| | all_gate_logits, |
| | ] |
| | if v is not None |
| | ) |
| |
|
| | |
| | return BaseModelOutputWithPastAndCrossAttentions( |
| | last_hidden_state=hidden_states, |
| | past_key_values=next_cache, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | cross_attentions=None, |
| | router_loss=all_router_loss, |
| | gate_logits=all_gate_logits, |
| | ) |
| |
|
| |
|
| | def parallel_matmul( |
| | x, |
| | y, |
| | bias=None, |
| | transpose_y=False, |
| | ): |
| | """ |
| | Performs parallel matrix multiplication with tensor model parallelism support. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor with shape [batch_size, seq_len, hidden_size] |
| | y (Union[torch.Tensor, EagerParamBase]): Weight matrix which can be: |
| | - Regular tensor |
| | - Distributed parameter in tensor parallel mode |
| | bias (Optional[torch.Tensor]): Optional bias tensor |
| | transpose_y (bool): Whether to transpose the 'y' matrix before multiplication |
| | # tensor_parallel_degree (int): Degree of tensor model parallelism (default: 1) |
| | # tensor_parallel_output (bool): Whether to keep output in tensor parallel format |
| | or gather across devices (default: True) |
| | fuse_linear (bool): Whether to use fused linear operation for optimization |
| | |
| | Returns: |
| | torch.Tensor |
| | |
| | Raises: |
| | AssertionError: If tensor parallel is enabled but weight is not distributed |
| | AttributeError: If called without distributed.launch context |
| | """ |
| | if transpose_y: |
| | logits = torch.matmul(x, y.T) |
| | else: |
| | logits = torch.matmul(x, y) |
| | if bias is not None: |
| | logits += bias |
| | return logits |
| |
|
| |
|
| | def calc_lm_head_logits( |
| | config, hidden_states, weight, bias, tensor_parallel_output=None, training=True |
| | ): |
| | """ |
| | Calculate language model head logits with support for various parallelization strategies. |
| | |
| | This is the core function that computes the final output logits for a language model, |
| | handling sequence parallelism and tensor parallelism configurations. |
| | |
| | Args: |
| | config (Ernie4_5_Config): Model configuration. |
| | hidden_states (Tensor): Hidden states from the transformer layers |
| | weight (Tensor): Weight matrix for the language model head |
| | bias (Tensor): Bias vector for the language model head |
| | tensor_parallel_output (bool, optional): Override for tensor parallel output behavior. |
| | If None, uses config.tensor_parallel_output. |
| | Defaults to None. |
| | training (bool, optional): Whether in training mode. Defaults to True. |
| | |
| | Returns: |
| | Tensor: The computed logits for language modeling. |
| | """ |
| | if tensor_parallel_output is None: |
| | tensor_parallel_output = config.tensor_parallel_output |
| | logits = parallel_matmul( |
| | hidden_states, |
| | weight, |
| | bias=bias, |
| | transpose_y=config.tie_word_embeddings, |
| | ) |
| |
|
| | return logits |
| |
|
| |
|
| | def calc_multimodal_logits( |
| | last_hidden_state: torch.Tensor, |
| | lm_head_weight: torch.Tensor, |
| | lm_head_bias: torch.Tensor, |
| | mm_head_weight: torch.Tensor, |
| | mm_head_bias: torch.Tensor, |
| | token_type_ids_shifted: torch.Tensor, |
| | config: Ernie4_5_VLMoEConfig, |
| | ): |
| | """ |
| | calculate logits for pure text, multimodal text, and image |
| | Args: |
| | last_hidden_state: The hidden of the last layer, in sequence-parallel, is in the split state. |
| | ... |
| | token_type_ids_shifted: # Non-sp split tensor |
| | The token-type-ids at the label position is used to select the lm-head corresponding to each token. |
| | Note: In the id sequence of alternating images and texts, the last text token will predict the image id, |
| | and vice versa, so it is necessary to select the lmhead weight corresponding to the label type. |
| | """ |
| | |
| | |
| | assert last_hidden_state.shape[:2] == token_type_ids_shifted.shape, ( |
| | last_hidden_state.shape, |
| | token_type_ids_shifted.shape, |
| | ) |
| | parallel_matmul_tp = partial( |
| | parallel_matmul, |
| | ) |
| |
|
| | if mm_head_weight is None: |
| | if config.use_recompute_loss_fn: |
| | return last_hidden_state, None, None |
| | score_text = parallel_matmul_tp(last_hidden_state, lm_head_weight, lm_head_bias) |
| | return score_text, None, None |
| |
|
| | image_mask_shifted = token_type_ids_shifted == TokenType.image |
| | text_pos_shifted = token_type_ids_shifted == TokenType.text |
| |
|
| | if text_pos_shifted.any().item() > 0: |
| | score_text = parallel_matmul_tp( |
| | last_hidden_state[text_pos_shifted], lm_head_weight, lm_head_bias |
| | ) |
| | else: |
| | score_text = None |
| |
|
| | if mm_head_weight is not None and image_mask_shifted.any().item() > 0: |
| | score_image = parallel_matmul_tp( |
| | last_hidden_state[image_mask_shifted], mm_head_weight, mm_head_bias |
| | ) |
| | else: |
| | score_image = None |
| |
|
| | return score_text, score_image, None |
| |
|
| |
|
| | class Ernie4_5_MoeLMHead(nn.Module): |
| | """Language model head for ERNIE with support for tensor parallelism.""" |
| |
|
| | def __init__(self, config): |
| | """Initialize the language model head. |
| | |
| | Args: |
| | config (Ernie4_5_Config): Model configuration containing: |
| | - vocab_size: Size of vocabulary |
| | - hidden_size: Dimension of hidden states |
| | # - tensor_parallel_degree: Degree of tensor parallelism |
| | - tie_word_embeddings: Whether to tie input/output embeddings |
| | - weight_share_add_bias: Whether to add bias when weight sharing |
| | - use_bias: Whether to use bias term |
| | - use_recompute_loss_fn: Whether to defer logits computation to loss function |
| | - use_sparse_head_and_loss_fn: Whether to use sparse head computation |
| | """ |
| |
|
| | super(Ernie4_5_MoeLMHead, self).__init__() |
| | self.config = config |
| | if config.tensor_parallel_degree > 1: |
| | vocab_size = config.vocab_size // config.tensor_parallel_degree |
| | else: |
| | vocab_size = config.vocab_size |
| |
|
| | if config.tie_word_embeddings: |
| | self.weight = nn.Parameter( |
| | torch.empty( |
| | vocab_size, config.hidden_size, dtype=torch.get_default_dtype() |
| | ) |
| | ) |
| | else: |
| | self.weight = nn.Parameter( |
| | torch.empty( |
| | config.hidden_size, vocab_size, dtype=torch.get_default_dtype() |
| | ) |
| | ) |
| | nn.init.xavier_uniform_(self.weight) |
| |
|
| | logger.info( |
| | f"output-weight:{self.weight.shape} tie_word_embeddings:{config.tie_word_embeddings}" |
| | ) |
| |
|
| | if config.weight_share_add_bias and config.use_bias: |
| | self.bias = nn.Parameter( |
| | torch.zeros(vocab_size, dtype=torch.get_default_dtype()) |
| | ) |
| | else: |
| | self.bias = None |
| |
|
| | |
| | self.weight.is_distributed = ( |
| | True if (vocab_size != config.vocab_size) else False |
| | ) |
| | if config.weight_share_add_bias and config.use_bias: |
| | self.bias.is_distributed = ( |
| | True if (vocab_size != config.vocab_size) else False |
| | ) |
| |
|
| | if self.weight.is_distributed: |
| | self.weight.split_axis = 1 |
| | if ( |
| | config.weight_share_add_bias |
| | and config.use_bias |
| | and self.bias.is_distributed |
| | ): |
| | self.bias.split_axis = 0 |
| |
|
| | if self.config.use_recompute_loss_fn: |
| | logger.info( |
| | "Using recompute_loss_fn, the calculation of logits will be moved into " |
| | "loss_fn for memory optimization" |
| | ) |
| |
|
| | def forward(self, hidden_states, tensor_parallel_output=None): |
| | """Project hidden states to vocabulary logits. |
| | |
| | Args: |
| | hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
| | tensor_parallel_output (Optional[bool]): Whether to output parallel results. Defaults to None. |
| | |
| | Returns: |
| | Union[ |
| | Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | # When use_recompute_loss_fn or use_sparse_head_and_loss_fn |
| | - hidden_states: Original input |
| | - weight: Projection weights |
| | - bias: Optional bias term |
| | Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], bool]: # With tensor_parallel_output |
| | Same as above plus tensor_parallel_output flag |
| | torch.Tensor: # Normal case |
| | Logits tensor of shape [batch_size, seq_len, vocab_size] |
| | ] |
| | """ |
| | return calc_lm_head_logits( |
| | self.config, |
| | hidden_states, |
| | self.weight, |
| | self.bias, |
| | tensor_parallel_output, |
| | training=self.training, |
| | ) |
| |
|
| |
|
| | class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin): |
| | """ERNIE Mixture of Experts (MoE) model for causal language modeling.""" |
| |
|
| | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] |
| |
|
| | def __init__(self, config): |
| | """ |
| | Initializes the ERNIE MoE model for causal language modeling. |
| | |
| | Args: |
| | config (dict): Model configuration. |
| | """ |
| | super().__init__(config) |
| |
|
| | |
| | |
| | new_initializer_range = math.sqrt(0.3333 / config.hidden_size) |
| | logger.info( |
| | f"change initializer-range from {config.initializer_range} to {new_initializer_range}" |
| | ) |
| | config.initializer_range = new_initializer_range |
| | self.config = config |
| | self.model = Ernie4_5_Model(config) |
| | self.lm_head = Ernie4_5_MoeLMHead(config) |
| |
|
| | self.tie_weights() |
| |
|
| | def get_input_embeddings(self): |
| | """Returns the input embeddings layer.""" |
| | return self.model.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | """Sets the input embeddings layer.""" |
| | self.model.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | """Returns the output embeddings (LM head).""" |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | """Sets the output embeddings layer.""" |
| | self.lm_head = new_embeddings |
| |
|
| | def set_decoder(self, decoder): |
| | """Sets the ERNIE decoder model.""" |
| | self.model = decoder |
| |
|
| | def get_decoder(self): |
| | """Get the transformer decoder. |
| | |
| | Returns: |
| | nn.Layer: The decoder module |
| | """ |
| | return self.model |
| |
|
| | |
| | def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False): |
| | """ |
| | Updates model kwargs for generation. |
| | |
| | Args: |
| | outputs (Any): Model outputs. |
| | model_kwargs (dict): Current model kwargs. |
| | is_encoder_decoder (bool): Whether using encoder-decoder architecture. |
| | |
| | Returns: |
| | dict: Updated model kwargs. |
| | """ |
| | |
| | if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], torch.Tensor): |
| | model_kwargs["past_key_values"] = outputs[1] |
| |
|
| | if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: |
| | model_kwargs["past_key_values"] = outputs.past_key_values |
| |
|
| | |
| | if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None: |
| | token_type_ids = model_kwargs["token_type_ids"] |
| | model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1:]], dim=-1) |
| |
|
| | if not is_encoder_decoder and model_kwargs.get("attention_mask", None) is not None: |
| | |
| | attention_mask = model_kwargs["attention_mask"] |
| | model_kwargs["attention_mask"] = torch.cat( |
| | [ |
| | attention_mask, |
| | torch.ones((attention_mask.shape[0], 1), dtype=torch.int64, device=attention_mask.device), |
| | ], |
| | dim=-1, |
| | ) |
| |
|
| | |
| | if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: |
| | role_ids = model_kwargs["role_ids"] |
| | model_kwargs["role_ids"] = torch.cat([role_ids, role_ids[:, -1:]], dim=-1) |
| |
|
| | if self.config.get('rope_3d', False): |
| | assert "position_ids" in model_kwargs, "position_ids must be provided if rope_3d is on" |
| | position_ids = model_kwargs["position_ids"] |
| | bsz = position_ids.shape[0] |
| |
|
| | max_position = position_ids.max(dim=1, keepdim=True)[0] |
| | new_positions = max_position + 1 |
| | |
| | model_kwargs["position_ids"] = torch.cat( |
| | [position_ids, new_positions], |
| | dim=1 |
| | ) |
| |
|
| | return model_kwargs |
| |
|
| |
|
| | class VisionMlp(nn.Module): |
| | """VisionMLP""" |
| |
|
| | def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: |
| | super().__init__() |
| | self.fc1 = nn.Linear(dim, hidden_dim) |
| | self.act = ACT2FN[hidden_act] |
| | self.fc2 = nn.Linear(hidden_dim, dim) |
| |
|
| | def forward(self, x) -> torch.Tensor: |
| | """ |
| | Args: |
| | x (torch.Tensor): input tensor |
| | |
| | Returns: |
| | torch.Tensor: VisionMLP output tensor |
| | """ |
| | return self.fc2(self.act(self.fc1(x))) |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | """PatchEmbed""" |
| |
|
| | def __init__( |
| | self, |
| | patch_size: int = 14, |
| | in_channels: int = 3, |
| | embed_dim: int = 1152, |
| | ) -> None: |
| | """ |
| | Args: |
| | patch_size (int, optional): patch size. Defaults to 14. |
| | in_channels (int, optional): number of channels. Defaults to 3. |
| | embed_dim (int, optional): embedding dimension. Defaults to 1152. |
| | """ |
| | super().__init__() |
| | self.patch_size = patch_size |
| | self.in_channels = in_channels |
| | self.embed_dim = embed_dim |
| | self.proj = nn.Linear( |
| | in_channels * patch_size * patch_size, embed_dim, bias=False |
| | ) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states (torch.Tensor): hidden states |
| | |
| | Returns: |
| | torch.Tensor: output tensor |
| | """ |
| | target_dtype = self.proj.weight.dtype |
| |
|
| | hidden_states = self.proj(hidden_states.to(target_dtype)) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class VisionRotaryEmbedding(nn.Module): |
| | """VisionRotaryEmbedding""" |
| |
|
| | def __init__(self, dim: int, theta: float = 10000.0) -> None: |
| | """ |
| | Args: |
| | dim (int): the dimension of each token. |
| | theta (float, optional): the frequency factor. Defaults to 10000.0. |
| | """ |
| | super().__init__() |
| | self.inv_freq = 1.0 / theta ** ( |
| | torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim |
| | ) |
| |
|
| | def forward(self, seqlen: int) -> torch.Tensor: |
| | """ |
| | Args: |
| | seqlen (int): length of sequence. |
| | |
| | Returns: |
| | torch.Tensor: rotary position embedding |
| | """ |
| | seq = torch.arange(seqlen).to(self.inv_freq.dtype) |
| | freqs = torch.outer(input=seq, vec2=self.inv_freq) |
| | return freqs |
| |
|
| |
|
| | def rotate_half(x): |
| | """Rotates half the hidden dims of the input.""" |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb_vision( |
| | tensor: torch.Tensor, freqs: torch.Tensor |
| | ) -> torch.Tensor: |
| | """Applies Rotary Position Embedding to the input tensors. |
| | |
| | Args: |
| | tensor (torch.Tensor): The input tensor. |
| | freqs (torch.Tensor): The frequencies used for the rotation. |
| | Returns: |
| | output (torch.Tensor): the tensor rotated using the Rotary Position Embedding. |
| | """ |
| | orig_dtype = tensor.dtype |
| |
|
| | tensor = tensor.type(dtype=torch.float32) |
| | cos = freqs.cos() |
| | sin = freqs.sin() |
| | cos = cos.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32) |
| | sin = sin.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32) |
| | output = tensor * cos + rotate_half(tensor) * sin |
| | output = output.to(orig_dtype) |
| | return output |
| |
|
| |
|
| | class VisionAttention(nn.Module): |
| | """VisionAttention""" |
| |
|
| | def __init__(self, dim: int, num_heads: int = 16) -> None: |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.qkv = nn.Linear(dim, dim * 3, bias=True) |
| | self.proj = nn.Linear(dim, dim) |
| | self.head_dim = dim // num_heads |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | cu_seqlens: torch.Tensor, |
| | rotary_pos_emb: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """forward function for vision attention""" |
| | seq_length = hidden_states.shape[0] |
| | qkv = ( |
| | self.qkv(hidden_states) |
| | .reshape([seq_length, 3, self.num_heads, -1]) |
| | .permute(1, 0, 2, 3) |
| | ) |
| | q, k, v = qkv.unbind(axis=0) |
| |
|
| | q = apply_rotary_pos_emb_vision(q.unsqueeze(dim=0), rotary_pos_emb).squeeze( |
| | dim=0 |
| | ) |
| | k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze( |
| | dim=0 |
| | ) |
| | |
| | q = q.transpose(0, 1) |
| | k = k.transpose(0, 1) |
| | v = v.transpose(0, 1) |
| | |
| | lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| | splits = [ |
| | torch.split(tensor, lengths.tolist(), dim=1) for tensor in (q, k, v) |
| | ] |
| | |
| | attn_output = [] |
| | for q, k, v in zip(*splits): |
| | attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) |
| | attn_weights = nn.functional.softmax( |
| | attn_weights, dim=-1, dtype=torch.float32 |
| | ).to(q.dtype) |
| | attn_output_splited = torch.matmul(attn_weights, v) |
| | attn_output_splited = attn_output_splited.transpose(0, 1) |
| | attn_output.append(attn_output_splited) |
| | attn_output = torch.cat(attn_output, dim=0) |
| | attn_output = attn_output.reshape(seq_length, -1).contiguous() |
| | attn_output = self.proj(attn_output) |
| | return attn_output |
| |
|
| |
|
| | class DFNRopeVisionBlock(nn.Module): |
| | """DFNRopeVisionBlock""" |
| |
|
| | def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
| | """ |
| | Args: |
| | config (dict): model configuration. |
| | attn_implementation (str, optional): attention implementation. Defaults to "sdpa". |
| | """ |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6) |
| | self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6) |
| | mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) |
| |
|
| | self.attn = VisionAttention(config.embed_dim, num_heads=config.num_heads) |
| | self.mlp = VisionMlp( |
| | dim=config.embed_dim, |
| | hidden_dim=mlp_hidden_dim, |
| | hidden_act=config.hidden_act, |
| | ) |
| | self.config = config |
| |
|
| | def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states(torch.Tensor): hidden states |
| | cu_seqlens (torch.Tensor): cumulative sequence lengths |
| | rotary_pos_emb: rotary position embedding |
| | |
| | Returns: |
| | torch.Tensor: output tensor |
| | """ |
| | hidden_states = hidden_states + self.attn( |
| | self.norm1(hidden_states), |
| | cu_seqlens=cu_seqlens, |
| | rotary_pos_emb=rotary_pos_emb, |
| | ) |
| | hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| | return hidden_states |
| |
|
| |
|
| | class DFNRopeVisionTransformerPreTrainedModel(PreTrainedModel): |
| | """DFNRopeVisionTransformerPreTrainedModel""" |
| |
|
| | config_class = DFNRopeVisionTransformerConfig |
| | _tp_plan = {} |
| |
|
| | def __init__(self, config) -> None: |
| | """ |
| | Args: |
| | config (dict): model configuration |
| | """ |
| | super().__init__(config) |
| | self.spatial_merge_size = config.spatial_merge_size |
| |
|
| | self.patch_embed = PatchEmbed( |
| | patch_size=config.patch_size, |
| | in_channels=config.in_channels, |
| | embed_dim=config.embed_dim, |
| | ) |
| |
|
| | head_dim = config.embed_dim // config.num_heads |
| | self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) |
| |
|
| | self.blocks = nn.ModuleList( |
| | [DFNRopeVisionBlock(config) for _ in range(config.depth)] |
| | ) |
| |
|
| | assert ( |
| | config.hidden_size == config.embed_dim |
| | ), "in DFNRope, vit's config.hidden must be equal to config.embed_dim" |
| | self.ln = nn.LayerNorm(config.hidden_size, eps=1e-6) |
| |
|
| | def rot_pos_emb(self, grid_thw, num_pad=0): |
| | """rot_pos_emb |
| | |
| | Args: |
| | grid_thw (torch.Tensor): grid thw of input |
| | |
| | Returns: |
| | torch.Tensor: rotary position embedding |
| | """ |
| | pos_ids = [] |
| | grid_hw_array = np.array(grid_thw.cpu(), dtype=np.int64) |
| | for t, h, w in grid_hw_array: |
| | hpos_ids = np.arange(h).reshape([-1, 1]) |
| | hpos_ids = np.tile(hpos_ids, (1, w)) |
| | hpos_ids = hpos_ids.reshape( |
| | h // self.spatial_merge_size, |
| | self.spatial_merge_size, |
| | w // self.spatial_merge_size, |
| | self.spatial_merge_size, |
| | ) |
| | hpos_ids = np.transpose(hpos_ids, (0, 2, 1, 3)) |
| | hpos_ids = hpos_ids.flatten() |
| |
|
| | wpos_ids = np.arange(w).reshape([1, -1]) |
| | wpos_ids = np.tile(wpos_ids, (h, 1)) |
| | wpos_ids = wpos_ids.reshape( |
| | h // self.spatial_merge_size, |
| | self.spatial_merge_size, |
| | w // self.spatial_merge_size, |
| | self.spatial_merge_size, |
| | ) |
| | wpos_ids = np.transpose(wpos_ids, (0, 2, 1, 3)) |
| | wpos_ids = wpos_ids.flatten() |
| |
|
| | stacked_ids = np.stack([hpos_ids, wpos_ids], axis=-1) |
| | tiled_ids = np.tile(stacked_ids, (t, 1)) |
| | pos_ids.append(tiled_ids) |
| |
|
| | pos_ids = np.concatenate(pos_ids, axis=0) |
| | if num_pad > 0: |
| | pos_ids = np.concatenate( |
| | [pos_ids, np.zeros((num_pad, 2), dtype=pos_ids.dtype)] |
| | ) |
| | max_grid_size = np.amax(grid_hw_array[:, 1:]) |
| | rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
| | rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_dim=1) |
| | return rotary_pos_emb |
| |
|
| | def forward( |
| | self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states (torch.Tensor): input tensor |
| | grid_thw (torch.Tensor): grid thw of input |
| | num_pad (int): number of padding tokens |
| | |
| | Returns: |
| | torch.Tensor: output tensor |
| | """ |
| | hidden_states = self.patch_embed(hidden_states) |
| |
|
| | rotary_pos_emb = self.rot_pos_emb(grid_thw, num_pad=num_pad) |
| | rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) |
| |
|
| | cu_seqlens = torch.repeat_interleave( |
| | grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] |
| | ).cumsum(dim=0, dtype=torch.int32) |
| |
|
| | if num_pad > 0: |
| | cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) |
| | cu_seqlens[-1] = cu_seqlens[-2] + num_pad |
| | else: |
| | cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
| |
|
| | for idx, blk in enumerate(self.blocks): |
| | hidden_states = blk( |
| | hidden_states, |
| | cu_seqlens=cu_seqlens, |
| | rotary_pos_emb=rotary_pos_emb, |
| | ) |
| |
|
| | ret = self.ln(hidden_states) |
| | return ret |
| |
|
| |
|
| | class VariableResolutionResamplerModel(nn.Module): |
| | """ |
| | VariableResolutionResamplerModel, support variable resolution |
| | """ |
| |
|
| | def __init__(self, in_dim, out_dim, spatial_conv_size, temporal_conv_size, config): |
| | super().__init__() |
| | self.in_dim = in_dim |
| | self.out_dim = out_dim |
| | self.config = config |
| | self.spatial_conv_size = spatial_conv_size |
| | self.temporal_conv_size = temporal_conv_size |
| | self.use_temporal_conv = config.use_temporal_conv |
| |
|
| | |
| | self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size |
| | |
| | self.temporal_dim = ( |
| | self.in_dim |
| | * self.spatial_conv_size |
| | * self.spatial_conv_size |
| | * self.temporal_conv_size |
| | ) |
| |
|
| | |
| | with UniqueNameGuard("mm_resampler_") as guard: |
| |
|
| | self.spatial_linear = nn.Sequential( |
| | nn.Linear(self.spatial_dim, self.spatial_dim), |
| | nn.GELU(), |
| | nn.Linear(self.spatial_dim, self.spatial_dim), |
| | nn.LayerNorm(self.spatial_dim, eps=1e-6), |
| | ) |
| |
|
| | if self.use_temporal_conv: |
| | self.temporal_linear = nn.Sequential( |
| | nn.Linear(self.temporal_dim, self.spatial_dim), |
| | nn.GELU(), |
| | nn.Linear(self.spatial_dim, self.spatial_dim), |
| | nn.LayerNorm(self.spatial_dim, eps=1e-6), |
| | ) |
| |
|
| | self.mlp = nn.Linear(self.spatial_dim, self.out_dim) |
| |
|
| | out_config = deepcopy(config) |
| | out_config.hidden_size = out_dim |
| | self.after_norm = RMSNorm(out_config) |
| |
|
| | def spatial_conv_reshape(self, x, spatial_conv_size): |
| | """ |
| | reshape before linear to imitation conv |
| | """ |
| | S, C = x.shape |
| | x = x.reshape([-1, C * (spatial_conv_size**2)]) |
| | return x |
| |
|
| | def forward(self, x, image_mask, token_type_ids, image_type_ids, grid_thw): |
| | """ |
| | x: image_features |
| | image_mask: [B] |
| | token_types_ids: [B] |
| | image_type_ids: [B_image] |
| | grid_thw: [B_image, 3] |
| | """ |
| | assert image_type_ids is not None |
| |
|
| | def fwd_spatial(x): |
| | """ |
| | x in the shape of [S, H] |
| | S is ordered in the following way: [ [patch_h*patch_w (row-major traversal)] * patch_time] |
| | H is simply hidden |
| | """ |
| | x = self.spatial_conv_reshape(x, self.spatial_conv_size) |
| |
|
| | x = self.spatial_linear(x) |
| |
|
| | return x |
| |
|
| | def fwd_placeholder(x, grid_thw, to_tensor=False): |
| | """ |
| | x: [S, H] |
| | grid_thw: [S, 3] |
| | the second dimension: [t, h, w] |
| | """ |
| |
|
| | grid_thw_cpu = grid_thw.cpu().numpy() |
| | grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] |
| | grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2) |
| |
|
| | tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2) |
| | batch_offset = np.empty( |
| | tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype |
| | ) |
| | batch_offset[0] = 0 |
| | batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] |
| |
|
| | assert ( |
| | self.temporal_conv_size == 2 |
| | ), f"Hard Code: temporal_conv_size==2, got:{self.temporal_conv_size}" |
| |
|
| | |
| | slice_offsets = [] |
| | for temporoal_size, spatial_size, b_offset in zip( |
| | grid_t, grid_hw_after_conv, batch_offset |
| | ): |
| | for temp_offset in range(0, temporoal_size, 2): |
| | slice_offsets.append( |
| | np.arange( |
| | b_offset + (temp_offset) * spatial_size, |
| | b_offset + (temp_offset + 1) * spatial_size, |
| | ) |
| | ) |
| | slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( |
| | x.device |
| | ) |
| |
|
| | slice_offsets2 = [] |
| | for temporoal_size, spatial_size, b_offset in zip( |
| | grid_t, grid_hw_after_conv, batch_offset |
| | ): |
| | for temp_offset in range( |
| | 1 if temporoal_size > 1 else 0, temporoal_size, 2 |
| | ): |
| | slice_offsets2.append( |
| | np.arange( |
| | b_offset + (temp_offset) * spatial_size, |
| | b_offset + (temp_offset + 1) * spatial_size, |
| | ) |
| | ) |
| | slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( |
| | x.device |
| | ) |
| |
|
| | x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) |
| | x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) |
| | x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) |
| | return x |
| |
|
| | def fwd_temporal(x): |
| | x = self.temporal_linear(x) |
| | return x |
| |
|
| | def fwd_mlp(x): |
| | x = self.mlp(x) |
| | x = self.after_norm(x) |
| | return x |
| |
|
| | x = fwd_spatial(x) |
| | if self.use_temporal_conv: |
| | x = fwd_placeholder(x, grid_thw) |
| | x = fwd_temporal(x) |
| | x = fwd_mlp(x) |
| | return x |
| |
|
| |
|
| | class Ernie4_5_MoeVLHead(Ernie4_5_MoeLMHead): |
| | """Ernie4_5_MoeVLHead""" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| | if config.mm_vocab_size > 0: |
| | mm_vocab_config = deepcopy(config) |
| | mm_vocab_config.vocab_size = config.mm_vocab_size |
| | assert mm_vocab_config.vocab_size > 0, mm_vocab_config |
| | assert ( |
| | mm_vocab_config.im_patch_id >= mm_vocab_config.max_text_id |
| | ), mm_vocab_config |
| | self.mm_head = Ernie4_5_MoeLMHead(mm_vocab_config) |
| | else: |
| | self.mm_head = None |
| |
|
| | def forward(self, hidden_state, token_type_ids_labels, use_cache=False): |
| | """ |
| | Args: |
| | hidden_state(torch.Tensor): hidden state |
| | token_type_ids_labels(torch.Tensor): token ids |
| | use_cache(bool): whether to use cache, default is False |
| | |
| | Returns: |
| | logits_text(torch.Tensor): text logits |
| | logits_image(torch.Tensor): image logits |
| | """ |
| | if not use_cache: |
| | mm_head_weight = self.mm_head.weight if self.mm_head is not None else None |
| | mm_head_bias = self.mm_head.bias if self.mm_head is not None else None |
| | logits_text, logits_image, _ = calc_multimodal_logits( |
| | hidden_state, |
| | self.weight, |
| | self.bias, |
| | mm_head_weight, |
| | mm_head_bias, |
| | token_type_ids_labels, |
| | self.config, |
| | ) |
| | return logits_text, logits_image, None |
| | else: |
| | |
| | return ( |
| | parallel_matmul( |
| | hidden_state[:, -1:, :], |
| | self.weight, |
| | self.bias, |
| | transpose_y=self.config.tie_word_embeddings, |
| | ), |
| | None, |
| | None, |
| | ) |
| |
|
| |
|
| | class Ernie4_5_VLMoeForConditionalGeneration(Ernie4_5_MoeForCausalLM): |
| | """Ernie4_5_VLMoeForConditionalGeneration""" |
| |
|
| | config_class = Ernie4_5_VLMoEConfig |
| | main_input_name = "pixel_values" |
| | _keep_in_fp16_modules = ["vision_model"] |
| | _tp_plan = {} |
| |
|
| | def __init__( |
| | self, config: Ernie4_5_VLMoEConfig, vision_model=None, resampler_model=None |
| | ): |
| | """ |
| | initialize Ernie4_5_VLMoeForConditionalGeneration |
| | |
| | Args: |
| | config(Ernie4_5_VLMoEConfig): Model configuration. |
| | vision_model(nn.Module): vision model |
| | resampler_model(nn.Module): resampler model |
| | """ |
| | super().__init__(config) |
| |
|
| | self.vision_model = DFNRopeVisionTransformerPreTrainedModel( |
| | config.vision_config |
| | ) |
| |
|
| | self.model.resampler_model = VariableResolutionResamplerModel( |
| | config.pixel_hidden_size, |
| | config.hidden_size, |
| | config.spatial_conv_size, |
| | config.temporal_conv_size, |
| | config=config, |
| | ) |
| |
|
| | self.image_preprocess = None |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | self.post_init() |
| |
|
| | def add_image_preprocess(self, processor): |
| | """add image preprocess""" |
| | logger.info("image preprocess is set") |
| |
|
| | image_preprocess = processor.image_processor |
| | image_preprocess.image_mean_tensor = torch.tensor( |
| | image_preprocess.image_mean, dtype=torch.float32 |
| | ).reshape([1, 3, 1, 1]) |
| | image_preprocess.image_std_tensor = torch.tensor( |
| | image_preprocess.image_std, dtype=torch.float32 |
| | ).reshape([1, 3, 1, 1]) |
| | image_preprocess.rescale_factor = torch.tensor( |
| | image_preprocess.rescale_factor, dtype=torch.float32 |
| | ) |
| | image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze( |
| | [-2, -1] |
| | ).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1) |
| | image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze( |
| | [-2, -1] |
| | ).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1) |
| |
|
| | self.image_preprocess = image_preprocess |
| |
|
| | def vision_forward( |
| | self, |
| | images, |
| | image_position_ids, |
| | image_attention_mask, |
| | grid_thw, |
| | ): |
| | """vision_forward""" |
| | if self.image_preprocess is not None: |
| | assert images.dtype == torch.uint8, images.dtype |
| | current_device = images.device |
| | self.image_preprocess.image_mean_tensor = ( |
| | self.image_preprocess.image_mean_tensor.to(current_device) |
| | ) |
| | self.image_preprocess.image_std_tensor = ( |
| | self.image_preprocess.image_std_tensor.to(current_device) |
| | ) |
| | images = self.image_preprocess.rescale_factor * images.to(torch.float32) |
| | images = ( |
| | images - self.image_preprocess.image_mean_tensor |
| | ) / self.image_preprocess.image_std_tensor |
| | images = images.to(torch.bfloat16) |
| | else: |
| | assert images.dtype == torch.bfloat16, images.dtype |
| | |
| | if grid_thw is not None: |
| | grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) |
| | grid_thw = F.pad( |
| | torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), |
| | [1, 0, 0, 0], |
| | value=1, |
| | ) |
| | image_features = self.vision_model(images, grid_thw) |
| | return image_features |
| |
|
| | def vision_mapping_forward( |
| | self, |
| | token_type_ids, |
| | token_type_ids_w_video, |
| | input_ids, |
| | mm_input_ids, |
| | image_features, |
| | inputs_embeds, |
| | image_type_ids, |
| | grid_thw, |
| | ): |
| | """vision_mapping_forward""" |
| | image_mask = input_ids == self.config.im_patch_id |
| | image_features = self.model.resampler_model( |
| | image_features, |
| | image_mask, |
| | token_type_ids_w_video, |
| | image_type_ids, |
| | grid_thw, |
| | ) |
| |
|
| | if image_features.dim == 2: |
| | B, N, C = image_features.shape |
| | image_features = image_features.reshape([B * N, C]).to(inputs_embeds.dtype) |
| | |
| | inputs_embeds[image_mask.to(inputs_embeds.device)] = image_features.to( |
| | inputs_embeds.device |
| | ) |
| | return inputs_embeds |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | images=None, |
| | use_cache=False, |
| | past_key_values=None, |
| | inputs_embeds=None, |
| | image_position_ids=None, |
| | image_attention_mask=None, |
| | token_type_ids=None, |
| | image_type_ids=None, |
| | grid_thw=None, |
| | **kwargs, |
| | ): |
| | """ |
| | Prepare inputs for the decoder that can be used for generation. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input ids. |
| | images (torch.Tensor): Images. Default to None. |
| | use_cache (bool): Whether to use cache. Default to False. |
| | past_key_values (list): Past key values. Default to None. |
| | inputs_embeds (torch.Tensor): Input embeddings. Default to None. |
| | image_position_ids (torch.Tensor): Image position ids. Default to None. |
| | image_attention_mask (torch.Tensor): Image attention mask. Default to None. |
| | token_type_ids (torch.Tensor): Token type ids. Default to None. |
| | image_type_ids (torch.Tensor): Image type ids. Default to None. |
| | grid_thw (torch.Tensor): Grid thw. Default to None. |
| | """ |
| | if past_key_values: |
| | input_ids = input_ids[:, -1:] |
| | token_type_ids = token_type_ids[:, -1:] |
| | image_type_ids = ( |
| | image_type_ids[:, -1:] if image_type_ids is not None else None |
| | ) |
| |
|
| | if self.config.use_flash_attention: |
| | attention_mask = None |
| | else: |
| | attention_mask = kwargs.get("attention_mask", None) |
| |
|
| | |
| | if inputs_embeds is not None and past_key_values is None: |
| | model_inputs = {"inputs_embeds": inputs_embeds} |
| | else: |
| | model_inputs = {"input_ids": input_ids} |
| |
|
| | model_inputs.update( |
| | { |
| | "past_key_values": past_key_values, |
| | "use_cache": True, |
| | "attention_mask": attention_mask, |
| | "images": images, |
| | "image_position_ids": image_position_ids, |
| | "image_attention_mask": image_attention_mask, |
| | "image_type_ids": image_type_ids, |
| | "token_type_ids": torch.cat( |
| | [ |
| | token_type_ids, |
| | torch.zeros( |
| | [len(token_type_ids), 1], dtype=token_type_ids.dtype |
| | ).to(token_type_ids.device), |
| | ], |
| | dim=-1, |
| | ), |
| | "grid_thw": grid_thw, |
| | } |
| | ) |
| | if self.config.rope_3d: |
| | model_inputs.update({"position_ids": kwargs["position_ids"]}) |
| |
|
| | return model_inputs |
| |
|
| | def _post_init(self, original_init, *args, **kwargs): |
| | """ |
| | Label all multimodal parameters in the model, only head and Embedding |
| | Experts parameters are already labeled |
| | """ |
| | super()._post_init(self, original_init, *args, **kwargs) |
| | if self.lm_head.mm_head is not None: |
| | self.lm_head.mm_head.weight.expert_type = "expert_type_1" |
| | if getattr(self.lm_head.mm_head, "bias", None) is not None: |
| | self.lm_head.mm_head.bias.expert_type = "expert_type_1" |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | position_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[List[torch.Tensor]] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | images: Optional[torch.Tensor] = None, |
| | ignored_index: Optional[int] = 0, |
| | return_dict: Optional[bool] = None, |
| | image_position_ids: Optional[torch.Tensor] = None, |
| | image_attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | image_type_ids: Optional[torch.Tensor] = None, |
| | grid_thw: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ): |
| | """ |
| | Forward for Ernie4_5_VLMoeForConditionalGeneration |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input ids. |
| | position_ids (Optional[torch.Tensor], optional): Position ids. Defaults to None. |
| | attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. |
| | past_key_values (Optional[List[torch.Tensor]], optional): Past key values. Defaults to None. |
| | use_cache (Optional[bool], optional): Use cache. Defaults to None. |
| | output_attentions (Optional[bool], optional): Output attentions. Defaults to None. |
| | output_hidden_states (Optional[bool], optional): Output hidden states. Defaults to None. |
| | labels (Optional[torch.Tensor], optional): Labels. Defaults to None. |
| | images (Optional[torch.Tensor]): Images. Defaults to None. |
| | ignored_index (Optional[int], optional): Ignored index. Defaults to 0. |
| | return_dict (Optional[bool], optional): Return dict. Defaults to None. |
| | image_position_ids (Optional[torch.Tensor], optional): Image position ids. Defaults to None. |
| | image_attention_mask (Optional[torch.Tensor], optional): Image attention mask. Defaults to None. |
| | token_type_ids (Optional[torch.Tensor], optional): Token type ids. Defaults to None. |
| | image_type_ids (Optional[torch.Tensor], optional): Image type ids. Defaults to None. |
| | grid_thw (Optional[torch.Tensor], optional): Grid thw. Defaults to None. |
| | """ |
| | if grid_thw is not None: |
| | grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | image_mask = input_ids == self.config.im_patch_id |
| |
|
| | image_rate = image_mask.to(torch.float32).mean() |
| |
|
| | if past_key_values is None: |
| | if images is not None: |
| | assert (image_mask).any().item(), ( |
| | image_mask.detach().cpu().numpy().tolist(), |
| | input_ids.detach().cpu().numpy().tolist(), |
| | self.config.im_patch_id, |
| | images.shape, |
| | ) |
| | image_features = self.vision_forward( |
| | images, |
| | image_position_ids, |
| | image_attention_mask, |
| | grid_thw, |
| | ) |
| | else: |
| | image_features = None |
| | else: |
| | image_features = None |
| | if token_type_ids is None: |
| | token_type_ids = image_mask.to(torch.int64) |
| | token_type_ids_labels = torch.cat( |
| | [token_type_ids[:, 1:], token_type_ids[:, -1:]], 1 |
| | ) |
| | else: |
| | assert ( |
| | token_type_ids.shape[1] == input_ids.shape[1] + 1 |
| | ), f"token_type:{token_type_ids.shape}, ids:{input_ids.shape}" |
| | token_type_ids_labels = token_type_ids[..., 1:] |
| |
|
| | lm_input_ids = input_ids.clone() |
| | mm_input_ids = input_ids.clone() |
| |
|
| | inputs_embeds = self.model.embed_tokens(lm_input_ids) |
| | token_type_ids_w_video = token_type_ids[..., :-1].clone() |
| | token_type_ids[token_type_ids == TokenType.video] = TokenType.image |
| |
|
| | if images is not None and image_features is not None: |
| | inputs_embeds = self.vision_mapping_forward( |
| | token_type_ids[..., :-1], |
| | token_type_ids_w_video, |
| | input_ids, |
| | mm_input_ids, |
| | image_features, |
| | inputs_embeds, |
| | image_type_ids, |
| | grid_thw, |
| | ) |
| | else: |
| | pass |
| |
|
| | outputs = self.model( |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | past_key_values=past_key_values, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | ) |
| |
|
| | if not use_cache: |
| | assert outputs.last_hidden_state.shape[:2] == token_type_ids_labels.shape, ( |
| | outputs.last_hidden_state.shape, |
| | token_type_ids_labels.shape, |
| | ) |
| | if self.config.use_recompute_loss_fn: |
| | logits = outputs.last_hidden_state |
| | else: |
| | logits = self.lm_head(outputs.last_hidden_state) |
| | else: |
| | logits = self.lm_head(outputs.last_hidden_state[:, -1:, :]) |
| |
|
| | router_loss = outputs.router_loss |
| |
|
| | |
| | loss = None |
| | return CausalLMOutputWithCrossAttentions( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | router_loss=outputs.router_loss, |
| | ) |
| |
|
| | @staticmethod |
| | def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False): |
| | """_resolve_prefix_keys""" |
| | |
| | state_keys_map = {} |
| |
|
| | state_keys_base = set(state_keys_base) |
| | state_keys_real = set(state_keys_real) |
| |
|
| | for key in state_keys_base: |
| | for x in state_keys_real: |
| | if "mm_embed_tokens" in x: |
| | if "mm_embed_tokens" in key: |
| | state_keys_map[key] = x |
| | break |
| | elif x.endswith(key): |
| | state_keys_map[key] = x |
| | break |
| | if key not in state_keys_map: |
| | if not ignore_error: |
| | logger.error(f"could not find name {key} in loaded state dict!") |
| | else: |
| | state_keys_real.remove(state_keys_map[key]) |
| |
|
| | return state_keys_map |
| |
|
| |
|
| | @dataclass |
| | class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): |
| | """ |
| | Base class for model outputs with past key values and cross attention layers, |
| | with additional support for router components in mixture-of-experts models. |
| | |
| | This extends the base model output to include: |
| | 1. Router-related outputs for expert selection |
| | 2. Maintains all existing functionality from the parent class |
| | """ |
| |
|
| | last_hidden_state: Optional[Tuple[torch.Tensor]] = None |
| | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None |
| | hidden_states: Optional[Tuple[torch.Tensor]] = None |
| | attentions: Optional[Tuple[torch.Tensor]] = None |
| | cross_attentions: Optional[Tuple[torch.Tensor]] = None |
| | router_loss: Optional[torch.Tensor] = None |
| | gate_logits: Optional[Tuple[torch.Tensor]] = None |
| |
|
| |
|
| | @dataclass |
| | class CausalLMOutputWithCrossAttentions(ModelOutput): |
| | """ |
| | Base class for causal language model (or autoregressive) outputs. |
| | |
| | Args: |
| | loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| | Language modeling loss (for next-token prediction). |
| | logits (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| | hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` |
| | is passed or when `config.output_hidden_states=True`): |
| | Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| | |
| | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| | attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or |
| | when `config.output_attentions=True`): |
| | Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| | sequence_length)`. |
| | |
| | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| | heads. |
| | router_loss (Optional[torch.Tensor]): |
| | The routing loss computed by the gating network in mixture-of-experts models. |
| | This is typically the load balancing loss that encourages equal expert utilization. |
| | None when not using mixture-of-experts routing. |
| | """ |
| |
|
| | loss: Optional[torch.Tensor] = None |
| | logits: torch.Tensor = None |
| | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None |
| | hidden_states: Optional[Tuple[torch.Tensor]] = None |
| | attentions: Optional[Tuple[torch.Tensor]] = None |
| | router_loss: Optional[Tuple[torch.Tensor]] = None |
| |
|