Speaker Proxy Network (RVQ β Speaker Embedding)
A lightweight differentiable surrogate that maps Qwen3-TTS RVQ embeddings directly to speaker embeddings, bypassing the expensive audio-decoding β feature-extraction pipeline during voice-conversion training.
β οΈ Note: This repository contains only the Speaker Proxy. The full RVQ proxy (speaker + wav2vec + mel) is a separate effort. This checkpoint is the standalone speaker branch, trained with a pure contrastive objective on real speaker labels.
Why a Speaker Proxy?
During voice-conversion training, the standard pipeline is:
model logits β argmax β RVQ tokens β decoder β waveform β ECAPA-TDNN β speaker embedding
This pipeline is non-differentiable because of argmax and the audio decoder. The Speaker Proxy replaces it with:
model logits β softmax β RVQ sum embedding β SpeakerProxyECAPA β L2-normalized speaker embedding
Everything after softmax is now differentiable, enabling end-to-end backpropagation through the entire voice-conversion objective.
Architecture
SpeakerProxyECAPA β an ECAPA-TDNN-style network adapted for RVQ-sum inputs.
| Component | Details |
|---|---|
| Input | [B, T, 2048] RVQ sum embedding (sum of 16 learned codebook embeddings) |
| Front-end | Conv1d projection + SE-Res2Blocks (dilations 2, 3, 4) |
| Pooling | Attentive Statistics Pooling (mean + std, attention-weighted) |
| Bottleneck | FC β 192-dim |
| Output | L2-normalized 192-dim speaker embedding |
| Parameters | ~4.6M |
The architecture mirrors the original SpeechBrain ECAPA-TDNN but is trained end-to-end on RVQ inputs rather than raw audio spectrograms.
Training
| Detail | Value |
|---|---|
| Dataset | lonesamurai/emilia_clean_10k (10,000 clips, 200 speakers) |
| Train / Val split | 8,000 / 2,000 clips |
| Epochs | ~200 |
| Loss | Pure contrastive β (1βcos)Β² alignment + λ·ReLU(cosβmargin)Β² repulsion |
| Ξ» (repel) | 5.0 |
| Optimizer | AdamW, lr = 1e-4, weight_decay = 1e-5 |
| Best val separation | 0.8141 |
Validation performance (contrastive separation metric)
- Best checkpoint: epoch ~140, separation = 0.8141
- Final checkpoint: epoch ~197, separation β 0.73 (plateaued)
Comparison with Original ECAPA-TDNN
Tested on 5 seen + 5 unseen speakers from EMILIA:
| Metric | SpeakerProxy (Ours) | Original ECAPA-TDNN |
|---|---|---|
| Seen-Seen off-diag mean | 0.050 | 0.094 |
| Unseen-Unseen off-diag mean | β0.026 | 0.060 |
| Seen-Unseen off-diag mean | β0.026 | 0.033 |
| All off-diag mean | β0.009 | 0.053 |
| Off-diag std | 0.156 | 0.098 |
| Worst confusion (max) | 0.420 | 0.327 |
| Per-speaker separation (seen avg) | 0.992 | 0.940 |
| Per-speaker separation (unseen avg) | 1.024 | 0.955 |
Takeaway: Our proxy achieves stronger average separation than the original audio-based ECAPA, especially on unseen speakers (negative mean similarity vs. positive). The trade-off is slightly higher variance β a few outlier pairs show stronger confusion, but the vast majority of speaker pairs are pushed farther apart.
Checkpoints
| File | Description |
|---|---|
speaker_proxy_10k_best.pt |
Best checkpoint (val separation = 0.8141, ~epoch 140) |
The checkpoint contains:
model_state_dict: full network weightsconfig: architecture hyperparametersepoch: training epoch at save timeval_separation: best validation metric
Usage
import torch
from exiv.components.models.qwen3_tts.sern.speaker_proxy_ecapa import SpeakerProxyECAPA
# Load checkpoint
checkpoint = torch.load("speaker_proxy_10k_best.pt", map_location="cpu")
config = checkpoint["config"]
# Build model
proxy = SpeakerProxyECAPA(
input_dim=config["input_dim"], # 2048
embed_dim=config["embed_dim"], # 192
channels=config["channels"], # 512
num_blocks=config["num_blocks"], # 3
)
proxy.load_state_dict(checkpoint["model_state_dict"])
proxy.eval().cuda()
# Forward pass β E_rvq is the sum of 16 RVQ embedding tables
# E_rvq: [B, T, 2048] from Qwen3-TTS RVQ tokens
speaker_embedding = proxy(E_rvq) # [B, 192], L2-normalized
Computing RVQ sum embeddings from Qwen3-TTS tokens
# Extract the 16 embedding tables from Qwen3-TTS
embedding_tables = [
model.model.embed_tokens[i].weight for i in range(16)
]
# tokens: [B, T, 16] integer RVQ indices
E_rvq = torch.stack([
embedding_tables[i][tokens[..., i]] for i in range(16)
], dim=-1).sum(dim=-1) # [B, T, 2048]
Requirements
- PyTorch β₯ 2.0
- See Exiv for full integration with Qwen3-TTS SERN adapter
License
MIT