import torch import os import sys import argparse import numpy as np import matplotlib.pyplot as plt import torchaudio import torch.nn.functional as F from transformers import WavLMForSequenceClassification, AutoFeatureExtractor # Add path to root sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from Universal_Audio_Detector.new_model import UniversalDetectorV2 as UniversalDetector from Universal_Audio_Detector.features import FeatureExtractor_v2 as FeatureExtractor from Universal_Audio_Detector.robust_watermark import load_audio class WavLMClassifier: def __init__(self, model_path, device=None): self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Loading WavLM Classifier from: {model_path}") try: self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path) self.model = WavLMForSequenceClassification.from_pretrained(model_path).to(self.device) self.model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise def predict_file(self, audio_path, chunk_duration=3.0, stride=1.5): """ Predicts whether an audio file is Real or Synthetic using sliding window. """ try: waveform, sr = torchaudio.load(audio_path) except Exception as e: print(f"Error reading file {audio_path}: {e}") return None # Resample if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) waveform = resampler(waveform) sr = 16000 # Mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) waveform = waveform.squeeze() # (T,) # Chunking Logic (Sliding Window) chunk_size = int(chunk_duration * sr) stride_size = int(stride * sr) # Handle short files if waveform.shape[0] < chunk_size: pad_len = chunk_size - waveform.shape[0] waveform = F.pad(waveform, (0, pad_len)) chunks = [waveform] else: chunks = [] for start in range(0, waveform.shape[0] - chunk_size + 1, stride_size): chunks.append(waveform[start:start+chunk_size]) # Add remainder if meaningful if (waveform.shape[0] - chunk_size) % stride_size != 0: chunks.append(waveform[-chunk_size:]) probs = [] print(f"Processing {len(chunks)} chunks...") with torch.no_grad(): for chunk in chunks: inputs = self.feature_extractor( chunk, sampling_rate=16000, return_tensors="pt", padding="max_length", max_length=chunk_size, truncation=True ) input_values = inputs['input_values'].to(self.device) logits = self.model(input_values).logits probabilities = F.softmax(logits, dim=1) # Probability of Class 1 (Synthetic) synth_prob = probabilities[0, 1].item() probs.append(synth_prob) # Aggregation avg_prob = np.mean(probs) max_prob = np.max(probs) # Decision (Conservative) verdict = "SYNTHETIC" if avg_prob > 0.5 else "REAL" return { "verdict": verdict, "avg_confidence": avg_prob, "max_confidence": max_prob, "chunk_probs": probs } # def explain_single_file_for_synth(audio_path, model_path="universal_detector_v2_sota.pth", output_image="detection_timeline.png"): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # try: # model = UniversalDetector(sample_rate=16000).to(device) # model.eval() # if os.path.exists(model_path): # checkpoint = torch.load(model_path, map_location=device) # try: # model.load_state_dict(checkpoint) # except: # model.load_state_dict(checkpoint, strict=False) # else: # print("Error: Model weights not found!") # return # # print("Loading Features (WavLM)...") # feature_extractor = FeatureExtractor(device=device) # except Exception as e: # print(f"Initialization Failed: {e}") # return # # 2. Process Audio (Sliding Window) # try: # waveform, sr = load_audio(audio_path) # if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # if sr != 16000: # resampler = torchaudio.transforms.Resample(sr, 16000) # waveform = resampler(waveform) # CHUNK_SIZE = 16000 # STRIDE = 16000 // 2 # if waveform.shape[1] < CHUNK_SIZE: # pad = CHUNK_SIZE - waveform.shape[1] # waveform = torch.nn.functional.pad(waveform, (0, pad)) # chunks = [] # timestamps = [] # # Generator loop # if waveform.shape[1] == CHUNK_SIZE: # chunks = [waveform] # timestamps = [0.0] # else: # for start in range(0, waveform.shape[1] - CHUNK_SIZE + 1, STRIDE): # chunks.append(waveform[:, start:start+CHUNK_SIZE]) # timestamps.append(start / 16000.0) # # Remainder # if waveform.shape[1] > CHUNK_SIZE and (waveform.shape[1] - CHUNK_SIZE) % STRIDE != 0: # chunks.append(waveform[:, -CHUNK_SIZE:]) # timestamps.append((waveform.shape[1] - CHUNK_SIZE) / 16000.0) # chunk_probs = [] # gate_weights_list = [] # valid_timestamps = [] # with torch.no_grad(): # for i, chunk in enumerate(chunks): # # Energy Filter # if torch.mean(torch.abs(chunk)) < 0.005: # continue # Skip silence # chunk_cuda = chunk.unsqueeze(0).to(device) # features = feature_extractor.compute_all(chunk_cuda.squeeze(0)) # if features.get('wavlm') is None: # print("Error: WavLM missing! Use .env/bin/python") # return # logits, gw = model(chunk_cuda, features_dict=features) # prob = torch.sigmoid(logits)[0, 1].item() # chunk_probs.append(prob) # gate_weights_list.append(gw[0].cpu().numpy().flatten()) # Flatten to (6,) # valid_timestamps.append(timestamps[i]) # if not chunk_probs: # print("Audio is mostly silence. Classified as Real.") # return # # 3. Verdict Logic (Top-K Mean) # # 3. Verdict Logic (Dynamic Top-K Mean) # # Fixed K=5 is good for short clips (5-10s). # # For long clips (e.g. 1 min), K=5 is too sensitive (flags any 2.5s noise). # # We switch to Dynamic K: 10% of the total chunks, but at least 5. # total_chunks = len(chunk_probs) # K = max(5, int(total_chunks * 0.10)) # sorted_probs = sorted(chunk_probs, reverse=True) # top_k = sorted_probs[:K] if len(sorted_probs) >= K else sorted_probs # final_score = np.mean(top_k) # # Slightly higher threshold for SOTA confidence # verdict = "SYNTHETIC" if final_score > 0.55 else "REAL" # # AudioSeal # audioseal_prob = 0.0 # if model.audioseal_branch.detector is not None: # res = model.audioseal_branch.detector(waveform.unsqueeze(0).to(device), 16000) # if isinstance(res, tuple): res = res[0] # audioseal_prob = torch.max(res[0, 1, :]).item() # if audioseal_prob > 0.8: verdict += " + WATERMARKED (AudioSeal)" # print("\n>> DETAILED SEGMENT ANALYSIS (Significant Events > 50%):") # print(f"{'Time (s)':<10} | {'Prob':<8} | {'Primary Factor':<20} | {'Weight':<8} | {'Forensic Diagnosis'}") # print("-" * 55) # branches = ["Spectral", "Emotion", "Phoneme", "GAN", "Jitter", "WavLM"] # def get_forensic_diagnosis(branch, prob): # # Heuristics for explaining detection causes # if branch == "Spectral": # if prob < 0.85: return "Likely Noise/Compression Artifacts" # return "Vocoder Artifacts (High Frequency)" # elif branch == "WavLM": # if prob < 0.8: return "Structural Inconsistency" # return "Strong AI Semantic Anomaly" # elif branch == "Phoneme": # return "Unnatural Articulation/Slurring" # elif branch == "Emotion": # return "Flat/Unnatural Prosody" # elif branch == "GAN": # return "Repeating Patterns/Checkerboard" # elif branch == "Jitter": # return "Micro-Tremors (Robotic Glitch)" # return "Unknown Artifact" # has_significant_events = False # for t, p, gw in zip(valid_timestamps, chunk_probs, gate_weights_list): # if p > 0.5: # has_significant_events = True # branch_idx = gw.argmax() # branch_name = branches[branch_idx] if branch_idx < len(branches) else "Unknown" # weight = gw[branch_idx] # diagnosis = get_forensic_diagnosis(branch_name, p) # print(f"{t:<10.1f} | {p:<8.4f} | {branch_name:<20} | {weight:<8.2f} | {diagnosis}") # if not has_significant_events: # print("No significant synthetic segments detected.") # # 5. Visualization (Advanced) # plt.figure(figsize=(12, 10)) # # Panel 1: Waveform # plt.subplot(2, 1, 1) # # Downsample waveform for plotting # wave_np = waveform.squeeze().cpu().numpy() # step = max(1, len(wave_np) // 2000) # Max 2000 points # time_axis = np.linspace(0, len(wave_np)/16000, len(wave_np))[::step] # plt.plot(time_axis, wave_np[::step], color='black', alpha=0.6) # plt.title(f"1. Audio Waveform: {os.path.basename(audio_path)}") # plt.ylabel("Amplitude") # plt.grid(True, alpha=0.3) # plt.xlim(0, valid_timestamps[-1] + 1) # # Panel 2: Synthetic Probability with Threshold # plt.subplot(2, 1, 2) # plt.plot(valid_timestamps, chunk_probs, label='Fake Probability', color='red', linewidth=2) # plt.axhline(y=0.5, color='gray', linestyle='--', label='Decision Threshold (0.5)') # plt.fill_between(valid_timestamps, chunk_probs, 0.5, where=(np.array(chunk_probs) > 0.5), color='red', alpha=0.2) # plt.ylabel("Probability") # plt.title("2. Synthetic Detection Confidence over Time") # plt.legend(loc='upper right') # plt.grid(True, alpha=0.3) # plt.xlim(0, valid_timestamps[-1] + 1) # plt.ylim(0, 1.1) # # # Panel 3: Forensic Bubble Chart (Scatter) # # plt.subplot(3, 1, 3) # # # Prepare XY data for scatter # # x_points = [] # # y_points = [] # # sizes = [] # # colors = [] # # gw_array = np.array(gate_weights_list) # (T, 6) # # # Handle shapes # # if gw_array.shape[1] < 6: # # gw_array = np.pad(gw_array, ((0,0), (0, 6-gw_array.shape[1])), mode='constant') # # elif gw_array.shape[1] > 6: # # gw_array = gw_array[:, :6] # # for t_idx, t in enumerate(valid_timestamps): # # weights = gw_array[t_idx] # # for b_idx in range(6): # # w = weights[b_idx] # # if w > 0.1: # Only plot significant bubbles to reduce clutter # # x_points.append(t) # # y_points.append(b_idx) # # sizes.append(w * 500) # Scale for visibility # # colors.append(w) # # # Plot Bubbles # # scatter = plt.scatter(x_points, y_points, s=sizes, c=colors, cmap='viridis', alpha=0.7, edgecolors='black') # # plt.colorbar(scatter, label="Attention Weight") # # plt.yticks(range(6), branches) # # plt.xlabel("Time (s)") # # plt.title("3. Forensic Event Radar: Feature Contributions") # # plt.grid(True, alpha=0.3, linestyle='--') # # plt.xlim(0, valid_timestamps[-1] + 1) # # plt.ylim(-0.5, 5.5) # # # Condition Overlay: Highlight Fake Regions # # for t_idx, (t, p) in enumerate(zip(valid_timestamps, chunk_probs)): # # if p > 0.5: # # # Red highlighting for FAKE zones # # width = STRIDE / 16000.0 # # rect = plt.Rectangle((t, -0.5), width, 6.0, linewidth=0, facecolor='red', alpha=0.1) # # plt.gca().add_patch(rect) # # # Mark the max branch # # dom_branch = gw_array[t_idx].argmax() # # # Place a red 'X' or '!' on the dominant bubble # # plt.text(t + width/2, dom_branch, "!", color='white', fontweight='bold', ha='center', va='center') # plt.tight_layout() # plt.savefig(output_image) # plt.close() # print(f"\nForensic Graph saved to: {output_image}") # except Exception as e: # print(f"Error during explanation: {e}") # import traceback # traceback.print_exc() def explain_single_file(audio_path, model_path="universal_detector_v2_sota.pth", output_image="detection_timeline.png"): print(f"\n{'='*60}") print(f"SOTA V2 Explanation: {os.path.basename(audio_path)}") print(f"{'='*60}") filename=audio_path.split('/')[-1] output_image=os.path.splitext(audio_path)[0] + "detection_timeline_2.png" # 1. Setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Loading Model...") try: model = UniversalDetector(sample_rate=16000).to(device) model.eval() if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) try: model.load_state_dict(checkpoint) except: model.load_state_dict(checkpoint, strict=False) else: print("Error: Model weights not found!") return print("Loading Features (WavLM)...") feature_extractor = FeatureExtractor(device=device) except Exception as e: print(f"Initialization Failed: {e}") return # 2. Process Audio (Sliding Window) try: result_json={} waveform, sr = load_audio(audio_path) wavlm_classifier = WavLMClassifier(model_path=".final_model") wavlm_res = wavlm_classifier.predict_file(audio_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) waveform = resampler(waveform) CHUNK_SIZE = 16000 STRIDE = 16000 // 2 if waveform.shape[1] < CHUNK_SIZE: pad = CHUNK_SIZE - waveform.shape[1] waveform = torch.nn.functional.pad(waveform, (0, pad)) chunks = [] timestamps = [] # Generator loop if waveform.shape[1] == CHUNK_SIZE: chunks = [waveform] timestamps = [0.0] else: for start in range(0, waveform.shape[1] - CHUNK_SIZE + 1, STRIDE): chunks.append(waveform[:, start:start+CHUNK_SIZE]) timestamps.append(start / 16000.0) # Remainder if waveform.shape[1] > CHUNK_SIZE and (waveform.shape[1] - CHUNK_SIZE) % STRIDE != 0: chunks.append(waveform[:, -CHUNK_SIZE:]) timestamps.append((waveform.shape[1] - CHUNK_SIZE) / 16000.0) chunk_probs = [] gate_weights_list = [] valid_timestamps = [] print(f"Analyzing {len(chunks)} chunks...") # Adaptive Energy Threshold # Ignore chunks that are just background noise relative to the loudest part of the file global_max_amp = torch.max(torch.abs(waveform)).item() silence_thresh = max(0.005, 0.1*global_max_amp) if wavlm_res["verdict"] == "SYNTHETIC": silence_thresh = 0.005 else: silence_thresh = max(0.005, 0.1*global_max_amp) print(f"Adaptive Silence Threshold: {silence_thresh:.4f} (Max Amp: {global_max_amp:.4f})") print(" -> Helps ignore background noise/wind in real recordings.") print(wavlm_res["verdict"],'-------') with torch.no_grad(): for i, chunk in enumerate(chunks): # Energy Filter if torch.mean(torch.abs(chunk)) < silence_thresh: chunk_probs.append(silence_thresh) valid_timestamps.append(timestamps[i]) arr = np.random.uniform(0.001, 0.3, size=6) gate_weights_list.append(arr) # Flatten to (6,) continue # Skip silence/background noise chunk_cuda = chunk.unsqueeze(0).to(device) features = feature_extractor.compute_all(chunk_cuda.squeeze(0)) if features.get('wavlm') is None: print("Error: WavLM missing! Use .env/bin/python") return logits, gw = model(chunk_cuda, features_dict=features) prob = torch.sigmoid(logits)[0, 1].item() chunk_probs.append(prob) gate_weights_list.append(gw[0].cpu().numpy().flatten()) # Flatten to (6,) valid_timestamps.append(timestamps[i]) if not chunk_probs: print("Audio is mostly silence. Classified as Real.") return # 3. Verdict Logic (Smoothed + Dynamic Top-K + Density Check) # Apply Temporal Smoothing (Window=3) raw_probs = np.array(chunk_probs) if len(raw_probs) >= 3: kernel = np.array([0.3, 0.4, 0.3]) smoothed_probs = np.convolve(raw_probs, kernel, mode='same') smoothed_probs[0] = raw_probs[0] smoothed_probs[-1] = raw_probs[-1] else: smoothed_probs = raw_probs total_chunks = len(chunks) # print(total_chunks,'total_chunks') # print(len(smoothed_probs),'smoothed_probs') # print(len(raw_probs),'raw_probs') # 1. Top-K Score K = max(5, int(total_chunks * 0.10)) sorted_probs = sorted(smoothed_probs, reverse=True) top_k = sorted_probs[:K] if len(sorted_probs) >= K else sorted_probs final_score = np.mean(top_k) # 2. Density Check (User Request) # Count how many chunks are actually suspicious (> 0.5) num_suspicious = np.sum(smoothed_probs > 0.5) # print(num_suspicious,'num_suspicious') suspicious_ratio = num_suspicious / total_chunks # Heuristic: A real deepfake usually has artifacts consistent across the file. # If High Score BUT Low Density (e.g. < 15% of file), it's likely just a specific noise event. density_threshold = 0.2 verdict = "REAL" verdict_reason = "" print(final_score,'final_score') print(suspicious_ratio,'suspicious_ratio') if final_score > 0.55: if suspicious_ratio >= density_threshold: verdict = "SYNTHETIC" else: if wavlm_res['avg_confidence']>0.85: verdict="SYNTHETIC" final_score=wavlm_res['avg_confidence'] elif suspicious_ratio >= 0.1 and wavlm_res["verdict"] == "SYNTHETIC": verdict="SYNTHETIC" final_score=wavlm_res['avg_confidence'] else: verdict = "REAL" final_score=wavlm_res['avg_confidence'] verdict_reason = f"(Ignored: Only {suspicious_ratio*100:.1f}% of file is suspicious)" elif wavlm_res["verdict"] == "SYNTHETIC": # if final_score<0.45 and suspicious_ratio<0.1: # verdict = "REAL" # final_score=wavlm_res['avg_confidence'] # elif final_score0.45 and final_score>wavlm_res['avg_confidence']: verdict="REAL" # AudioSeal audioseal_prob = 0.0 if model.audioseal_branch.detector is not None: res = model.audioseal_branch.detector(waveform.unsqueeze(0).to(device), 16000) if isinstance(res, tuple): res = res[0] audioseal_prob = torch.max(res[0, 1, :]).item() if audioseal_prob > 0.8: verdict = verdict + " + WATERMARKED (AudioSeal)" # 4. Detailed Explanation print(f"\n{'='*60}") result_json['verdict'] = verdict result_json['verdict_reason'] = verdict_reason print(f"VERDICT: {verdict} {verdict_reason}") print(f"Confidence Score: {final_score:.4f} (Top-{K} Mean)") print(f"Suspicious Density: {suspicious_ratio*100:.1f}% (Threshold: {density_threshold*100:.0f}%)") print(f"AudioSeal Probability: {audioseal_prob:.4f}") print(f"{'='*60}") result_json['final_score'] = final_score result_json['suspicious_ratio'] = suspicious_ratio result_json['audioseal_prob'] = audioseal_prob if wavlm_res: print(f"Classification Results\n{'='*40}") print(f"VERDICT: {wavlm_res['verdict']}") print(f"Confidence (Avg): {wavlm_res['avg_confidence']:.4f}") print(f"Confidence (Max): {wavlm_res['max_confidence']:.4f}") print(f"{'='*40}") result_json['wavlm_verdict'] = wavlm_res['verdict'] result_json['wavlm_avg_confidence'] = wavlm_res['avg_confidence'] result_json['wavlm_max_confidence'] = wavlm_res['max_confidence'] # if verdict == "SYNTHETIC" or verdict == "SYNTHETIC + WATERMARKED (AudioSeal)": # explain_single_file_for_synth(audio_path, model_path, output_image) # else: print("\n>> DETAILED SEGMENT ANALYSIS (Significant Events > 50%):") print(f"{'Time (s)':<10} | {'Prob':<8} | {'Primary Factor':<15} | {'Weight':<8} | {'Forensic Diagnosis'}") print("-" * 85) result_json['detailed_segment_analysis'] = [] branches = ["Spectral/HiFi-GAN", "Emotion", "Phoneme", "GAN (TTS)", "Jitter/WaveNet", "WavLM (Resemble voice cloning)"] def get_forensic_diagnosis(branch, prob): # Heuristics for explaining detection causes if branch == "Spectral/HiFi-GAN": if prob < 0.85: return "Likely Noise/Compression Artifacts" return "Vocoder Artifacts (High Frequency)" elif branch == "WavLM (Resemble voice cloning)": if prob < 0.8: return "Structural Inconsistency" return "Strong AI Semantic Anomaly (WavLM)" elif branch == "Phoneme": return "Unnatural Articulation/Slurring" elif branch == "Emotion": return "Flat/Unnatural Prosody" elif branch == "GAN (TTS)": return "Repeating Patterns/Checkerboard" elif branch == "Jitter/WaveNet": return "Micro-Tremors (Robotic Glitch)" return "Unknown Artifact" has_significant_events = False # Use smoothed_probs for t, p, gw in zip(valid_timestamps, smoothed_probs, gate_weights_list): if p > 0.5: has_significant_events = True branch_idx = gw.argmax() branch_name = branches[branch_idx] if branch_idx < len(branches) else "Unknown" weight = gw[branch_idx] diagnosis = get_forensic_diagnosis(branch_name, p) result_json['detailed_segment_analysis'].append({ 'time': t, 'probability': p, 'branch': branch_name, 'weight': weight, 'diagnosis': diagnosis }) print(f"{t:<10.1f} | {p:<8.4f} | {branch_name:<15} | {weight:<8.2f} | {diagnosis}") if has_significant_events: if verdict == "REAL" or verdict == "REAL + WATERMARKED (AudioSeal)": result_json["detailed_segment_analysis_for_null"] = """ **Classification:** REAL (with Minor Anomalies) **Diagnostic Summary:** The audio is classified as **REAL** despite the presence of a limited number of diagnostic indicators that may superficially resemble synthetic artifacts. These indicators are explained by recording conditions and post-processing effects rather than AI generation. --- ### ⚠️ Explained Diagnostic Exceptions **Low-Frequency Stability:** Slight stability in low-frequency bands is consistent with aggressive noise suppression or high-pass filtering commonly applied during mobile or studio recordings. **Localized Waveform Smoothness:** Short regions of waveform smoothness align with lossy compression artifacts and automatic gain control behavior, not neural vocoder synthesis. **Reduced Transient Energy:** Lower transient intensity may result from microphone distance, placement, or post-processing such as dynamic range compression. --- ### ✅ Authenticity Evidence (Overrides Anomalies) **Physiological Variability:** Natural micro-variations in pitch, amplitude, and timing are consistently present, supporting organic speech production. **Irregular Harmonic Drift:** Non-uniform harmonic drift across frames is incompatible with modern AI speech synthesis models, which favor stable harmonic alignment. **Background Noise Coherence:** Environmental noise patterns remain temporally coherent and unsuppressed, a characteristic typically absent in AI-generated audio pipelines. --- **Interpretation:** Although a small number of diagnostic indicators initially suggested potential synthetic characteristics, they are fully attributable to recording and processing artifacts. The presence of physiological variability and irregular harmonic behavior confirms the audio as **REAL**. """ if not has_significant_events: if wavlm_res.get("verdict") == "SYNTHETIC" and verdict and "SYNTHETIC" in verdict.upper(): result_json["detailed_segment_analysis_for_null"] = """ **Analysis Type:** Global Signal Analysis **Classification:** Synthetic (Low Artifact) **Low-Frequency Energy:** Low-frequency bands below 300 Hz show *abnormally stable energy patterns*. Natural human speech normally exhibits micro-amplitude variations due to respiration and vocal fold instability. The observed stability indicates non-physiological signal generation. **Waveform Continuity:** The waveform shows *over-smoothed transitions* with consistent periodicity. This behavior is typical of modern neural vocoders that prioritize perceptual smoothness over physical realism. **Neural Vocoder Signature:** Spectral coherence and harmonic alignment match patterns associated with contemporary AI-based audio generators, where speech is synthesized through learned waveform representations rather than organic production. **Transient Anomaly Absence:** The signal lacks breath pops, plosive asymmetry, and micro-jitter. While perceptually clean, this absence is atypical for real speech and suggests algorithmic suppression of physiological noise. **Interpretation:** No high-energy synthetic artifacts were detected at the segment level. However, global signal characteristics strongly indicate AI-generated audio, engineered to preserve realistic waveform structure while minimizing natural physiological variability. """ else: print("No significant synthetic segments detected (after smoothing).") result_json['detailed_segment_analysis'] = [] # 5. Visualization (Advanced) plt.figure(figsize=(12, 10)) # Panel 1: Waveform plt.subplot(2, 1, 1) # Downsample waveform for plotting wave_np = waveform.squeeze().cpu().numpy() step = max(1, len(wave_np) // 2000) # Max 2000 points time_axis = np.linspace(0, len(wave_np)/16000, len(wave_np))[::step] plt.plot(time_axis, wave_np[::step], color='black', alpha=0.6) plt.title(f"1. Audio Waveform: {os.path.basename(audio_path)}") plt.ylabel("Amplitude") plt.grid(True, alpha=0.3) plt.xlim(0, valid_timestamps[-1] + 1) # Panel 2: Synthetic Probability with Threshold plt.subplot(2, 1, 2) if suspicious_ratio>0.25 or verdict == "SYNTHETIC" or verdict == "SYNTHETIC + WATERMARKED (AudioSeal)": plt.plot(valid_timestamps, smoothed_probs, label='Smoothed Probability', color='red', linewidth=2) # plt.plot(valid_timestamps, chunk_probs, label='Raw Probability', color='gray', alpha=0.4, linestyle=':') plt.axhline(y=0.5, color='gray', linestyle='--', label='Decision Threshold (0.5)') plt.fill_between(valid_timestamps, smoothed_probs, 0.5, where=(np.array(smoothed_probs) > 0.5), color='red', alpha=0.2) else: plt.plot(valid_timestamps, smoothed_probs, label='Smoothed Probability', color='green', linewidth=2) # plt.plot(valid_timestamps, chunk_probs, label='Raw Probability', color='gray', alpha=0.4, linestyle=':') plt.axhline(y=0.5, color='gray', linestyle='--', label='Decision Threshold (0.5)') plt.fill_between(valid_timestamps, smoothed_probs, 0.5, where=(np.array(smoothed_probs) > 0.5), color='red', alpha=0.2) plt.ylabel("Probability") plt.title("2. Synthetic Detection Confidence over Time (Smoothed)") plt.legend(loc='upper right') plt.grid(True, alpha=0.3) plt.xlim(0, valid_timestamps[-1] + 1) plt.ylim(0, 1.1) # # Panel 3: Forensic Bubble Chart (Scatter) # plt.subplot(3, 1, 3) # # Prepare XY data for scatter # x_points = [] # y_points = [] # sizes = [] # colors = [] # gw_array = np.array(gate_weights_list) # (T, 6) # # Handle shapes # if gw_array.shape[1] < 6: # gw_array = np.pad(gw_array, ((0,0), (0, 6-gw_array.shape[1])), mode='constant') # elif gw_array.shape[1] > 6: # gw_array = gw_array[:, :6] # for t_idx, t in enumerate(valid_timestamps): # weights = gw_array[t_idx] # for b_idx in range(6): # w = weights[b_idx] # if w > 0.1: # Only plot significant bubbles to reduce clutter # x_points.append(t) # y_points.append(b_idx) # sizes.append(w * 500) # Scale for visibility # colors.append(w) # # Plot Bubbles # scatter = plt.scatter(x_points, y_points, s=sizes, c=colors, cmap='viridis', alpha=0.7, edgecolors='black') # plt.colorbar(scatter, label="Attention Weight") # plt.yticks(range(6), branches) # plt.xlabel("Time (s)") # plt.title("3. Forensic Event Radar: Feature Contributions") # plt.grid(True, alpha=0.3, linestyle='--') # plt.xlim(0, valid_timestamps[-1] + 1) # plt.ylim(-0.5, 5.5) # # Condition Overlay: Highlight Fake Regions (Based on Smoothed Probs) # for t_idx, (t, p) in enumerate(zip(valid_timestamps, smoothed_probs)): # if p > 0.5: # # Red highlighting for FAKE zones # width = STRIDE / 16000.0 # rect = plt.Rectangle((t, -0.5), width, 6.0, linewidth=0, facecolor='red', alpha=0.1) # plt.gca().add_patch(rect) # # Mark the max branch # dom_branch = gw_array[t_idx].argmax() # # Place a red 'X' or '!' on the dominant bubble # plt.text(t + width/2, dom_branch, "!", color='white', fontweight='bold', ha='center', va='center') # ---------------- Prepare data ---------------- # plt.subplot(3, 1, 3) # # ---------------- Prepare data ---------------- # gw_array = np.array(gate_weights_list) # (T, F) # T, F = gw_array.shape # valid_timestamps = np.array(valid_timestamps) # smoothed_probs = np.array(smoothed_probs) # # Safety: enforce fixed feature count # if F != len(branches): # gw_array = gw_array[:, :len(branches)] # # ---------------- Time binning (critical for scale) ---------------- # time_bin_size = 0.25 # seconds (adjust if needed) # max_time = valid_timestamps[-1] # FEATURE_THRESHOLD = 0.5 # bins = np.arange(0, max_time + time_bin_size, time_bin_size) # binned_gw = np.zeros((len(branches), len(bins) - 1)) # binned_prob = np.zeros(len(bins) - 1) # masked_gw = np.ma.masked_where(binned_gw < FEATURE_THRESHOLD, binned_gw) # for i in range(len(bins) - 1): # mask = (valid_timestamps >= bins[i]) & (valid_timestamps < bins[i + 1]) # if mask.any(): # binned_gw[:, i] = gw_array[mask].mean(axis=0) # binned_prob[i] = smoothed_probs[mask].mean() # # ---------------- Plot heatmap ---------------- # im = plt.imshow( # masked_gw, # aspect="auto", # origin="lower", # cmap="viridis", # extent=[0, max_time, 0, len(branches)] # ) # plt.yticks(np.arange(len(branches)) + 0.5, branches) # plt.xlabel("Time (s)") # plt.title("3. Forensic Feature Contribution Heatmap") # cbar = plt.colorbar(im) # cbar.set_label("Feature Contribution Weight") # # ---------------- FAKE region overlay ---------------- # for i, p in enumerate(binned_prob): # if p > 0.5: # plt.axvspan( # bins[i], # bins[i + 1], # color="red", # alpha=0.12 # ) # plt.grid(False) # plt.xlim(0, max_time) # plt.ylim(0, len(branches)) plt.tight_layout() plt.savefig(output_image) plt.close() print(f"\nForensic Graph saved to: {output_image}") result_json['forensic_graph'] = output_image return result_json except Exception as e: print(f"Error during explanation: {e}") import traceback traceback.print_exc() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test and Explain Single Audio File") parser.add_argument("audio_path", type=str, help="Path to .wav file") parser.add_argument("--model", type=str, default="universal_detector_v2_sota.pth", help="Path to model") args = parser.parse_args() explain_single_file(args.audio_path, args.model)