alex commited on
Commit
25889c7
·
1 Parent(s): 43067da

now with audio support

Browse files
app.py CHANGED
@@ -2,6 +2,80 @@ import sys
2
  from pathlib import Path
3
  import uuid
4
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Add packages to Python path
7
  current_dir = Path(__file__).parent
@@ -17,9 +91,11 @@ import random
17
  import torch
18
  from typing import Optional
19
  from pathlib import Path
 
20
  from huggingface_hub import hf_hub_download, snapshot_download
21
  from ltx_pipelines.distilled import DistilledPipeline
22
  from ltx_core.model.video_vae import TilingConfig
 
23
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
24
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
25
  from ltx_pipelines.utils.constants import (
@@ -31,6 +107,8 @@ from ltx_pipelines.utils.constants import (
31
  DEFAULT_LORA_STRENGTH,
32
  )
33
  from ltx_core.loader.single_gpu_model_builder import enable_only_lora
 
 
34
  from PIL import Image
35
 
36
  MAX_SEED = np.iinfo(np.int32).max
@@ -38,6 +116,11 @@ MAX_SEED = np.iinfo(np.int32).max
38
  # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
39
  from ltx_pipelines.utils import ModelLedger
40
  from ltx_pipelines.utils.helpers import generate_enhanced_prompt
 
 
 
 
 
41
 
42
  # HuggingFace Hub defaults
43
  DEFAULT_REPO_ID = "Lightricks/LTX-2"
@@ -82,6 +165,8 @@ model_ledger = ModelLedger(
82
  local_files_only=False
83
  )
84
 
 
 
85
 
86
  # Load text encoder once and keep it in memory
87
  text_encoder = model_ledger.text_encoder()
@@ -90,6 +175,109 @@ print("=" * 80)
90
  print("Text encoder loaded and ready!")
91
  print("=" * 80)
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def encode_text_simple(text_encoder, prompt: str):
94
  """Simple text encoding without using pipeline_utils."""
95
  v_context, a_context, _ = text_encoder(prompt)
@@ -262,6 +450,7 @@ RUNTIME_LORA_CHOICES = [
262
  ("Slide Right", 5),
263
  ("Slide Down", 6),
264
  ("Slide Up", 7),
 
265
  ]
266
 
267
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
@@ -580,7 +769,7 @@ def generate_video_example(input_image, prompt, camera_lora, resolution, progres
580
 
581
  w, h = apply_resolution(resolution)
582
 
583
- output_video, seed = generate_video(
584
  input_image,
585
  prompt,
586
  10, # duration seconds
@@ -589,18 +778,18 @@ def generate_video_example(input_image, prompt, camera_lora, resolution, progres
589
  True, # randomize_seed
590
  h, # height
591
  w, # width
592
- camera_lora,
 
593
  progress
594
  )
595
 
596
  return output_video
597
-
598
-
599
  def generate_video_example_t2v(prompt, camera_lora, resolution, progress=gr.Progress(track_tqdm=True)):
600
 
601
  w, h = apply_resolution(resolution)
602
 
603
- output_video, seed = generate_video(
604
  None,
605
  prompt,
606
  15, # duration seconds
@@ -609,11 +798,32 @@ def generate_video_example_t2v(prompt, camera_lora, resolution, progress=gr.Prog
609
  True, # randomize_seed
610
  h, # height
611
  w, # width
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  camera_lora,
 
613
  progress
614
  )
 
615
  return output_video
616
-
617
  def get_duration(
618
  input_image,
619
  prompt,
@@ -624,14 +834,20 @@ def get_duration(
624
  height,
625
  width,
626
  camera_lora,
 
627
  progress
628
  ):
 
 
 
 
 
629
  if duration <= 5:
630
- return 80
631
  elif duration <= 10:
632
- return 120
633
  else:
634
- return 180
635
 
636
  @spaces.GPU(duration=get_duration)
637
  def generate_video(
@@ -644,6 +860,7 @@ def generate_video(
644
  height: int = DEFAULT_1_STAGE_HEIGHT,
645
  width: int = DEFAULT_1_STAGE_WIDTH,
646
  camera_lora: str = "No LoRA",
 
647
  progress=gr.Progress(track_tqdm=True),
648
  ):
649
  """
@@ -705,10 +922,21 @@ def generate_video(
705
  audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
706
  print("✓ Embeddings loaded successfully")
707
 
 
708
  # free prompt enhancer / encoder temps ASAP
709
  del embeddings, final_prompt, status
710
  torch.cuda.empty_cache()
711
 
 
 
 
 
 
 
 
 
 
 
712
 
713
  # Map dropdown name -> adapter index
714
  name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
@@ -717,6 +945,22 @@ def generate_video(
717
  enable_only_lora(pipeline._transformer, selected_idx)
718
  torch.cuda.empty_cache()
719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  # Run inference - progress automatically tracks tqdm from pipeline
721
  with torch.inference_mode():
722
  pipeline(
@@ -731,12 +975,14 @@ def generate_video(
731
  tiling_config=TilingConfig.default(),
732
  video_context=video_context,
733
  audio_context=audio_context,
 
 
734
  )
735
  del video_context, audio_context
736
  torch.cuda.empty_cache()
737
  print("successful generation")
738
 
739
- return str(output_path), current_seed
740
 
741
 
742
 
@@ -1160,12 +1406,13 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
1160
  height=512
1161
  )
1162
 
1163
-
1164
  prompt_ui = PromptBox(
1165
  value="Make this image come alive with cinematic motion, smooth animation",
1166
  elem_id="prompt_ui",
1167
  )
1168
 
 
 
1169
  prompt = gr.Textbox(
1170
  label="Prompt",
1171
  value="Make this image come alive with cinematic motion, smooth animation",
@@ -1302,11 +1549,13 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
1302
  height,
1303
  width,
1304
  camera_lora,
 
1305
  ],
1306
- outputs=[output_video,seed]
1307
  )
1308
 
1309
 
 
1310
  timestep_prompt = """Style: Realistic live-action, cinematic, shallow depth of field, 24 fps, natural and dramatic lighting
1311
 
1312
  Environment: Interior of a space station module or realistic mock-up, metal panels, blinking lights, Earth visible through a large window
@@ -1331,6 +1580,24 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
1331
 
1332
  Music: subtle cinematic synth or ambient pad, futuristic and minimal, emphasizing awe and solitude"""
1333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1334
 
1335
  gr.Examples(
1336
  examples=[
@@ -1402,6 +1669,5 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
1402
  )
1403
 
1404
 
1405
-
1406
  if __name__ == "__main__":
1407
  demo.launch(ssr_mode=False, mcp_server=True, css=css)
 
2
  from pathlib import Path
3
  import uuid
4
  import tempfile
5
+ import subprocess
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ import os
10
+ from typing import Any
11
+
12
+ def _coerce_audio_path(audio_path: Any) -> str:
13
+ # Common Gradio case: tuple where first item is the filepath
14
+ if isinstance(audio_path, tuple) and len(audio_path) > 0:
15
+ audio_path = audio_path[0]
16
+
17
+ # Some gradio versions pass a dict-like object
18
+ if isinstance(audio_path, dict):
19
+ # common keys: "name", "path"
20
+ audio_path = audio_path.get("name") or audio_path.get("path")
21
+
22
+ # pathlib.Path etc.
23
+ if not isinstance(audio_path, (str, bytes, os.PathLike)):
24
+ raise TypeError(f"audio_path must be a path-like, got {type(audio_path)}: {audio_path}")
25
+
26
+ return os.fspath(audio_path)
27
+
28
+
29
+
30
+ def match_audio_to_duration(
31
+ audio_path: str,
32
+ target_seconds: float,
33
+ target_sr: int = 48000,
34
+ to_mono: bool = True,
35
+ pad_mode: str = "silence", # "silence" | "repeat"
36
+ device: str = "cuda",
37
+ ):
38
+ """
39
+ Load audio, resample, (optionally) mono, then trim/pad to exactly target_seconds.
40
+ Returns: (waveform[T] or [1,T], sr)
41
+ """
42
+ audio_path = _coerce_audio_path(audio_path)
43
+
44
+ wav, sr = torchaudio.load(audio_path) # [C, T] float32 CPU
45
+
46
+ # Resample to target_sr (recommended so duration math is stable)
47
+ if sr != target_sr:
48
+ wav = torchaudio.functional.resample(wav, sr, target_sr)
49
+ sr = target_sr
50
+
51
+ # Mono (common expectation; if your model supports stereo, set to_mono=False)
52
+ if to_mono and wav.shape[0] > 1:
53
+ wav = wav.mean(dim=0, keepdim=True) # [1, T]
54
+
55
+ # Exact target length in samples
56
+ target_len = int(round(target_seconds * sr))
57
+ cur_len = wav.shape[-1]
58
+
59
+ if cur_len > target_len:
60
+ wav = wav[..., :target_len]
61
+ elif cur_len < target_len:
62
+ pad_len = target_len - cur_len
63
+ if pad_mode == "repeat" and cur_len > 0:
64
+ # Repeat then cut to exact length
65
+ reps = (target_len + cur_len - 1) // cur_len
66
+ wav = wav.repeat(1, reps)[..., :target_len]
67
+ else:
68
+ # Silence pad
69
+ wav = F.pad(wav, (0, pad_len))
70
+
71
+ # move to device
72
+ wav = wav.to(device, non_blocking=True)
73
+ return wav, sr
74
+
75
+
76
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
77
+
78
+ sh("pip install --no-deps easy_dwpose")
79
 
80
  # Add packages to Python path
81
  current_dir = Path(__file__).parent
 
91
  import torch
92
  from typing import Optional
93
  from pathlib import Path
94
+ import torchaudio
95
  from huggingface_hub import hf_hub_download, snapshot_download
96
  from ltx_pipelines.distilled import DistilledPipeline
97
  from ltx_core.model.video_vae import TilingConfig
98
+ from ltx_core.model.audio_vae.ops import AudioProcessor
99
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
100
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
101
  from ltx_pipelines.utils.constants import (
 
107
  DEFAULT_LORA_STRENGTH,
108
  )
109
  from ltx_core.loader.single_gpu_model_builder import enable_only_lora
110
+ from ltx_core.model.audio_vae import decode_audio
111
+ from ltx_core.model.audio_vae import encode_audio
112
  from PIL import Image
113
 
114
  MAX_SEED = np.iinfo(np.int32).max
 
116
  # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
117
  from ltx_pipelines.utils import ModelLedger
118
  from ltx_pipelines.utils.helpers import generate_enhanced_prompt
119
+ import imageio
120
+ import cv2
121
+ from controlnet_aux import CannyDetector
122
+ from easy_dwpose import DWposeDetector
123
+
124
 
125
  # HuggingFace Hub defaults
126
  DEFAULT_REPO_ID = "Lightricks/LTX-2"
 
165
  local_files_only=False
166
  )
167
 
168
+ canny_processor = CannyDetector()
169
+
170
 
171
  # Load text encoder once and keep it in memory
172
  text_encoder = model_ledger.text_encoder()
 
175
  print("Text encoder loaded and ready!")
176
  print("=" * 80)
177
 
178
+ def on_lora_change(selected: str):
179
+ needs_video = selected in {"Pose", "Canny", "Detailer"}
180
+ return (
181
+ selected,
182
+ gr.update(visible=not needs_video, value=None if needs_video else None),
183
+ gr.update(visible=needs_video, value=None if not needs_video else None),
184
+ )
185
+
186
+
187
+ def process_video_for_pose(frames, width: int, height: int):
188
+
189
+ pose_processor = DWposeDetector("cuda")
190
+
191
+ if not frames:
192
+ return []
193
+
194
+ pose_frames = []
195
+ for frame in frames:
196
+ # imageio frame -> PIL
197
+ pil = Image.fromarray(frame.astype(np.uint8)).convert("RGB")
198
+
199
+ # ✅ do NOT pass width/height here (easy_dwpose will handle drawing sizes internally)
200
+ pose_img = pose_processor(pil)
201
+
202
+ # Ensure it's PIL then resize to your conditioning size
203
+ if not isinstance(pose_img, Image.Image):
204
+ # some versions might return np array
205
+ pose_img = Image.fromarray(pose_img.astype(np.uint8))
206
+
207
+ pose_img = pose_img.convert("RGB").resize((width, height), Image.BILINEAR)
208
+
209
+ pose_np = np.array(pose_img).astype(np.float32) / 255.0
210
+ pose_frames.append(pose_np)
211
+
212
+ return pose_frames
213
+
214
+
215
+ def preprocess_video_to_pose_mp4(video_path: str, width: int, height: int, fps: float):
216
+ frames = load_video_frames(video_path)
217
+ pose_frames = process_video_for_pose(frames, width=width, height=height)
218
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
219
+ tmp.close()
220
+ return write_video_mp4(pose_frames, fps=fps, out_path=tmp.name)
221
+
222
+
223
+ def load_video_frames(video_path: str):
224
+ """Return list of frames as numpy arrays (H,W,3) uint8."""
225
+ frames = []
226
+ with imageio.get_reader(video_path) as reader:
227
+ for frame in reader:
228
+ frames.append(frame)
229
+ return frames
230
+
231
+
232
+ def process_video_for_canny(frames, width: int, height: int,
233
+ low_threshold=50, high_threshold=200):
234
+ """
235
+ Convert RGB frames -> canny edge frames.
236
+ Returns list of np arrays (H,W,3) in float [0..1] (like controlnet_aux output).
237
+ """
238
+ if not frames:
239
+ return []
240
+
241
+ detect_resolution = max(frames[0].shape[0], frames[0].shape[1])
242
+ image_resolution = max(width, height)
243
+
244
+ canny_frames = []
245
+ for frame in frames:
246
+ # controlnet_aux CannyDetector returns float image in [0..1] if output_type="np"
247
+ canny = canny_processor(
248
+ frame,
249
+ low_threshold=low_threshold,
250
+ high_threshold=high_threshold,
251
+ detect_resolution=detect_resolution,
252
+ image_resolution=image_resolution,
253
+ output_type="np",
254
+ )
255
+ canny_frames.append(canny)
256
+
257
+ return canny_frames
258
+
259
+
260
+ def write_video_mp4(frames_float_01, fps: float, out_path: str):
261
+ """Write frames in float [0..1] to mp4 as uint8."""
262
+ frames_uint8 = [(f * 255).astype(np.uint8) for f in frames_float_01]
263
+
264
+ # PyAV backend doesn't support `quality=...`
265
+ with imageio.get_writer(out_path, fps=fps, macro_block_size=1) as writer:
266
+ for fr in frames_uint8:
267
+ writer.append_data(fr)
268
+ return out_path
269
+
270
+
271
+
272
+ def preprocess_video_to_canny_mp4(video_path: str, width: int, height: int, fps: float):
273
+ """End-to-end: read video -> canny -> write temp mp4 -> return path."""
274
+ frames = load_video_frames(video_path)
275
+ canny_frames = process_video_for_canny(frames, width=width, height=height)
276
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
277
+ tmp.close()
278
+ return write_video_mp4(canny_frames, fps=fps, out_path=tmp.name)
279
+
280
+
281
  def encode_text_simple(text_encoder, prompt: str):
282
  """Simple text encoding without using pipeline_utils."""
283
  v_context, a_context, _ = text_encoder(prompt)
 
450
  ("Slide Right", 5),
451
  ("Slide Down", 6),
452
  ("Slide Up", 7),
453
+
454
  ]
455
 
456
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
 
769
 
770
  w, h = apply_resolution(resolution)
771
 
772
+ output_video = generate_video(
773
  input_image,
774
  prompt,
775
  10, # duration seconds
 
778
  True, # randomize_seed
779
  h, # height
780
  w, # width
781
+ camera_lora,
782
+ None,
783
  progress
784
  )
785
 
786
  return output_video
787
+
 
788
  def generate_video_example_t2v(prompt, camera_lora, resolution, progress=gr.Progress(track_tqdm=True)):
789
 
790
  w, h = apply_resolution(resolution)
791
 
792
+ output_video = generate_video(
793
  None,
794
  prompt,
795
  15, # duration seconds
 
798
  True, # randomize_seed
799
  h, # height
800
  w, # width
801
+ camera_lora,
802
+ None,
803
+ progress
804
+ )
805
+ return output_video
806
+
807
+ def generate_video_example_s2v(input_image, prompt, camera_lora, resolution, audio_path, progress=gr.Progress(track_tqdm=True)):
808
+
809
+ w, h = apply_resolution(resolution)
810
+
811
+ output_video = generate_video(
812
+ input_image,
813
+ prompt,
814
+ 10, # duration seconds
815
+ True, # enhance_prompt
816
+ 42, # seed
817
+ True, # randomize_seed
818
+ h, # height
819
+ w, # width
820
  camera_lora,
821
+ audio_path,
822
  progress
823
  )
824
+
825
  return output_video
826
+
827
  def get_duration(
828
  input_image,
829
  prompt,
 
834
  height,
835
  width,
836
  camera_lora,
837
+ audio_path,
838
  progress
839
  ):
840
+ extra_time = 0
841
+
842
+ if audio_path is not None or input_image is None:
843
+ extra_time += 10
844
+
845
  if duration <= 5:
846
+ return 80 + extra_time
847
  elif duration <= 10:
848
+ return 120 + extra_time
849
  else:
850
+ return 180 + extra_time
851
 
852
  @spaces.GPU(duration=get_duration)
853
  def generate_video(
 
860
  height: int = DEFAULT_1_STAGE_HEIGHT,
861
  width: int = DEFAULT_1_STAGE_WIDTH,
862
  camera_lora: str = "No LoRA",
863
+ audio_path = None,
864
  progress=gr.Progress(track_tqdm=True),
865
  ):
866
  """
 
922
  audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
923
  print("✓ Embeddings loaded successfully")
924
 
925
+
926
  # free prompt enhancer / encoder temps ASAP
927
  del embeddings, final_prompt, status
928
  torch.cuda.empty_cache()
929
 
930
+ # ✅ if user provided audio, use a neutral audio_context
931
+ n_audio_context = None
932
+
933
+ if audio_path is not None:
934
+ with torch.inference_mode():
935
+ _, n_audio_context = encode_text_simple(text_encoder, "") # returns tensors on GPU already
936
+ del audio_context
937
+ audio_context = n_audio_context
938
+
939
+ torch.cuda.empty_cache()
940
 
941
  # Map dropdown name -> adapter index
942
  name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
 
945
  enable_only_lora(pipeline._transformer, selected_idx)
946
  torch.cuda.empty_cache()
947
 
948
+ # True video duration in seconds based on your rounding
949
+ video_seconds = (num_frames - 1) / frame_rate
950
+
951
+ if audio_path is not None:
952
+ input_waveform, input_waveform_sample_rate = match_audio_to_duration(
953
+ audio_path=audio_path,
954
+ target_seconds=video_seconds,
955
+ target_sr=48000, # pick what your model expects; 48k is common for AV models
956
+ to_mono=True, # set False if your model wants stereo
957
+ pad_mode="silence", # or "repeat" if you prefer looping over silence
958
+ device="cuda",
959
+ )
960
+ else:
961
+ input_waveform = None
962
+ input_waveform_sample_rate = None
963
+
964
  # Run inference - progress automatically tracks tqdm from pipeline
965
  with torch.inference_mode():
966
  pipeline(
 
975
  tiling_config=TilingConfig.default(),
976
  video_context=video_context,
977
  audio_context=audio_context,
978
+ input_waveform=input_waveform,
979
+ input_waveform_sample_rate=input_waveform_sample_rate,
980
  )
981
  del video_context, audio_context
982
  torch.cuda.empty_cache()
983
  print("successful generation")
984
 
985
+ return str(output_path)
986
 
987
 
988
 
 
1406
  height=512
1407
  )
1408
 
 
1409
  prompt_ui = PromptBox(
1410
  value="Make this image come alive with cinematic motion, smooth animation",
1411
  elem_id="prompt_ui",
1412
  )
1413
 
1414
+ audio_input = gr.Audio(label="Audio (Optional)", type="filepath")
1415
+
1416
  prompt = gr.Textbox(
1417
  label="Prompt",
1418
  value="Make this image come alive with cinematic motion, smooth animation",
 
1549
  height,
1550
  width,
1551
  camera_lora,
1552
+ audio_input
1553
  ],
1554
+ outputs=[output_video]
1555
  )
1556
 
1557
 
1558
+
1559
  timestep_prompt = """Style: Realistic live-action, cinematic, shallow depth of field, 24 fps, natural and dramatic lighting
1560
 
1561
  Environment: Interior of a space station module or realistic mock-up, metal panels, blinking lights, Earth visible through a large window
 
1580
 
1581
  Music: subtle cinematic synth or ambient pad, futuristic and minimal, emphasizing awe and solitude"""
1582
 
1583
+ gr.Examples(
1584
+ examples=[
1585
+
1586
+ [
1587
+ "supergirl-2.png",
1588
+ "A fuzzy puppet superhero character resembling a female puppet with blonde hair and a blue superhero suit sleeping in bed and just waking up, she gradually gets up, rubbing her eyes and looking at her dog that just popped on the bed. the scene feels chaotic, comedic, and emotional with expressive puppet reactions, cinematic lighting, smooth camera motion, shallow depth of field, and high-quality puppet-style animation",
1589
+ "Static",
1590
+ "16:9",
1591
+ "supergirl.m4a"
1592
+ ],
1593
+
1594
+ ],
1595
+ fn=generate_video_example_s2v,
1596
+ inputs=[input_image, prompt_ui, camera_lora_ui, radioanimated_resolution, audio_input],
1597
+ outputs = [output_video],
1598
+ label="S2V Example",
1599
+ cache_examples=True,
1600
+ )
1601
 
1602
  gr.Examples(
1603
  examples=[
 
1669
  )
1670
 
1671
 
 
1672
  if __name__ == "__main__":
1673
  demo.launch(ssr_mode=False, mcp_server=True, css=css)
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  """Audio VAE model components."""
2
 
3
- from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio
4
  from ltx_core.model.audio_vae.model_configurator import (
5
  AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
6
  AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
@@ -24,4 +24,5 @@ __all__ = [
24
  "Vocoder",
25
  "VocoderConfigurator",
26
  "decode_audio",
 
27
  ]
 
1
  """Audio VAE model components."""
2
 
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio
4
  from ltx_core.model.audio_vae.model_configurator import (
5
  AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
6
  AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
 
24
  "Vocoder",
25
  "VocoderConfigurator",
26
  "decode_audio",
27
+ "encode_audio",
28
  ]
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py CHANGED
@@ -8,7 +8,7 @@ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
8
  from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
9
  from ltx_core.model.audio_vae.causality_axis import CausalityAxis
10
  from ltx_core.model.audio_vae.downsample import build_downsampling_path
11
- from ltx_core.model.audio_vae.ops import PerChannelStatistics
12
  from ltx_core.model.audio_vae.resnet import ResnetBlock
13
  from ltx_core.model.audio_vae.upsample import build_upsampling_path
14
  from ltx_core.model.audio_vae.vocoder import Vocoder
@@ -464,6 +464,57 @@ class AudioDecoder(torch.nn.Module):
464
  h = self.conv_out(h)
465
  return torch.tanh(h) if self.tanh_out else h
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
469
  """
 
8
  from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
9
  from ltx_core.model.audio_vae.causality_axis import CausalityAxis
10
  from ltx_core.model.audio_vae.downsample import build_downsampling_path
11
+ from ltx_core.model.audio_vae.ops import PerChannelStatistics, AudioProcessor
12
  from ltx_core.model.audio_vae.resnet import ResnetBlock
13
  from ltx_core.model.audio_vae.upsample import build_upsampling_path
14
  from ltx_core.model.audio_vae.vocoder import Vocoder
 
464
  h = self.conv_out(h)
465
  return torch.tanh(h) if self.tanh_out else h
466
 
467
+ @torch.no_grad()
468
+ def encode_audio(
469
+ waveform: torch.Tensor,
470
+ waveform_sample_rate: int,
471
+ *,
472
+ audio_encoder: "AudioEncoder",
473
+ audio_processor: "AudioProcessor",
474
+ return_mean_only: bool = False,
475
+ ) -> torch.Tensor:
476
+ """
477
+ Encode a waveform into an audio latent representation.
478
+
479
+ Args:
480
+ waveform: Audio waveform tensor.
481
+ Expected shapes:
482
+ - (T,) -> treated as (1,1,T)
483
+ - (B,T) -> treated as (B,1,T)
484
+ - (B,C,T) -> used as-is
485
+ waveform_sample_rate: Sample rate of the provided waveform.
486
+ audio_encoder: AudioEncoder that consumes (B, C, frames, mel_bins).
487
+ audio_processor: AudioProcessor from ops.py that produces log-mel features.
488
+ return_mean_only: If True and encoder outputs double_z, return only the mean half.
489
+
490
+ Returns:
491
+ Latent tensor from AudioEncoder.
492
+ If return_mean_only=True and double_z=True: returns (B, z_channels, frames, mel_bins).
493
+ Otherwise returns the raw encoder output (often (B, 2*z_channels, frames, mel_bins)).
494
+ """
495
+ # --- normalize waveform shape to (B, C, T) ---
496
+ if waveform.dim() == 1:
497
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
498
+ elif waveform.dim() == 2:
499
+ waveform = waveform.unsqueeze(1)
500
+ elif waveform.dim() == 3:
501
+ pass
502
+ else:
503
+ raise ValueError(f"Unexpected waveform shape: {tuple(waveform.shape)}")
504
+
505
+ waveform = waveform.float()
506
+
507
+ # --- waveform -> log-mel spectrogram (B, C, frames, mel_bins) ---
508
+ mel = audio_processor.waveform_to_mel(waveform, waveform_sample_rate)
509
+
510
+ # --- mel -> latent ---
511
+ latent = audio_encoder(mel)
512
+
513
+ if return_mean_only and getattr(audio_encoder, "double_z", False):
514
+ latent = torch.chunk(latent, 2, dim=1)[0]
515
+
516
+ return latent
517
+
518
 
519
  def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
520
  """
packages/ltx-pipelines/src/ltx_pipelines/distilled.py CHANGED
@@ -8,7 +8,7 @@ from ltx_core.components.diffusion_steps import EulerDiffusionStep
8
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
9
  from ltx_core.components.noisers import GaussianNoiser
10
  from ltx_core.components.protocols import DiffusionStepProtocol
11
- from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex
12
  from ltx_core.loader import LoraPathStrengthAndSDOps
13
  from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
14
  from ltx_core.model.upsampler import upsample_video
@@ -16,6 +16,7 @@ from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunk
16
  from ltx_core.model.video_vae import decode_video as vae_decode_video
17
  from ltx_core.text_encoders.gemma import encode_text
18
  from ltx_core.types import LatentState, VideoPixelShape
 
19
  from ltx_pipelines import utils
20
  from ltx_pipelines.utils import ModelLedger
21
  from ltx_pipelines.utils.args import default_2_stage_distilled_arg_parser
@@ -38,6 +39,42 @@ from ltx_pipelines.utils.helpers import (
38
  from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
39
  from ltx_pipelines.utils.types import PipelineComponents
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  device = get_device()
42
 
43
 
@@ -74,6 +111,151 @@ class DistilledPipeline:
74
  # Cached models to avoid reloading
75
  self._video_encoder = None
76
  self._transformer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @torch.inference_mode()
79
  def __call__(
@@ -92,12 +274,27 @@ class DistilledPipeline:
92
  tiling_config: TilingConfig | None = None,
93
  video_context: torch.Tensor | None = None,
94
  audio_context: torch.Tensor | None = None,
 
 
 
95
  ) -> None:
96
  generator = torch.Generator(device=self.device).manual_seed(seed)
97
  noiser = GaussianNoiser(generator=generator)
98
  stepper = EulerDiffusionStep()
99
  dtype = torch.bfloat16
100
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Use pre-computed embeddings if provided, otherwise encode text
102
  if video_context is None or audio_context is None:
103
  text_encoder = self.model_ledger.text_encoder()
@@ -153,6 +350,7 @@ class DistilledPipeline:
153
  video_state, audio_state = denoise_audio_video(
154
  output_shape=stage_1_output_shape,
155
  conditionings=stage_1_conditionings,
 
156
  noiser=noiser,
157
  sigmas=stage_1_sigmas,
158
  stepper=stepper,
@@ -197,6 +395,7 @@ class DistilledPipeline:
197
  video_state, audio_state = denoise_audio_video(
198
  output_shape=stage_2_output_shape,
199
  conditionings=stage_2_conditionings,
 
200
  noiser=noiser,
201
  sigmas=stage_2_sigmas,
202
  stepper=stepper,
@@ -227,8 +426,6 @@ class DistilledPipeline:
227
  )
228
 
229
 
230
-
231
-
232
  def _create_conditionings(
233
  self,
234
  images: list[tuple[str, int, float]],
@@ -275,4 +472,4 @@ class DistilledPipeline:
275
  )
276
  )
277
 
278
- return conditionings
 
8
  from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
9
  from ltx_core.components.noisers import GaussianNoiser
10
  from ltx_core.components.protocols import DiffusionStepProtocol
11
+ from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex, ConditioningError
12
  from ltx_core.loader import LoraPathStrengthAndSDOps
13
  from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
14
  from ltx_core.model.upsampler import upsample_video
 
16
  from ltx_core.model.video_vae import decode_video as vae_decode_video
17
  from ltx_core.text_encoders.gemma import encode_text
18
  from ltx_core.types import LatentState, VideoPixelShape
19
+ from ltx_core.tools import LatentTools
20
  from ltx_pipelines import utils
21
  from ltx_pipelines.utils import ModelLedger
22
  from ltx_pipelines.utils.args import default_2_stage_distilled_arg_parser
 
39
  from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
40
  from ltx_pipelines.utils.types import PipelineComponents
41
 
42
+ import torchaudio
43
+ from ltx_core.model.audio_vae import AudioProcessor
44
+ from ltx_core.types import AudioLatentShape, VideoPixelShape
45
+
46
+ class AudioConditionByLatent(ConditioningItem):
47
+ """
48
+ Conditions audio generation by injecting a full latent sequence.
49
+ Replaces tokens in the latent state with the provided audio latents,
50
+ and sets denoise strength according to the strength parameter.
51
+ """
52
+
53
+ def __init__(self, latent: torch.Tensor, strength: float):
54
+ self.latent = latent
55
+ self.strength = strength
56
+
57
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
58
+ if not isinstance(latent_tools.target_shape, AudioLatentShape):
59
+ raise ConditioningError("Audio conditioning requires an audio latent target shape.")
60
+
61
+ cond_batch, cond_channels, cond_frames, cond_bins = self.latent.shape
62
+ tgt_batch, tgt_channels, tgt_frames, tgt_bins = latent_tools.target_shape.to_torch_shape()
63
+
64
+ if (cond_batch, cond_channels, cond_frames, cond_bins) != (tgt_batch, tgt_channels, tgt_frames, tgt_bins):
65
+ raise ConditioningError(
66
+ f"Can't apply audio conditioning item to latent with shape {latent_tools.target_shape}, expected "
67
+ f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_bins})."
68
+ )
69
+
70
+ tokens = latent_tools.patchifier.patchify(self.latent)
71
+ latent_state = latent_state.clone()
72
+ latent_state.latent[:, : tokens.shape[1]] = tokens
73
+ latent_state.clean_latent[:, : tokens.shape[1]] = tokens
74
+ latent_state.denoise_mask[:, : tokens.shape[1]] = 1.0 - self.strength
75
+
76
+ return latent_state
77
+
78
  device = get_device()
79
 
80
 
 
111
  # Cached models to avoid reloading
112
  self._video_encoder = None
113
  self._transformer = None
114
+
115
+ def _build_audio_conditionings_from_waveform(
116
+ self,
117
+ input_waveform: torch.Tensor,
118
+ input_sample_rate: int,
119
+ num_frames: int,
120
+ fps: float,
121
+ strength: float,
122
+ ) -> list[AudioConditionByLatent] | None:
123
+ strength = float(strength)
124
+ if strength <= 0.0:
125
+ return None
126
+
127
+ # Expect waveform as:
128
+ # - (T,) or (C,T) or (B,C,T). Convert to (B,C,T)
129
+ waveform = input_waveform
130
+ if waveform.ndim == 1:
131
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
132
+ elif waveform.ndim == 2:
133
+ waveform = waveform.unsqueeze(0)
134
+ elif waveform.ndim != 3:
135
+ raise ValueError(f"input_waveform must be 1D/2D/3D, got shape {tuple(waveform.shape)}")
136
+
137
+ # Get audio encoder + its config
138
+ audio_encoder = self.model_ledger.audio_encoder() # assumes ledger exposes it
139
+ # If you want to cache it like video_encoder/transformer, you can.
140
+ target_sr = int(getattr(audio_encoder, "sample_rate"))
141
+ target_channels = int(getattr(audio_encoder, "in_channels", waveform.shape[1]))
142
+ mel_bins = int(getattr(audio_encoder, "mel_bins"))
143
+ mel_hop = int(getattr(audio_encoder, "mel_hop_length"))
144
+ n_fft = int(getattr(audio_encoder, "n_fft"))
145
+
146
+ # Match channels
147
+ if waveform.shape[1] != target_channels:
148
+ if waveform.shape[1] == 1 and target_channels > 1:
149
+ waveform = waveform.repeat(1, target_channels, 1)
150
+ elif target_channels == 1:
151
+ waveform = waveform.mean(dim=1, keepdim=True)
152
+ else:
153
+ waveform = waveform[:, :target_channels, :]
154
+ if waveform.shape[1] < target_channels:
155
+ pad_ch = target_channels - waveform.shape[1]
156
+ pad = torch.zeros((waveform.shape[0], pad_ch, waveform.shape[2]), dtype=waveform.dtype)
157
+ waveform = torch.cat([waveform, pad], dim=1)
158
+
159
+ # Resample if needed (CPU float32 is safest for torchaudio)
160
+ waveform = waveform.to(device="cpu", dtype=torch.float32)
161
+ if int(input_sample_rate) != target_sr:
162
+ waveform = torchaudio.functional.resample(waveform, int(input_sample_rate), target_sr)
163
+
164
+ # Waveform -> Mel
165
+ audio_processor = AudioProcessor(
166
+ sample_rate=target_sr,
167
+ mel_bins=mel_bins,
168
+ mel_hop_length=mel_hop,
169
+ n_fft=n_fft,
170
+ ).to(waveform.device)
171
+
172
+ mel = audio_processor.waveform_to_mel(waveform, target_sr)
173
+
174
+ # Mel -> latent (run encoder on its own device/dtype)
175
+ audio_params = next(audio_encoder.parameters(), None)
176
+ enc_device = audio_params.device if audio_params is not None else self.device
177
+ enc_dtype = audio_params.dtype if audio_params is not None else self.dtype
178
+
179
+ mel = mel.to(device=enc_device, dtype=enc_dtype)
180
+ with torch.inference_mode():
181
+ audio_latent = audio_encoder(mel)
182
+
183
+ # Pad/trim latent to match the target video duration
184
+ audio_downsample = getattr(getattr(audio_encoder, "patchifier", None), "audio_latent_downsample_factor", 4)
185
+ target_shape = AudioLatentShape.from_video_pixel_shape(
186
+ VideoPixelShape(batch=audio_latent.shape[0], frames=int(num_frames), width=1, height=1, fps=float(fps)),
187
+ channels=audio_latent.shape[1],
188
+ mel_bins=audio_latent.shape[3],
189
+ sample_rate=target_sr,
190
+ hop_length=mel_hop,
191
+ audio_latent_downsample_factor=audio_downsample,
192
+ )
193
+ target_frames = int(target_shape.frames)
194
+
195
+ if audio_latent.shape[2] < target_frames:
196
+ pad_frames = target_frames - audio_latent.shape[2]
197
+ pad = torch.zeros(
198
+ (audio_latent.shape[0], audio_latent.shape[1], pad_frames, audio_latent.shape[3]),
199
+ device=audio_latent.device,
200
+ dtype=audio_latent.dtype,
201
+ )
202
+ audio_latent = torch.cat([audio_latent, pad], dim=2)
203
+ elif audio_latent.shape[2] > target_frames:
204
+ audio_latent = audio_latent[:, :, :target_frames, :]
205
+
206
+ # Move latent to pipeline device/dtype for conditioning object
207
+ audio_latent = audio_latent.to(device=self.device, dtype=self.dtype)
208
+
209
+ return [AudioConditionByLatent(audio_latent, strength)]
210
+
211
+ def _prepare_output_waveform(
212
+ self,
213
+ input_waveform: torch.Tensor,
214
+ input_sample_rate: int,
215
+ target_sample_rate: int,
216
+ num_frames: int,
217
+ fps: float,
218
+ ) -> torch.Tensor:
219
+ """
220
+ Returns waveform on CPU, float32, resampled to target_sample_rate and
221
+ trimmed/padded to match video duration.
222
+ Output shape: (T,) for mono or (C, T) for multi-channel.
223
+ """
224
+ wav = input_waveform
225
+
226
+ # Accept (T,), (C,T), (B,C,T)
227
+ if wav.ndim == 3:
228
+ wav = wav[0]
229
+ elif wav.ndim == 2:
230
+ pass
231
+ elif wav.ndim == 1:
232
+ wav = wav.unsqueeze(0)
233
+ else:
234
+ raise ValueError(f"input_waveform must be 1D/2D/3D, got {tuple(wav.shape)}")
235
+
236
+ # Now wav is (C, T)
237
+ wav = wav.detach().to("cpu", dtype=torch.float32)
238
+
239
+ # Resample if needed
240
+ if int(input_sample_rate) != int(target_sample_rate):
241
+ wav = torchaudio.functional.resample(wav, int(input_sample_rate), int(target_sample_rate))
242
+
243
+ # Match video duration
244
+ duration_sec = float(num_frames) / float(fps)
245
+ target_len = int(round(duration_sec * float(target_sample_rate)))
246
+
247
+ cur_len = int(wav.shape[-1])
248
+ if cur_len > target_len:
249
+ wav = wav[..., :target_len]
250
+ elif cur_len < target_len:
251
+ pad = target_len - cur_len
252
+ wav = torch.nn.functional.pad(wav, (0, pad))
253
+
254
+ # If mono, return (T,) for convenience
255
+ if wav.shape[0] == 1:
256
+ return wav[0]
257
+ return wav
258
+
259
 
260
  @torch.inference_mode()
261
  def __call__(
 
274
  tiling_config: TilingConfig | None = None,
275
  video_context: torch.Tensor | None = None,
276
  audio_context: torch.Tensor | None = None,
277
+ input_waveform: torch.Tensor | None = None,
278
+ input_waveform_sample_rate: int | None = None,
279
+ audio_strength: float = 1.0, # or audio_scale, your naming
280
  ) -> None:
281
  generator = torch.Generator(device=self.device).manual_seed(seed)
282
  noiser = GaussianNoiser(generator=generator)
283
  stepper = EulerDiffusionStep()
284
  dtype = torch.bfloat16
285
 
286
+ audio_conditionings = None
287
+ if input_waveform is not None:
288
+ if input_waveform_sample_rate is None:
289
+ raise ValueError("input_waveform_sample_rate must be provided when input_waveform is set.")
290
+ audio_conditionings = self._build_audio_conditionings_from_waveform(
291
+ input_waveform=input_waveform,
292
+ input_sample_rate=int(input_waveform_sample_rate),
293
+ num_frames=num_frames,
294
+ fps=frame_rate,
295
+ strength=audio_strength,
296
+ )
297
+
298
  # Use pre-computed embeddings if provided, otherwise encode text
299
  if video_context is None or audio_context is None:
300
  text_encoder = self.model_ledger.text_encoder()
 
350
  video_state, audio_state = denoise_audio_video(
351
  output_shape=stage_1_output_shape,
352
  conditionings=stage_1_conditionings,
353
+ audio_conditionings=audio_conditionings,
354
  noiser=noiser,
355
  sigmas=stage_1_sigmas,
356
  stepper=stepper,
 
395
  video_state, audio_state = denoise_audio_video(
396
  output_shape=stage_2_output_shape,
397
  conditionings=stage_2_conditionings,
398
+ audio_conditionings=audio_conditionings,
399
  noiser=noiser,
400
  sigmas=stage_2_sigmas,
401
  stepper=stepper,
 
426
  )
427
 
428
 
 
 
429
  def _create_conditionings(
430
  self,
431
  images: list[tuple[str, int, float]],
 
472
  )
473
  )
474
 
475
+ return conditionings
packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py CHANGED
@@ -245,6 +245,7 @@ def noise_audio_state(
245
  device: torch.device,
246
  noise_scale: float = 1.0,
247
  initial_latent: torch.Tensor | None = None,
 
248
  ) -> tuple[LatentState, AudioLatentTools]:
249
  """Initialize and noise an audio latent state for the diffusion pipeline.
250
  Creates an audio latent state from the output shape, applies conditionings,
@@ -262,6 +263,7 @@ def noise_audio_state(
262
  device=device,
263
  noise_scale=noise_scale,
264
  initial_latent=initial_latent,
 
265
  )
266
 
267
  return audio_state, audio_tools
@@ -275,18 +277,35 @@ def create_noised_state(
275
  device: torch.device,
276
  noise_scale: float = 1.0,
277
  initial_latent: torch.Tensor | None = None,
 
278
  ) -> LatentState:
279
- """Create a noised latent state from empty state, conditionings, and noiser.
280
- Creates an empty latent state, applies conditionings, and then adds noise
281
- using the provided noiser. Returns the final noised state ready for diffusion.
282
- """
283
  state = tools.create_initial_state(device, dtype, initial_latent)
284
  state = state_with_conditionings(state, conditionings, tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  state = noiser(state, noise_scale)
286
 
 
 
 
 
 
287
  return state
288
 
289
 
 
290
  def state_with_conditionings(
291
  latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools
292
  ) -> LatentState:
@@ -302,7 +321,9 @@ def state_with_conditionings(
302
 
303
  def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
304
  """Blend denoised output with clean state based on mask."""
305
- return (denoised * denoise_mask + clean.float() * (1 - denoise_mask)).to(denoised.dtype)
 
 
306
 
307
 
308
  def modality_from_latent_state(
@@ -386,10 +407,12 @@ def denoise_audio_video( # noqa: PLR0913
386
  components: PipelineComponents,
387
  dtype: torch.dtype,
388
  device: torch.device,
 
389
  noise_scale: float = 1.0,
390
  initial_video_latent: torch.Tensor | None = None,
391
  initial_audio_latent: torch.Tensor | None = None,
392
- ) -> tuple[LatentState, LatentState]:
 
393
  video_state, video_tools = noise_video_state(
394
  output_shape=output_shape,
395
  noiser=noiser,
@@ -403,7 +426,7 @@ def denoise_audio_video( # noqa: PLR0913
403
  audio_state, audio_tools = noise_audio_state(
404
  output_shape=output_shape,
405
  noiser=noiser,
406
- conditionings=[],
407
  components=components,
408
  dtype=dtype,
409
  device=device,
@@ -411,13 +434,22 @@ def denoise_audio_video( # noqa: PLR0913
411
  initial_latent=initial_audio_latent,
412
  )
413
 
 
 
 
 
 
414
  video_state, audio_state = denoising_loop_fn(
415
  sigmas,
416
  video_state,
417
  audio_state,
418
  stepper,
 
419
  )
420
 
 
 
 
421
  video_state = video_tools.clear_conditioning(video_state)
422
  video_state = video_tools.unpatchify(video_state)
423
  audio_state = audio_tools.clear_conditioning(audio_state)
@@ -426,6 +458,7 @@ def denoise_audio_video( # noqa: PLR0913
426
  return video_state, audio_state
427
 
428
 
 
429
  _UNICODE_REPLACEMENTS = str.maketrans("\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-")
430
 
431
 
 
245
  device: torch.device,
246
  noise_scale: float = 1.0,
247
  initial_latent: torch.Tensor | None = None,
248
+ denoise_mask: torch.Tensor | None = None
249
  ) -> tuple[LatentState, AudioLatentTools]:
250
  """Initialize and noise an audio latent state for the diffusion pipeline.
251
  Creates an audio latent state from the output shape, applies conditionings,
 
263
  device=device,
264
  noise_scale=noise_scale,
265
  initial_latent=initial_latent,
266
+ denoise_mask=denoise_mask,
267
  )
268
 
269
  return audio_state, audio_tools
 
277
  device: torch.device,
278
  noise_scale: float = 1.0,
279
  initial_latent: torch.Tensor | None = None,
280
+ denoise_mask: torch.Tensor | None = None, # <-- add
281
  ) -> LatentState:
 
 
 
 
282
  state = tools.create_initial_state(device, dtype, initial_latent)
283
  state = state_with_conditionings(state, conditionings, tools)
284
+
285
+ if denoise_mask is not None:
286
+ # Convert any tensor mask into a single scalar (solid mask behavior)
287
+ if isinstance(denoise_mask, torch.Tensor):
288
+ mask_value = float(denoise_mask.mean().item())
289
+ else:
290
+ mask_value = float(denoise_mask)
291
+
292
+ state = replace(
293
+ state,
294
+ clean_latent=state.latent.clone(),
295
+ denoise_mask=torch.full_like(state.denoise_mask, mask_value), # <- matches internal shape
296
+ )
297
+
298
  state = noiser(state, noise_scale)
299
 
300
+ if denoise_mask is not None:
301
+ m = state.denoise_mask.to(dtype=state.latent.dtype, device=state.latent.device)
302
+ clean = state.clean_latent.to(dtype=state.latent.dtype, device=state.latent.device)
303
+ state = replace(state, latent=state.latent * m + clean * (1 - m))
304
+
305
  return state
306
 
307
 
308
+
309
  def state_with_conditionings(
310
  latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools
311
  ) -> LatentState:
 
321
 
322
  def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
323
  """Blend denoised output with clean state based on mask."""
324
+ clean = clean.to(dtype=denoised.dtype)
325
+ denoise_mask = denoise_mask.to(dtype=denoised.dtype)
326
+ return denoised * denoise_mask + clean * (1 - denoise_mask)
327
 
328
 
329
  def modality_from_latent_state(
 
407
  components: PipelineComponents,
408
  dtype: torch.dtype,
409
  device: torch.device,
410
+ audio_conditionings: list[ConditioningItem] | None = None,
411
  noise_scale: float = 1.0,
412
  initial_video_latent: torch.Tensor | None = None,
413
  initial_audio_latent: torch.Tensor | None = None,
414
+ # mask_context: MaskInjection | None = None,
415
+ ) -> tuple[LatentState | None, LatentState | None]:
416
  video_state, video_tools = noise_video_state(
417
  output_shape=output_shape,
418
  noiser=noiser,
 
426
  audio_state, audio_tools = noise_audio_state(
427
  output_shape=output_shape,
428
  noiser=noiser,
429
+ conditionings=audio_conditionings or [],
430
  components=components,
431
  dtype=dtype,
432
  device=device,
 
434
  initial_latent=initial_audio_latent,
435
  )
436
 
437
+ loop_kwargs = {}
438
+ # if "preview_tools" in inspect.signature(denoising_loop_fn).parameters:
439
+ # loop_kwargs["preview_tools"] = video_tools
440
+ # if "mask_context" in inspect.signature(denoising_loop_fn).parameters:
441
+ # loop_kwargs["mask_context"] = mask_context
442
  video_state, audio_state = denoising_loop_fn(
443
  sigmas,
444
  video_state,
445
  audio_state,
446
  stepper,
447
+ **loop_kwargs,
448
  )
449
 
450
+ if video_state is None or audio_state is None:
451
+ return None, None
452
+
453
  video_state = video_tools.clear_conditioning(video_state)
454
  video_state = video_tools.unpatchify(video_state)
455
  audio_state = audio_tools.clear_conditioning(audio_state)
 
458
  return video_state, audio_state
459
 
460
 
461
+
462
  _UNICODE_REPLACEMENTS = str.maketrans("\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-")
463
 
464
 
packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py CHANGED
@@ -36,10 +36,24 @@ from ltx_core.text_encoders.gemma import (
36
  module_ops_from_gemma_root,
37
  )
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  class ModelLedger:
41
  """
42
- Central coordinator for loading and building models used in an LTX pipeline.
43
  The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
44
  audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
45
  factory methods for constructing model instances.
@@ -144,6 +158,14 @@ class ModelLedger:
144
  registry=self.registry,
145
  )
146
 
 
 
 
 
 
 
 
 
147
  if self.gemma_root_path is not None:
148
  self.text_encoder_builder = Builder(
149
  model_path=self.checkpoint_path,
@@ -197,6 +219,14 @@ class ModelLedger:
197
  .eval()
198
  )
199
 
 
 
 
 
 
 
 
 
200
  def video_decoder(self) -> VideoDecoder:
201
  if not hasattr(self, "vae_decoder_builder"):
202
  raise ValueError(
 
36
  module_ops_from_gemma_root,
37
  )
38
 
39
+ from ltx_core.model.audio_vae import (
40
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
41
+ VOCODER_COMFY_KEYS_FILTER,
42
+ AudioDecoder,
43
+ AudioDecoderConfigurator,
44
+ Vocoder,
45
+ VocoderConfigurator,
46
+ AudioEncoder,
47
+ )
48
+ from ltx_core.model.audio_vae.model_configurator import (
49
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
50
+ AudioEncoderConfigurator,
51
+ )
52
+
53
 
54
  class ModelLedger:
55
  """
56
+ Central coordinator for loading and building models used in an LTX pipeline.
57
  The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
58
  audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
59
  factory methods for constructing model instances.
 
158
  registry=self.registry,
159
  )
160
 
161
+ self.audio_encoder_builder = Builder(
162
+ model_path=self.checkpoint_path,
163
+ model_class_configurator=AudioEncoderConfigurator,
164
+ model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
165
+ registry=self.registry,
166
+ )
167
+
168
+
169
  if self.gemma_root_path is not None:
170
  self.text_encoder_builder = Builder(
171
  model_path=self.checkpoint_path,
 
219
  .eval()
220
  )
221
 
222
+ def audio_encoder(self) -> AudioEncoder:
223
+ if not hasattr(self, "audio_encoder_builder"):
224
+ raise ValueError(
225
+ "Audio encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
226
+ )
227
+ return self.audio_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
228
+
229
+
230
  def video_decoder(self) -> VideoDecoder:
231
  if not hasattr(self, "vae_decoder_builder"):
232
  raise ValueError(
requirements.txt CHANGED
@@ -6,9 +6,15 @@ safetensors
6
  accelerate
7
  flashpack==0.1.2
8
  scikit-image>=0.25.2
 
 
9
  av
10
  tqdm
11
  pillow
12
  scipy>=1.14
13
  flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
14
- bitsandbytes
 
 
 
 
 
6
  accelerate
7
  flashpack==0.1.2
8
  scikit-image>=0.25.2
9
+ imageio
10
+ imageio-ffmpeg
11
  av
12
  tqdm
13
  pillow
14
  scipy>=1.14
15
  flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
16
+ bitsandbytes
17
+ opencv-python
18
+ controlnet_aux
19
+ onnxruntime-gpu
20
+ matplotlib