| | from dataclasses import dataclass |
| | from typing import Tuple, Optional |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| | from transformers.file_utils import ModelOutput |
| | from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model |
| |
|
| |
|
| | @dataclass |
| | class SpeechClassifierOutput(ModelOutput): |
| | loss: Optional[torch.FloatTensor] = None |
| | logits: torch.FloatTensor = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | class Wav2Vec2ClassificationHead(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.dropout = nn.Dropout(config.final_dropout) |
| | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | def forward(self, features, **kwargs): |
| | x = features |
| | x = self.dropout(x) |
| | x = self.dense(x) |
| | x = torch.tanh(x) |
| | x = self.dropout(x) |
| | x = self.out_proj(x) |
| | return x |
| |
|
| |
|
| | class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.config = config |
| |
|
| | self.wav2vec2 = Wav2Vec2Model(config) |
| | self.classifier = Wav2Vec2ClassificationHead(config) |
| |
|
| | self.init_weights() |
| |
|
| | def freeze_feature_extractor(self): |
| | self.wav2vec2.feature_extractor._freeze_parameters() |
| |
|
| | def merged_strategy( |
| | self, |
| | hidden_states, |
| | mode="mean" |
| | ): |
| | if mode == "mean": |
| | outputs = torch.mean(hidden_states, dim=1) |
| | elif mode == "sum": |
| | outputs = torch.sum(hidden_states, dim=1) |
| | elif mode == "max": |
| | outputs = torch.max(hidden_states, dim=1)[0] |
| | else: |
| | raise Exception( |
| | "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") |
| |
|
| | return outputs |
| |
|
| | def forward( |
| | self, |
| | input_values, |
| | attention_mask=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | labels=None, |
| | ): |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | outputs = self.wav2vec2( |
| | input_values, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | hidden_states = outputs[0] |
| | hidden_states = self.merged_strategy(hidden_states, mode="mean") |
| | logits = self.classifier(hidden_states) |
| |
|
| | loss = None |
| | if labels is not None: |
| | if self.config.problem_type is None: |
| | if self.num_labels == 1: |
| | self.config.problem_type = "regression" |
| | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| | self.config.problem_type = "single_label_classification" |
| | else: |
| | self.config.problem_type = "multi_label_classification" |
| |
|
| | if self.config.problem_type == "regression": |
| | loss_fct = MSELoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels) |
| | elif self.config.problem_type == "single_label_classification": |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.config.problem_type == "multi_label_classification": |
| | loss_fct = BCEWithLogitsLoss() |
| | loss = loss_fct(logits, labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SpeechClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|