Speaker Proxy Network (RVQ β†’ Speaker Embedding)

A lightweight differentiable surrogate that maps Qwen3-TTS RVQ embeddings directly to speaker embeddings, bypassing the expensive audio-decoding β†’ feature-extraction pipeline during voice-conversion training.

⚠️ Note: This repository contains only the Speaker Proxy. The full RVQ proxy (speaker + wav2vec + mel) is a separate effort. This checkpoint is the standalone speaker branch, trained with a pure contrastive objective on real speaker labels.


Why a Speaker Proxy?

During voice-conversion training, the standard pipeline is:

model logits β†’ argmax β†’ RVQ tokens β†’ decoder β†’ waveform β†’ ECAPA-TDNN β†’ speaker embedding

This pipeline is non-differentiable because of argmax and the audio decoder. The Speaker Proxy replaces it with:

model logits β†’ softmax β†’ RVQ sum embedding β†’ SpeakerProxyECAPA β†’ L2-normalized speaker embedding

Everything after softmax is now differentiable, enabling end-to-end backpropagation through the entire voice-conversion objective.


Architecture

SpeakerProxyECAPA β€” an ECAPA-TDNN-style network adapted for RVQ-sum inputs.

Component Details
Input [B, T, 2048] RVQ sum embedding (sum of 16 learned codebook embeddings)
Front-end Conv1d projection + SE-Res2Blocks (dilations 2, 3, 4)
Pooling Attentive Statistics Pooling (mean + std, attention-weighted)
Bottleneck FC β†’ 192-dim
Output L2-normalized 192-dim speaker embedding
Parameters ~4.6M

The architecture mirrors the original SpeechBrain ECAPA-TDNN but is trained end-to-end on RVQ inputs rather than raw audio spectrograms.


Training

Detail Value
Dataset lonesamurai/emilia_clean_10k (10,000 clips, 200 speakers)
Train / Val split 8,000 / 2,000 clips
Epochs ~200
Loss Pure contrastive β€” (1βˆ’cos)Β² alignment + λ·ReLU(cosβˆ’margin)Β² repulsion
Ξ» (repel) 5.0
Optimizer AdamW, lr = 1e-4, weight_decay = 1e-5
Best val separation 0.8141

Validation performance (contrastive separation metric)

  • Best checkpoint: epoch ~140, separation = 0.8141
  • Final checkpoint: epoch ~197, separation β‰ˆ 0.73 (plateaued)

Comparison with Original ECAPA-TDNN

Tested on 5 seen + 5 unseen speakers from EMILIA:

Metric SpeakerProxy (Ours) Original ECAPA-TDNN
Seen-Seen off-diag mean 0.050 0.094
Unseen-Unseen off-diag mean βˆ’0.026 0.060
Seen-Unseen off-diag mean βˆ’0.026 0.033
All off-diag mean βˆ’0.009 0.053
Off-diag std 0.156 0.098
Worst confusion (max) 0.420 0.327
Per-speaker separation (seen avg) 0.992 0.940
Per-speaker separation (unseen avg) 1.024 0.955

Takeaway: Our proxy achieves stronger average separation than the original audio-based ECAPA, especially on unseen speakers (negative mean similarity vs. positive). The trade-off is slightly higher variance β€” a few outlier pairs show stronger confusion, but the vast majority of speaker pairs are pushed farther apart.


Checkpoints

File Description
speaker_proxy_10k_best.pt Best checkpoint (val separation = 0.8141, ~epoch 140)

The checkpoint contains:

  • model_state_dict: full network weights
  • config: architecture hyperparameters
  • epoch: training epoch at save time
  • val_separation: best validation metric

Usage

import torch
from exiv.components.models.qwen3_tts.sern.speaker_proxy_ecapa import SpeakerProxyECAPA

# Load checkpoint
checkpoint = torch.load("speaker_proxy_10k_best.pt", map_location="cpu")
config = checkpoint["config"]

# Build model
proxy = SpeakerProxyECAPA(
    input_dim=config["input_dim"],      # 2048
    embed_dim=config["embed_dim"],      # 192
    channels=config["channels"],        # 512
    num_blocks=config["num_blocks"],    # 3
)
proxy.load_state_dict(checkpoint["model_state_dict"])
proxy.eval().cuda()

# Forward pass β€” E_rvq is the sum of 16 RVQ embedding tables
# E_rvq: [B, T, 2048] from Qwen3-TTS RVQ tokens
speaker_embedding = proxy(E_rvq)  # [B, 192], L2-normalized

Computing RVQ sum embeddings from Qwen3-TTS tokens

# Extract the 16 embedding tables from Qwen3-TTS
embedding_tables = [
    model.model.embed_tokens[i].weight for i in range(16)
]

# tokens: [B, T, 16] integer RVQ indices
E_rvq = torch.stack([
    embedding_tables[i][tokens[..., i]] for i in range(16)
], dim=-1).sum(dim=-1)  # [B, T, 2048]

Requirements

  • PyTorch β‰₯ 2.0
  • See Exiv for full integration with Qwen3-TTS SERN adapter

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support