import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import numpy as np class WatermarkUNet(nn.Module): def __init__(self, input_channels=1, watermark_len=32): super().__init__() self.watermark_len = watermark_len # Encoder self.enc1 = nn.Conv1d(input_channels, 16, kernel_size=4, stride=2, padding=1) # -> L/2 self.enc2 = nn.Conv1d(16, 32, kernel_size=4, stride=2, padding=1) # -> L/4 self.enc3 = nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1) # -> L/8 self.enc4 = nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1) # -> L/16 # Bottleneck # We concatenate the watermark to the latent space self.bottleneck_conv = nn.Conv1d(128 + watermark_len, 128, kernel_size=3, padding=1) # Decoder self.dec4 = nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1) self.dec3 = nn.ConvTranspose1d(128, 32, kernel_size=4, stride=2, padding=1) # Skip connection from enc3 (64+64) -> 128 input? No, cat(64, 64) -> 128 self.dec2 = nn.ConvTranspose1d(64, 16, kernel_size=4, stride=2, padding=1) # Skip from enc2 (32+32) -> 64 self.dec1 = nn.ConvTranspose1d(32, input_channels, kernel_size=4, stride=2, padding=1) # Skip from enc1 (16+16) -> 32 self.final_act = nn.Tanh() def forward(self, x, watermark): # x: (B, 1, L) # watermark: (B, watermark_len) # Encoder e1 = F.leaky_relu(self.enc1(x), 0.2) e2 = F.leaky_relu(self.enc2(e1), 0.2) e3 = F.leaky_relu(self.enc3(e2), 0.2) e4 = F.leaky_relu(self.enc4(e3), 0.2) # Bottleneck # Expand watermark to match e4 length B, _, L_bot = e4.shape w_expanded = watermark.unsqueeze(2).expand(B, self.watermark_len, L_bot) bottleneck = torch.cat([e4, w_expanded], dim=1) bottleneck = F.leaky_relu(self.bottleneck_conv(bottleneck), 0.2) # Decoder d4 = F.relu(self.dec4(bottleneck)) # Skip connection: Concatenate d4 with e3 # Check shapes (might need trimming if odd lengths) if d4.shape[2] != e3.shape[2]: d4 = d4[:, :, :e3.shape[2]] d4_cat = torch.cat([d4, e3], dim=1) d3 = F.relu(self.dec3(d4_cat)) if d3.shape[2] != e2.shape[2]: d3 = d3[:, :, :e2.shape[2]] d3_cat = torch.cat([d3, e2], dim=1) d2 = F.relu(self.dec2(d3_cat)) if d2.shape[2] != e1.shape[2]: d2 = d2[:, :, :e1.shape[2]] d2_cat = torch.cat([d2, e1], dim=1) d1 = self.dec1(d2_cat) # Residual connection: Output = Input + WatermarkSignal # We want the network to learn the *difference* (the watermark) # But here d1 is the full reconstructed audio? # Usually for watermarking, we output the *delta* or the full audio. # Let's output the full audio but constrain it to be close to x. # Or better: d1 is the delta. # Let's try d1 as delta. delta = self.final_act(d1) * 0.1 # Scale down delta initially if delta.shape[2] != x.shape[2]: delta = delta[:, :, :x.shape[2]] return x + delta class DistortionLayer(nn.Module): def __init__(self, sample_rate=16000): super().__init__() self.sample_rate = sample_rate def forward(self, x): # Apply random distortions during training if self.training: # 1. Gaussian Noise if torch.rand(1) < 0.5: noise = torch.randn_like(x) * 0.01 x = x + noise # 2. Amplitude Scaling if torch.rand(1) < 0.5: scale = 0.8 + torch.rand(1) * 0.4 # 0.8 to 1.2 x = x * scale.to(x.device) # 3. Time Stretch (Resampling approximation) # We can use interpolate if torch.rand(1) < 0.3: factor = 0.9 + torch.rand(1) * 0.2 # 0.9 to 1.1 new_len = int(x.shape[2] * factor) x = F.interpolate(x, size=new_len, mode='linear', align_corners=False) # We need to return fixed size for batching usually, but for now let's assume # the detector handles variable length or we crop/pad back. # For simplicity in batch training, we might want to crop/pad back to original length. if new_len > x.shape[2]: x = x[:, :, :x.shape[2]] # Crop elif new_len < x.shape[2]: x = F.pad(x, (0, x.shape[2] - new_len)) # 4. MP3 Compression Approximation (Spectral Dropout) if torch.rand(1) < 0.3: # Explicit rectangular window to suppress warning window = torch.ones(512, device=x.device) x_stft = torch.stft(x.squeeze(1), n_fft=512, return_complex=True, window=window) # Zero out high frequencies or random bins mask = torch.rand_like(x_stft.real) > 0.1 # Drop 10% x_stft = x_stft * mask x = torch.istft(x_stft, n_fft=512, length=x.shape[2], window=window, center=True).unsqueeze(1) return x class WatermarkDiscriminator(nn.Module): def __init__(self, input_channels=1): super().__init__() self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=15, stride=2, padding=7) self.conv2 = nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7) self.conv3 = nn.Conv1d(64, 128, kernel_size=15, stride=2, padding=7) self.conv4 = nn.Conv1d(128, 128, kernel_size=15, stride=2, padding=7) self.global_pool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(128, 1) def forward(self, x): x = F.leaky_relu(self.conv1(x), 0.2) x = F.leaky_relu(self.conv2(x), 0.2) x = F.leaky_relu(self.conv3(x), 0.2) x = F.leaky_relu(self.conv4(x), 0.2) x = self.global_pool(x).squeeze(2) x = torch.sigmoid(self.fc(x)) return x class MultiResolutionSTFTLoss(nn.Module): def __init__(self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240]): super().__init__() self.fft_sizes = fft_sizes self.hop_sizes = hop_sizes self.win_lengths = win_lengths def forward(self, x, y): loss = 0.0 for n_fft, hop_length, win_length in zip(self.fft_sizes, self.hop_sizes, self.win_lengths): window = torch.hann_window(win_length).to(x.device) x_stft = torch.stft(x.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, return_complex=True) y_stft = torch.stft(y.squeeze(1), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, return_complex=True) x_mag = torch.abs(x_stft) y_mag = torch.abs(y_stft) # Spectral Convergence Loss sc_loss = torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + 1e-6) # Log Magnitude Loss log_loss = F.l1_loss(torch.log(x_mag + 1e-6), torch.log(y_mag + 1e-6)) loss += sc_loss + log_loss return loss / len(self.fft_sizes) class WatermarkLoss(nn.Module): def __init__(self, lambda_audio=1.0, lambda_detect=1.0, lambda_perceptual=0.1): super().__init__() self.lambda_audio = lambda_audio self.lambda_detect = lambda_detect self.lambda_perceptual = lambda_perceptual self.mse = nn.MSELoss() self.bce = nn.BCELoss() self.stft_loss = MultiResolutionSTFTLoss() def forward(self, original, watermarked, prediction, target_label): l_audio = self.mse(original, watermarked) l_detect = self.bce(prediction, target_label) l_perceptual = self.stft_loss(original, watermarked) total_loss = (self.lambda_audio * l_audio) + \ (self.lambda_detect * l_detect) + \ (self.lambda_perceptual * l_perceptual) return total_loss, l_audio, l_detect, l_perceptual