Yeah. The difference in processing speed between those two functions (strictly speaking, mainly the difference in “processing timing”) occasionally comes up in discussions.
The difference is real, and in your specific code it is expected.
Two big causes dominate:
set_transform forces you onto the “formatting transform” path, which often hands you Python objects or NumPy arrays, not Torch tensors.
- Your
set_transform implementation does extra full-tensor allocations and (likely) double-batching.
Below is the “why”, mapped to your exact code.
0) What set_transform is doing under the hood
HF documents set_transform/with_transform as a formatting transform that:
- replaces
set_format()
- takes a batch as a dict and returns a batch
- is applied “right before returning the objects in
__getitem__”
- can be restricted to specific columns (
columns=[...]) (Hugging Face)
So if you call dataset.set_transform(transform) you are telling HF: “ignore the normal formatting pipeline, I will produce the final returned objects myself.” (Hugging Face)
That detail matters because in your wrapper approach you do:
self._dataset.set_format("torch")
and then you operate on already-created torch tensors.
With set_transform, you did not set torch formatting, and even if you did, the transform is documented as replacing that formatting layer. (Hugging Face)
1) The #1 time sink: torch.tensor(...) always copies
In your set_transform version you do:
batch_wav = torch.tensor(batch["wav"])
batch_act = torch.tensor(batch["act"])
torch.tensor(x) always makes a copy. PyTorch explicitly states torch.tensor(tensor) is equivalent to tensor.detach().clone(), and points you to torch.as_tensor() (avoid copies where possible) and torch.from_numpy() (shares storage with NumPy) as alternatives. (PyTorch Documentation)
So every dataset[0] in the set_transform path likely does:
- Arrow / Python / NumPy object materialization (HF side)
- then a full copy into a new torch tensor (
torch.tensor)
- then later another full allocation (see next section)
In the wrapper path you do:
item = self._dataset[index] # already torch because set_format('torch')
wav = item["wav"]
act = item["act"]
So you skip at least one huge copy per sample.
Why it can be much slower than you expect
If batch["wav"] is a Python nested list (possible depending on your dataset feature types), torch.tensor(list_of_lists) is particularly slow because it must walk Python objects and infer dtype/shape. If it is a NumPy array, it is still a copy (by definition). (PyTorch Documentation)
2) The #2 time sink: you allocate again with torch.stack
Inside your transform you do:
transformed_wav.append(wav)
...
batch_wav = torch.stack(transformed_wav)
torch.stack allocates a brand new tensor and copies data into it.
So in the set_transform path you likely allocate at least:
- Copy 1:
torch.tensor(batch["wav"])
- Copy 2:
torch.stack(transformed_wav)
In the wrapper path, you allocate less:
- You slice (often views) and permute (advanced indexing makes a copy, but you do that in both implementations)
- Then the DataLoader collates and stacks once
So the wrapper path tends to have:
- only the “necessary” copies (permutation + final collation)
while the set_transform path can add extra copies on top.
3) You may be “batching twice” with set_transform + DataLoader
This is the most common hidden footgun with set_transform.
Your set_transform returns batched tensors:
dataset[0] returns {"wav": (1, …), "act": (1, …)} because you stack, even for one item.
Then PyTorch DataLoader with automatic batching does:
fetch a list of samples, then call collate_fn([dataset[i] for i in indices]) (PyTorch Documentation)
So if each dataset[i] already contains a leading batch dimension of 1, the default collate will typically produce something shaped like (batch_size, 1, …) and will stack/copy again.
Even if you later squeeze that dimension, you already paid the cost.
In contrast, your wrapper returns per-sample wav: (mic, time) and act: (spk, frames), and the DataLoader stacks once into (batch_size, mic, time).
This single structural mismatch can easily turn into a multi-x slowdown at training time.
4) HF set_transform has known perf issues even for trivial transforms
Independent of your code, there are real reports that set_transform can drastically slow iteration even when the transform itself is trivial. (GitHub)
So you have:
- genuine framework overhead (HF formatting transform path) (GitHub)
- plus your own extra allocations (
torch.tensor, torch.stack)
- plus possibly double batching (extra stacking in DataLoader) (PyTorch Documentation)
A 10× delta on dataset[0] and 3× on full training is plausible under that combination.
5) Why your wrapper is “fundamentally faster”
Because it lines up with the normal PyTorch data pipeline:
__getitem__ returns one sample
- DataLoader batches once (collation) (PyTorch Documentation)
- You avoid unnecessary tensor construction and stacking
Also, set_format("torch") keeps the conversion logic inside HF’s optimized formatting code rather than redoing it manually via torch.tensor on every call. (And set_transform is explicitly a replacement for that formatting layer.) (Hugging Face)
6) How to confirm exactly which factor dominates (quick checks)
Do these prints/timings inside your WavActTransform.__call__:
- Check input types
print(type(batch["wav"]), type(batch["act"]))
If you see Python lists, that explains a lot.
- Time just the conversion
Time these separately:
torch.tensor(batch["wav"])
torch.as_tensor(batch["wav"]) (if it is NumPy)
PyTorch explicitly says as_tensor tries to avoid copies and from_numpy shares storage. (PyTorch Documentation)
- Check output shapes
Print batch_wav.shape returned by transform and the shape your training step receives from DataLoader. If you see (B, 1, …) anywhere, you are double batching.
7) What to do going forward (practical recommendations)
Best “clean + fast” option
Keep your wrapper dataset (your current fast solution). It is idiomatic PyTorch and avoids the set_transform formatting overhead.
If you want to stay purely in HF Dataset without a wrapper
Do augmentation in collate_fn instead:
- Keep
dataset.set_format("torch")
- DataLoader gives you a list of torch tensors
- You augment each item and then stack once
This matches PyTorch’s model: DataLoader fetches samples then collates to a batch. (PyTorch Documentation)
If you insist on set_transform
Then rewrite the transform to avoid the biggest traps:
- Avoid
torch.tensor(...) on every call. It always copies. Prefer torch.as_tensor / torch.from_numpy when possible. (PyTorch Documentation)
- Avoid producing an extra batch dimension (avoid
torch.stack for the single-item __getitem__ case), or disable DataLoader automatic batching and let the dataset return batches (more complex).
- Use
columns=["wav","act"] in set_transform to minimize what HF passes into your transform. (Hugging Face)
Even then, be aware of the known set_transform slowdown reports. (GitHub)
Good references (directly relevant)
- HF docs:
set_transform replaces set_format, batch-dict signature, applied in __getitem__, columns option (Hugging Face)
- PyTorch docs: DataLoader batches by fetching samples then calling
collate_fn([dataset[i]...]) (PyTorch Documentation)
- PyTorch docs:
torch.tensor(t) is clone-like (copies); as_tensor avoids copies where possible; from_numpy shares storage (PyTorch Documentation)
- HF issue: trivial
set_transform causing ~10× slower iteration (GitHub)
Summary
set_transform is a formatting hook that replaces set_format, so you often lose the fast “return torch tensors” path. (Hugging Face)
- Your
set_transform code does extra full copies: torch.tensor (always copies) plus torch.stack (allocates again). (PyTorch Documentation)
- You may also be batching twice when DataLoader stacks samples that already have a leading batch dimension. (PyTorch Documentation)
- HF has known reports of
set_transform iteration slowdowns even for trivial transforms. (GitHub)