import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import sys import os # Import sibling modules sys.path.append(os.path.dirname(os.path.abspath(__file__))) from deep_watermark import WatermarkDiscriminator from robust_watermark import WatermarkDetector import numpy as np import math try: # Try importing audioseal. If not installed, we might mock it or fail. # The user requested integration, so we assume it's available or we install it. # For this code, we'll use torch.hub if package not found, or just assume it's there. import audioseal except ImportError: print("Warning: 'audioseal' not found. AudioSeal branch will fail if used.") class WatermarkExpertBranch(nn.Module): """ Differentiable implementation of the Robust Watermark Detector. Extracts correlation features. """ def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256): super().__init__() self.detector = WatermarkDetector(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length) # We need to register the watermark block as a buffer so it moves with the model self.register_buffer('watermark_kernel', self.detector.watermark_block.unsqueeze(1)) self.window = torch.hann_window(n_fft) self.n_fft = n_fft self.hop_length = hop_length def forward(self, waveform): # waveform: (B, 1, T) # 1. STFT window = self.window.to(waveform.device) stft = torch.stft(waveform.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=True, center=True) magnitude = torch.abs(stft) # (B, F, T) # 2. Whitening (Spectral Smoothing) mag_unsqueezed = magnitude.unsqueeze(1) # (B, 1, F, T) smoothed = torch.nn.functional.avg_pool2d( mag_unsqueezed, kernel_size=(1, 15), stride=1, padding=(0, 7) ) whitened = mag_unsqueezed - smoothed whitened = whitened / (torch.std(whitened, dim=(2,3), keepdim=True) + 1e-6) # 3. Correlation # Kernel: (1, 1, F, T_block) kernel = self.watermark_kernel kernel = kernel - torch.mean(kernel) kernel = kernel / (torch.norm(kernel) + 1e-6) # Conv2d # Input: (B, 1, F, T) # Weight: (1, 1, F, T_block) correlation_map = torch.nn.functional.conv2d(whitened, kernel) # (B, 1, 1, T_out) scores = correlation_map.squeeze(1).squeeze(1) # (B, T_out) # 4. Feature Extraction # Max score, Mean score, Std score max_score = torch.max(scores, dim=1, keepdim=True)[0] mean_score = torch.mean(scores, dim=1, keepdim=True) std_score = torch.std(scores, dim=1, keepdim=True) return torch.cat([max_score, mean_score, std_score], dim=1) # (B, 3) class SynthArtifactBranch(nn.Module): """ ResNet-like CNN to detect synthetic artifacts from Mel-Spectrograms. """ def __init__(self): super().__init__() self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=64) self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(2, 2) def __init__(self, sample_rate=16000): super().__init__() self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=64) self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.layer1 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True) ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(64, 32) # Embedding size 32 def forward(self, x): # x: (B, 1, T) mel = self.mel_transform(x) # (B, 1, n_mels, time) mel = self.amplitude_to_db(mel) x = self.conv1(mel) x = self.bn1(x) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.pool(x) x = x.flatten(1) x = self.fc(x) return x # (B, 32) class LFCCBranch(nn.Module): def __init__(self, sample_rate=16000): super().__init__() # LFCC: Linear Frequency Cepstral Coefficients # Good for detecting high-frequency artifacts in synthetic speech self.lfcc_transform = torchaudio.transforms.LFCC( sample_rate=sample_rate, n_lfcc=40, speckwargs={"n_fft": 1024, "win_length": 400, "hop_length": 160} ) # ResNet-style Encoder (Reuse similar architecture) self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.layer1 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True) ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(64, 32) # Embedding size 32 def forward(self, x): # x: (B, 1, T) lfcc = self.lfcc_transform(x) # (B, 1, n_lfcc, time) # LFCC is already coefficients, no need for dB conversion x = self.conv1(lfcc) x = self.bn1(x) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.pool(x) x = x.flatten(1) x = self.fc(x) return x # (B, 32) class AudioSealBranch(nn.Module): def __init__(self): super().__init__() # Load pre-trained detector from installed package try: # We use the 'detector' part only. self.detector = audioseal.AudioSeal.load_detector("audioseal_detector_16bits") self.detector.eval() # Freeze for param in self.detector.parameters(): param.requires_grad = False except Exception as e: print(f"Error loading AudioSeal: {e}") self.detector = None # Feature Extraction Head # AudioSeal outputs a probability map (B, 2, T) self.pool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(2, 32) # Map to 32-dim embedding def forward(self, waveform): # waveform: (B, 1, T) if self.detector is None: return torch.zeros(waveform.size(0), 32).to(waveform.device) with torch.no_grad(): # AudioSeal expects (B, 1, T) # Returns: (B, 2, T) -> [Prob_Watermark, Prob_Message] # Note: AudioSeal might expect 16kHz. result = self.detector(waveform,16000) # AudioSeal returns a tuple (watermark_prob, message_prob) or similar. # We want the watermark probability map. if isinstance(result, tuple): probs = result[0] # Assume first element is the probability map else: probs = result # Ensure it's a tensor if not isinstance(probs, torch.Tensor): # Fallback if something is wrong return torch.zeros(waveform.size(0), 32).to(waveform.device) probs = probs[:, :, :] # (B, 2, T) # 1. Global Statistics (Mean probability across time) global_stats = self.pool(probs).squeeze(2) # (B, 2) # 2. Map to Embedding embedding = self.fc(global_stats) # (B, 32) return embedding class SincConv(nn.Module): """ Sinc-based convolution layer. Initializes filters as band-pass filters (Sinc functions). """ def __init__(self, out_channels, kernel_size, sample_rate=16000, min_low_hz=50, min_band_hz=50): super().__init__() if kernel_size % 2 == 0: kernel_size = kernel_size + 1 self.out_channels = out_channels self.kernel_size = kernel_size self.sample_rate = sample_rate self.min_low_hz = min_low_hz self.min_band_hz = min_band_hz # Initialize filters # We learn low_hz and band_hz low_hz = 30 high_hz = sample_rate / 2 - (min_low_hz + min_band_hz) # Mel-scale initialization mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) hz = self.to_hz(mel) # Filter parameters (Learnable) self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) # Hamming window self.window_ = torch.hamming_window(self.kernel_size, periodic=False) def to_mel(self, hz): return 2595 * np.log10(1 + hz / 700) def to_hz(self, mel): return 700 * (10 ** (mel / 2595) - 1) def forward(self, x): # Calculate actual frequencies low = self.min_low_hz + torch.abs(self.low_hz_) high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate/2) band = (high - low)[:, 0] # Create filters in time domain f_times_t_low = torch.matmul(low, (2 * math.pi * torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate)) f_times_t_high = torch.matmul(high, (2 * math.pi * torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate)) # Sinc function: sin(x)/x # Bandpass = 2 * (f_high * sinc(2*pi*f_high*t) - f_low * sinc(2*pi*f_low*t)) # Note: We use a simplified implementation for stability # Ideally we use full sinc formula. Here we approximate or use standard conv if too complex for this snippet. # But let's try to be correct. # Standard Sinc: # band_pass = 2 * f2 * sinc(2*pi*f2*t) - 2 * f1 * sinc(2*pi*f1*t) # We can implement sinc manually t = torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2 + 1).to(x.device).view(1, -1) / self.sample_rate t = t.repeat(self.out_channels, 1) low_f = low high_f = high # Sinc(x) = sin(x)/x # 2*f*sinc(2*pi*f*t) = 2*f * sin(2*pi*f*t) / (2*pi*f*t) = sin(2*pi*f*t) / (pi*t) band_pass_left = torch.sin(2 * math.pi * high_f * t) / (math.pi * t + 1e-6) band_pass_right = torch.sin(2 * math.pi * low_f * t) / (math.pi * t + 1e-6) # Handle t=0 # at t=0, limit is 2*f center_idx = int((self.kernel_size-1)/2) band_pass_left[:, center_idx] = 2 * high_f[:, 0] band_pass_right[:, center_idx] = 2 * low_f[:, 0] filters = band_pass_left - band_pass_right # Apply window filters = filters * self.window_.to(filters.device) return F.conv1d(x, filters.view(self.out_channels, 1, self.kernel_size)) class RawWaveBranch(nn.Module): """ Branch 6: Raw Waveform Analysis using SincNet-style layers. Detects fine-grained temporal artifacts. """ def __init__(self, sample_rate=16000): super().__init__() # Sinc Conv Layer self.sinc_conv = SincConv(out_channels=32, kernel_size=129, sample_rate=sample_rate) # Standard CNN layers following SincConv self.layer1 = nn.Sequential( nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm1d(64), nn.LeakyReLU(0.2), nn.MaxPool1d(2) ) self.layer2 = nn.Sequential( nn.Conv1d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm1d(128), nn.LeakyReLU(0.2), nn.MaxPool1d(2) ) self.pool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(128, 32) def forward(self, x): # x: (B, 1, T) x = self.sinc_conv(x) # (B, 32, T') x = self.layer1(x) x = self.layer2(x) x = self.pool(x).flatten(1) x = self.fc(x) return x class UniversalDetector(nn.Module): def __init__(self, sample_rate=16000): super().__init__() # Branch 1: Watermark Expert self.watermark_expert = WatermarkExpertBranch(sample_rate) # Branch 2: Synth Artifacts self.synth_artifact = SynthArtifactBranch(sample_rate) # Branch 3: LFCC Features self.lfcc_branch = LFCCBranch(sample_rate) # Branch 4: Deep Watermark (Pre-trained Discriminator) self.deep_watermark = WatermarkDiscriminator() # Branch 5: AudioSeal (SOTA) self.audioseal_branch = AudioSealBranch() # Branch 6: RawWave Expert (SincNet) self.raw_wave_branch = RawWaveBranch(sample_rate) # Fusion Head # Inputs: # - WM Expert: 3 # - Synth Artifact: 32 # - Deep WM: 1 # - LFCC: 32 # - AudioSeal: 32 # - RawWave: 32 # Total: 3 + 32 + 1 + 32 + 32 + 32 = 132 self.fusion_head = nn.Sequential( nn.Linear(132, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 2) # 2 Outputs: [Logit_Watermarked, Logit_Synth] ) def forward(self, waveform): # waveform: (B, 1, T) # 1. Expert Features wm_feats = self.watermark_expert(waveform) # (B, 3) # 2. Synth Features synth_emb = self.synth_artifact(waveform) # (B, 32) # 3. LFCC Features lfcc_emb = self.lfcc_branch(waveform) # (B, 32) # 4. Deep Features deep_prob = self.deep_watermark(waveform) # (B, 1) # 5. AudioSeal Features audioseal_emb = self.audioseal_branch(waveform) # (B, 32) # 6. RawWave Features raw_emb = self.raw_wave_branch(waveform) # (B, 32) # Fusion features = torch.cat([wm_feats, synth_emb, deep_prob, lfcc_emb, audioseal_emb, raw_emb], dim=1) # (B, 132) logits = self.fusion_head(features) # (B, 2) return logits