import torch import torchaudio import numpy as np import matplotlib.pyplot as plt from typing import Tuple, Optional import os import soundfile as sf # Phase 1: Global Parameters SAMPLE_RATE = 16000 N_FFT = 1024 HOP_LENGTH = N_FFT // 4 # 256 WIN_LENGTH = N_FFT WATERMARK_KEY = 42 # Default key def load_audio(path: str) -> Tuple[torch.Tensor, int]: """Robust audio loading using soundfile.""" try: data, sr = sf.read(path) # Soundfile returns (Frames, Channels) or (Frames,) if data.ndim == 1: waveform = torch.from_numpy(data).unsqueeze(0) # (1, T) else: waveform = torch.from_numpy(data.T) # (C, T) return waveform.float(), sr except Exception as e: print(f"Error loading {path} with soundfile: {e}") # Fallback or re-raise raise e def save_audio(path: str, waveform: torch.Tensor, sample_rate: int): """Robust audio saving using soundfile.""" # Waveform is (C, T) data = waveform.detach().cpu().numpy().T sf.write(path, data, sample_rate) class WatermarkEmbedder: def __init__(self, sample_rate: int = SAMPLE_RATE, n_fft: int = N_FFT, hop_length: int = HOP_LENGTH, key: int = WATERMARK_KEY, alpha: float = 3.0): # Increased alpha for better detection self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.key = key self.alpha = alpha self.window = torch.hann_window(self.n_fft) def _get_masking_threshold(self, magnitude: torch.Tensor) -> torch.Tensor: """ Calculates a simplified psychoacoustic masking threshold. """ freqs = torch.linspace(0, self.sample_rate / 2, magnitude.shape[-2]) # 1. Absolute Threshold of Hearing (ATH) f_khz = freqs / 1000.0 f_khz = torch.clamp(f_khz, min=0.02) # Clamp to 20Hz to prevent overflow ath = 3.64 * (f_khz ** -0.8) - 6.5 * torch.exp(-0.6 * (f_khz - 3.3) ** 2) + 1e-3 * (f_khz ** 4) ath = 10 ** (ath / 20) # Convert dB to amplitude ath = ath.view(1, -1, 1).to(magnitude.device) # 2. Tonal Masking # Spread energy mag_unsqueezed = magnitude.unsqueeze(1) kernel = torch.tensor([0.1, 0.3, 1.0, 0.3, 0.1], device=magnitude.device).view(1, 1, -1, 1) spread_energy = torch.nn.functional.conv2d( mag_unsqueezed, kernel, padding=(2, 0) ).squeeze(1) # Masking threshold masking_threshold = torch.max(ath, spread_energy * 0.1) return masking_threshold def embed(self, audio_path: str, output_path: str, visualize: bool = False): # 1. Input Normalization waveform, sr = load_audio(audio_path) if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) waveform = resampler(waveform) max_val = torch.abs(waveform).max() if max_val > 0: waveform = waveform / max_val original_channels = waveform.shape[0] if original_channels == 2: mid = (waveform[0] + waveform[1]) / 2 side = (waveform[0] - waveform[1]) / 2 target_signal = mid.unsqueeze(0) else: target_signal = waveform # 2. STFT self.window = self.window.to(waveform.device) stft = torch.stft( target_signal, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, return_complex=True, center=True ) magnitude = torch.abs(stft) phase = torch.angle(stft) # 3. Psychoacoustic Masking masking_threshold = self._get_masking_threshold(magnitude) # 4. Watermark Generation # Block size 48 frames (~0.75 sec) frames_per_block = 48 freq_bins = magnitude.shape[1] g = torch.Generator(device=waveform.device) g.manual_seed(self.key) watermark_block = (torch.rand((1, freq_bins, frames_per_block), generator=g, device=waveform.device) * 2) - 1 total_frames = magnitude.shape[2] num_repeats = (total_frames // frames_per_block) + 1 watermark_full = watermark_block.repeat(1, 1, num_repeats) watermark_full = watermark_full[:, :, :total_frames] # 5. Injection injection_signal = (self.alpha * watermark_full * masking_threshold) magnitude_mod = magnitude + injection_signal # Visualization if visualize: self._plot_embedding_stats(magnitude, masking_threshold, injection_signal, magnitude_mod, output_path) # 6. Reconstruction stft_mod = torch.polar(magnitude_mod, phase) reconstructed = torch.istft( stft_mod, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True, length=target_signal.shape[-1] ) if original_channels == 2: rec_mid = reconstructed.squeeze(0) rec_l = rec_mid + side rec_r = rec_mid - side final_audio = torch.stack([rec_l, rec_r]) else: final_audio = reconstructed save_audio(output_path, final_audio, self.sample_rate) return final_audio def _plot_embedding_stats(self, magnitude, masking_threshold, injection, magnitude_mod, output_path): """Generates plots for the embedding process.""" mag_np = 20 * torch.log10(magnitude[0] + 1e-6).cpu().numpy() mask_np = 20 * torch.log10(masking_threshold[0] + 1e-6).cpu().numpy() inj_np = 20 * torch.log10(torch.abs(injection[0]) + 1e-6).cpu().numpy() mod_np = 20 * torch.log10(magnitude_mod[0] + 1e-6).cpu().numpy() plt.figure(figsize=(15, 10)) plt.subplot(2, 2, 1) plt.imshow(mag_np, aspect='auto', origin='lower', cmap='inferno') plt.title("Original Spectrogram (dB)") plt.colorbar(format='%+2.0f dB') plt.subplot(2, 2, 2) plt.imshow(mask_np, aspect='auto', origin='lower', cmap='viridis') plt.title("Masking Threshold (dB)") plt.colorbar(format='%+2.0f dB') plt.subplot(2, 2, 3) plt.imshow(inj_np, aspect='auto', origin='lower', cmap='magma') plt.title("Injected Watermark Signal (dB)") plt.colorbar(format='%+2.0f dB') plt.subplot(2, 2, 4) plt.imshow(mod_np, aspect='auto', origin='lower', cmap='inferno') plt.title("Watermarked Spectrogram (dB)") plt.colorbar(format='%+2.0f dB') plt.tight_layout() plot_path = os.path.splitext(output_path)[0] + "_embedding_analysis_aa.png" plt.savefig(plot_path) plt.close() print(f"Saved embedding analysis to {plot_path}") class WatermarkDetector: def __init__(self, sample_rate: int = SAMPLE_RATE, n_fft: int = N_FFT, hop_length: int = HOP_LENGTH, key: int = WATERMARK_KEY): self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.key = key self.window = torch.hann_window(self.n_fft) # Generate the reference watermark block self.frames_per_block = 48 freq_bins = n_fft // 2 + 1 g = torch.Generator() g.manual_seed(self.key) self.watermark_block = (torch.rand((1, freq_bins, self.frames_per_block), generator=g) * 2) - 1 def detect(self, audio_path: str, threshold: float = 0.05, visualize: bool = False) -> bool: # 1. Preprocessing waveform, sr = load_audio(audio_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sr != self.sample_rate: resampler = torchaudio.transforms.Resample(sr, self.sample_rate) waveform = resampler(waveform) # 2. Synchronization self.window = self.window.to(waveform.device) stft = torch.stft( waveform, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, return_complex=True, center=True ) magnitude = torch.abs(stft) # 3. Whitening # Subtract moving average along time to remove speech formants # Kernel size for smoothing: (1, 15) mag_unsqueezed = magnitude.unsqueeze(1) smoothed = torch.nn.functional.avg_pool2d( mag_unsqueezed, kernel_size=(1, 15), stride=1, padding=(0, 7) ) # Handle size mismatch due to padding/pooling if any (avg_pool2d with stride 1 and padding should preserve size) # But avg_pool2d might truncate edges if not careful. # Let's use 'same' padding logic. # With kernel 15, padding 7, output size is Input + 2*7 - 15 + 1 = Input. Correct. whitened = mag_unsqueezed - smoothed whitened = whitened.squeeze(1) # Normalize variance whitened = whitened / (torch.std(whitened) + 1e-6) # 4. Correlation Check input_signal = whitened.unsqueeze(1) # Kernel kernel = self.watermark_block.to(waveform.device).unsqueeze(1) kernel = kernel - torch.mean(kernel) kernel = kernel / (torch.norm(kernel) + 1e-6) # Check if input is smaller than kernel if input_signal.shape[-1] < kernel.shape[-1]: print("Warning: Input audio too short for detection.") return False correlation_map = torch.nn.functional.conv2d(input_signal, kernel) scores = correlation_map.squeeze() if scores.numel() == 0: return False max_score = torch.max(scores).item() print(f"Max Correlation Score: {max_score}") if visualize: self._plot_detection_stats(whitened, scores, max_score, threshold, audio_path) if max_score > threshold: return True return False def _plot_detection_stats(self, whitened, scores, max_score, threshold, audio_path): """Generates plots for the detection process.""" whitened_np = whitened[0].cpu().numpy() scores_np = scores.cpu().numpy() plt.figure(figsize=(15, 8)) plt.subplot(2, 1, 1) plt.imshow(whitened_np, aspect='auto', origin='lower', cmap='coolwarm', vmin=-3, vmax=3) plt.title("Whitened Spectrogram (Signal - Smoothed Background)") plt.colorbar() plt.subplot(2, 1, 2) plt.plot(scores_np) plt.axhline(y=threshold, color='r', linestyle='--', label=f'Threshold ({threshold})') plt.axhline(y=max_score, color='g', linestyle=':', label=f'Max Score ({max_score:.2f})') plt.title("Correlation Score (Sliding Window)") plt.xlabel("Time Frame Index") plt.ylabel("Pearson Correlation") plt.legend() plt.tight_layout() plot_path = os.path.splitext(audio_path)[0] + "_detection_analysis_aa.png" plt.savefig(plot_path) plt.close() print(f"Saved detection analysis to {plot_path}") if __name__ == "__main__": # Simple test import os embedder = WatermarkEmbedder(alpha=5.0) # Stronger watermark for test detector = WatermarkDetector() # Create dummy audio sr = 16000 duration = 5 t = torch.linspace(0, duration, sr * duration) audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) + 0.1 * torch.randn_like(t) audio = audio.unsqueeze(0) test_file = "test_original.wav" watermarked_file = "test_watermarked.wav" save_audio(test_file, audio, sr) print("Embedding watermark...") embedder.embed(test_file, watermarked_file, visualize=True) print("Detecting watermark...") detected = detector.detect(watermarked_file, threshold=0.02, visualize=True) print(f"Detected: {detected}") print("Testing with crop...") wm_audio, _ = load_audio(watermarked_file) crop_start = sr * 1 crop_end = sr * 3 # 2 seconds cropped_audio = wm_audio[:, crop_start:crop_end] cropped_file = "test_cropped.wav" save_audio(cropped_file, cropped_audio, sr) detected_crop = detector.detect(cropped_file, threshold=0.02) print(f"Detected in crop: {detected_crop}")