import torch import torchaudio import librosa import numpy as np import numpy as np import logging from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, logging as transformers_logging # Suppress Transformer warnings (initialized weights for pretraining) transformers_logging.set_verbosity_error() class FeatureExtractor_v2: def __init__(self, device='cuda'): self.device = device self.sample_rate = 16000 # Load ASR Model for Phoneme Features (Lightweight version) # print("Loading ASR model for Phoneme Branch...") try: self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") self.asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device) self.asr_model.eval() # print("ASR model loaded successfully.") # Load WavLM (SOTA for Deepfake Detection) # print("Loading WavLM model for Semantic Branch...") from transformers import WavLMModel self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base").to(device) self.wavlm.eval() # print("WavLM model loaded successfully.") # Freeze Models for param in self.asr_model.parameters(): param.requires_grad = False for param in self.wavlm.parameters(): param.requires_grad = False except Exception as e: print(f"CRITICAL WARNING: Could not load SOTA models (WavLM/ASR): {e}") import traceback traceback.print_exc() self.asr_model = None self.wavlm = None def extract_f0(self, waveform): """ Extract Pitch (F0) using Librosa (PYIN). Input: (1, T) tensor Output: (1, T_frame) tensor normalized """ # Move to CPU for librosa wav_np = waveform.squeeze().cpu().numpy() # PYIN is accurate but slow. For real-time, we might check alternatives. # But for training, it's fine. f0, voiced_flag, voiced_probs = librosa.pyin(wav_np, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), sr=self.sample_rate) # Handle NaNs (unvoiced) f0 = np.nan_to_num(f0) # Normalize if f0.max() > 0: f0 = f0 / f0.max() return torch.tensor(f0, dtype=torch.float32).unsqueeze(0).to(self.device) def extract_energy(self, waveform, n_fft=1024, hop_length=256): """ Extract RMS Energy. Output: (1, T_frame) """ # Mel Spectrogram # Input (B, C, T) -> (B, C, F, T) mel = torchaudio.transforms.MelSpectrogram( sample_rate=self.sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=80 ).to(self.device)(waveform) # Energy = Sum of Mels over Frequency (Dim -2) # Mel Shape: (B, C, n_mels, T) energy = torch.sum(mel, dim=-2) # (B, C, T_frame) energy = torch.log(energy + 1e-6) # Determine if we need to squeeze channel? # Model expects (B, T) or (B, 1, T). # Let's keep (B, C, T) but squeeze channel if it's 1? # Caller expects (1, T_frame). extract_energy returns energy.squeeze(1)? # If B=1, C=1 -> result (1, 1, T). squeeze(1) -> (1, T). return energy.squeeze(1) # (B, T) def extract_wavlm(self, waveform): """ Extract WavLM Last Hidden State. Output: (1, T_wavlm, 768) """ if self.wavlm is None: return None with torch.no_grad(): if waveform.dim() == 3: input_values = waveform.squeeze(1) else: input_values = waveform outputs = self.wavlm(input_values) hidden_states = outputs.last_hidden_state # (B, T, 768) return hidden_states def extract_phonemes(self, waveform): """ Extract Phoneme Probabilities (Logits). Output: (1, T_frame, Vocab) """ if self.asr_model is None: return None with torch.no_grad(): # ASR expects (B, T) - Mono if waveform.dim() == 3 and waveform.shape[1] == 1: input_values = waveform.squeeze(1) else: input_values = waveform # W2V2 expects roughly normalized input # Typically processor does normalization, but we can assume roughly -1 to 1. logits = self.asr_model(input_values).logits # (B, T_frame, Vocab) return logits def compute_all(self, waveform): """ Returns dict of features. """ # Ensure (1, T) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) # 1. Mel (Standard) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.sample_rate, n_mels=80 ).to(self.device) mel = mel_transform(waveform) # (B, C, 80, T) # Convert to Log-Mel (dB) -> Done inside here to ensure consistency across usage mel = torch.log(mel + 1e-6) # If C=1, squeeze it to match SpectralNet expectation (B, 80, T) if mel.dim() == 4 and mel.shape[1] == 1: mel = mel.squeeze(1) # 2. F0 # Note: Librosa F0 length might differ from Mel length due to padding/centering. # We need to interpolate F0 to match Mel time dimension. f0 = self.extract_f0(waveform) # (B, T_f0) # Resize F0 to match Mel target_len = mel.shape[-1] # F0 is (B, T). unsqueeze(1) -> (B, 1, T). interpolate -> (B, 1, T_new). squeeze(1) -> (B, T_new) f0 = torch.nn.functional.interpolate(f0.unsqueeze(1), size=target_len, mode='linear', align_corners=False).squeeze(1) # 3. Energy energy = self.extract_energy(waveform) # (B, T) # Resize energy = torch.nn.functional.interpolate(energy.unsqueeze(1), size=target_len, mode='linear', align_corners=False).squeeze(1) # 4. Phonemes phonemes = self.extract_phonemes(waveform) # (1, T_asr, Vocab) # 5. WavLM (New SOTA Feature) wavlm_feat = self.extract_wavlm(waveform) # (1, T_wavlm, 768) return { 'mel': mel, # (1, 80, T) 'f0': f0, # (1, T) 'energy': energy, # (1, T) 'phonemes': phonemes, # (1, T_asr, 32) 'wavlm': wavlm_feat, # (1, T_wavlm, 768) 'waveform': waveform } class FeatureExtractor: def __init__(self, device='cuda'): self.device = device self.sample_rate = 16000 # Load ASR Model for Phoneme Features (Lightweight version) print("Loading ASR model for Phoneme Branch...") try: self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") self.asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device) self.asr_model.eval() print("ASR model loaded successfully.") # Freeze ASR for param in self.asr_model.parameters(): param.requires_grad = False except Exception as e: print(f"Warning: Could not load ASR model: {e}") self.asr_model = None def extract_f0(self, waveform): """ Extract Pitch (F0) using Librosa (PYIN). Input: (1, T) tensor Output: (1, T_frame) tensor normalized """ # Move to CPU for librosa wav_np = waveform.squeeze().cpu().numpy() # PYIN is accurate but slow. For real-time, we might check alternatives. # But for training, it's fine. f0, voiced_flag, voiced_probs = librosa.pyin(wav_np, fmin=librosa.note_to_hz('C2'), fmax=librosa.note_to_hz('C7'), sr=self.sample_rate) # Handle NaNs (unvoiced) f0 = np.nan_to_num(f0) # Normalize if f0.max() > 0: f0 = f0 / f0.max() return torch.tensor(f0, dtype=torch.float32).unsqueeze(0).to(self.device) def extract_energy(self, waveform, n_fft=1024, hop_length=256): """ Extract RMS Energy. Output: (1, T_frame) """ # Mel Spectrogram # Input (B, C, T) -> (B, C, F, T) mel = torchaudio.transforms.MelSpectrogram( sample_rate=self.sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=80 ).to(self.device)(waveform) # Energy = Sum of Mels over Frequency (Dim -2) # Mel Shape: (B, C, n_mels, T) energy = torch.sum(mel, dim=-2) # (B, C, T_frame) energy = torch.log(energy + 1e-6) # Determine if we need to squeeze channel? # Model expects (B, T) or (B, 1, T). # Let's keep (B, C, T) but squeeze channel if it's 1? # Caller expects (1, T_frame). extract_energy returns energy.squeeze(1)? # If B=1, C=1 -> result (1, 1, T). squeeze(1) -> (1, T). return energy.squeeze(1) # (B, T) def extract_phonemes(self, waveform): """ Extract Phoneme Probabilities (Logits). Output: (1, T_frame, Vocab) """ if self.asr_model is None: return None with torch.no_grad(): # ASR expects (B, T) - Mono if waveform.dim() == 3 and waveform.shape[1] == 1: input_values = waveform.squeeze(1) else: input_values = waveform # W2V2 expects roughly normalized input # Typically processor does normalization, but we can assume roughly -1 to 1. logits = self.asr_model(input_values).logits # (B, T_frame, Vocab) return logits def compute_all(self, waveform): """ Returns dict of features. """ # Ensure (1, T) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) # 1. Mel (Standard) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.sample_rate, n_mels=80 ).to(self.device) mel = mel_transform(waveform) # (B, C, 80, T) # Convert to Log-Mel (dB) # Add epsilon to avoid log(0) mel = torch.log(mel + 1e-6) # If C=1, squeeze it to match SpectralNet expectation (B, 80, T) if mel.dim() == 4 and mel.shape[1] == 1: mel = mel.squeeze(1) # 2. F0 # Note: Librosa F0 length might differ from Mel length due to padding/centering. # We need to interpolate F0 to match Mel time dimension. f0 = self.extract_f0(waveform) # (B, T_f0) # Resize F0 to match Mel target_len = mel.shape[-1] # F0 is (B, T). unsqueeze(1) -> (B, 1, T). interpolate -> (B, 1, T_new). squeeze(1) -> (B, T_new) f0 = torch.nn.functional.interpolate(f0.unsqueeze(1), size=target_len, mode='linear', align_corners=False).squeeze(1) # 3. Energy energy = self.extract_energy(waveform) # (B, T) # Resize energy = torch.nn.functional.interpolate(energy.unsqueeze(1), size=target_len, mode='linear', align_corners=False).squeeze(1) # 4. Phonemes phonemes = self.extract_phonemes(waveform) # (1, T_asr, Vocab) # Resize Phonemes to match Mel # Note: ASR has downsampling 320. Mel usually Hop 256 or 512. # We project phonemes to match Mel time if needed, or keep separate. # The design says different branches take different inputs. # Phoneme branch takes Phonemes. return { 'mel': mel, # (1, 80, T) 'f0': f0, # (1, T) 'energy': energy, # (1, T) 'phonemes': phonemes, # (1, T_asr, 32) 'waveform': waveform }