Spaces:
Sleeping
Sleeping
alex commited on
Commit ·
25889c7
1
Parent(s): 43067da
now with audio support
Browse files- app.py +279 -13
- packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +2 -1
- packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +52 -1
- packages/ltx-pipelines/src/ltx_pipelines/distilled.py +201 -4
- packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py +40 -7
- packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py +31 -1
- requirements.txt +7 -1
app.py
CHANGED
|
@@ -2,6 +2,80 @@ import sys
|
|
| 2 |
from pathlib import Path
|
| 3 |
import uuid
|
| 4 |
import tempfile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Add packages to Python path
|
| 7 |
current_dir = Path(__file__).parent
|
|
@@ -17,9 +91,11 @@ import random
|
|
| 17 |
import torch
|
| 18 |
from typing import Optional
|
| 19 |
from pathlib import Path
|
|
|
|
| 20 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 21 |
from ltx_pipelines.distilled import DistilledPipeline
|
| 22 |
from ltx_core.model.video_vae import TilingConfig
|
|
|
|
| 23 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 24 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 25 |
from ltx_pipelines.utils.constants import (
|
|
@@ -31,6 +107,8 @@ from ltx_pipelines.utils.constants import (
|
|
| 31 |
DEFAULT_LORA_STRENGTH,
|
| 32 |
)
|
| 33 |
from ltx_core.loader.single_gpu_model_builder import enable_only_lora
|
|
|
|
|
|
|
| 34 |
from PIL import Image
|
| 35 |
|
| 36 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -38,6 +116,11 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
| 38 |
# Install with: pip install git+https://github.com/Lightricks/LTX-2.git
|
| 39 |
from ltx_pipelines.utils import ModelLedger
|
| 40 |
from ltx_pipelines.utils.helpers import generate_enhanced_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# HuggingFace Hub defaults
|
| 43 |
DEFAULT_REPO_ID = "Lightricks/LTX-2"
|
|
@@ -82,6 +165,8 @@ model_ledger = ModelLedger(
|
|
| 82 |
local_files_only=False
|
| 83 |
)
|
| 84 |
|
|
|
|
|
|
|
| 85 |
|
| 86 |
# Load text encoder once and keep it in memory
|
| 87 |
text_encoder = model_ledger.text_encoder()
|
|
@@ -90,6 +175,109 @@ print("=" * 80)
|
|
| 90 |
print("Text encoder loaded and ready!")
|
| 91 |
print("=" * 80)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def encode_text_simple(text_encoder, prompt: str):
|
| 94 |
"""Simple text encoding without using pipeline_utils."""
|
| 95 |
v_context, a_context, _ = text_encoder(prompt)
|
|
@@ -262,6 +450,7 @@ RUNTIME_LORA_CHOICES = [
|
|
| 262 |
("Slide Right", 5),
|
| 263 |
("Slide Down", 6),
|
| 264 |
("Slide Up", 7),
|
|
|
|
| 265 |
]
|
| 266 |
|
| 267 |
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
|
|
@@ -580,7 +769,7 @@ def generate_video_example(input_image, prompt, camera_lora, resolution, progres
|
|
| 580 |
|
| 581 |
w, h = apply_resolution(resolution)
|
| 582 |
|
| 583 |
-
output_video
|
| 584 |
input_image,
|
| 585 |
prompt,
|
| 586 |
10, # duration seconds
|
|
@@ -589,18 +778,18 @@ def generate_video_example(input_image, prompt, camera_lora, resolution, progres
|
|
| 589 |
True, # randomize_seed
|
| 590 |
h, # height
|
| 591 |
w, # width
|
| 592 |
-
camera_lora,
|
|
|
|
| 593 |
progress
|
| 594 |
)
|
| 595 |
|
| 596 |
return output_video
|
| 597 |
-
|
| 598 |
-
|
| 599 |
def generate_video_example_t2v(prompt, camera_lora, resolution, progress=gr.Progress(track_tqdm=True)):
|
| 600 |
|
| 601 |
w, h = apply_resolution(resolution)
|
| 602 |
|
| 603 |
-
output_video
|
| 604 |
None,
|
| 605 |
prompt,
|
| 606 |
15, # duration seconds
|
|
@@ -609,11 +798,32 @@ def generate_video_example_t2v(prompt, camera_lora, resolution, progress=gr.Prog
|
|
| 609 |
True, # randomize_seed
|
| 610 |
h, # height
|
| 611 |
w, # width
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
camera_lora,
|
|
|
|
| 613 |
progress
|
| 614 |
)
|
|
|
|
| 615 |
return output_video
|
| 616 |
-
|
| 617 |
def get_duration(
|
| 618 |
input_image,
|
| 619 |
prompt,
|
|
@@ -624,14 +834,20 @@ def get_duration(
|
|
| 624 |
height,
|
| 625 |
width,
|
| 626 |
camera_lora,
|
|
|
|
| 627 |
progress
|
| 628 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
if duration <= 5:
|
| 630 |
-
return 80
|
| 631 |
elif duration <= 10:
|
| 632 |
-
return 120
|
| 633 |
else:
|
| 634 |
-
return 180
|
| 635 |
|
| 636 |
@spaces.GPU(duration=get_duration)
|
| 637 |
def generate_video(
|
|
@@ -644,6 +860,7 @@ def generate_video(
|
|
| 644 |
height: int = DEFAULT_1_STAGE_HEIGHT,
|
| 645 |
width: int = DEFAULT_1_STAGE_WIDTH,
|
| 646 |
camera_lora: str = "No LoRA",
|
|
|
|
| 647 |
progress=gr.Progress(track_tqdm=True),
|
| 648 |
):
|
| 649 |
"""
|
|
@@ -705,10 +922,21 @@ def generate_video(
|
|
| 705 |
audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
|
| 706 |
print("✓ Embeddings loaded successfully")
|
| 707 |
|
|
|
|
| 708 |
# free prompt enhancer / encoder temps ASAP
|
| 709 |
del embeddings, final_prompt, status
|
| 710 |
torch.cuda.empty_cache()
|
| 711 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
|
| 713 |
# Map dropdown name -> adapter index
|
| 714 |
name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
|
|
@@ -717,6 +945,22 @@ def generate_video(
|
|
| 717 |
enable_only_lora(pipeline._transformer, selected_idx)
|
| 718 |
torch.cuda.empty_cache()
|
| 719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
# Run inference - progress automatically tracks tqdm from pipeline
|
| 721 |
with torch.inference_mode():
|
| 722 |
pipeline(
|
|
@@ -731,12 +975,14 @@ def generate_video(
|
|
| 731 |
tiling_config=TilingConfig.default(),
|
| 732 |
video_context=video_context,
|
| 733 |
audio_context=audio_context,
|
|
|
|
|
|
|
| 734 |
)
|
| 735 |
del video_context, audio_context
|
| 736 |
torch.cuda.empty_cache()
|
| 737 |
print("successful generation")
|
| 738 |
|
| 739 |
-
return str(output_path)
|
| 740 |
|
| 741 |
|
| 742 |
|
|
@@ -1160,12 +1406,13 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 1160 |
height=512
|
| 1161 |
)
|
| 1162 |
|
| 1163 |
-
|
| 1164 |
prompt_ui = PromptBox(
|
| 1165 |
value="Make this image come alive with cinematic motion, smooth animation",
|
| 1166 |
elem_id="prompt_ui",
|
| 1167 |
)
|
| 1168 |
|
|
|
|
|
|
|
| 1169 |
prompt = gr.Textbox(
|
| 1170 |
label="Prompt",
|
| 1171 |
value="Make this image come alive with cinematic motion, smooth animation",
|
|
@@ -1302,11 +1549,13 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 1302 |
height,
|
| 1303 |
width,
|
| 1304 |
camera_lora,
|
|
|
|
| 1305 |
],
|
| 1306 |
-
outputs=[output_video
|
| 1307 |
)
|
| 1308 |
|
| 1309 |
|
|
|
|
| 1310 |
timestep_prompt = """Style: Realistic live-action, cinematic, shallow depth of field, 24 fps, natural and dramatic lighting
|
| 1311 |
|
| 1312 |
Environment: Interior of a space station module or realistic mock-up, metal panels, blinking lights, Earth visible through a large window
|
|
@@ -1331,6 +1580,24 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 1331 |
|
| 1332 |
Music: subtle cinematic synth or ambient pad, futuristic and minimal, emphasizing awe and solitude"""
|
| 1333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1334 |
|
| 1335 |
gr.Examples(
|
| 1336 |
examples=[
|
|
@@ -1402,6 +1669,5 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 1402 |
)
|
| 1403 |
|
| 1404 |
|
| 1405 |
-
|
| 1406 |
if __name__ == "__main__":
|
| 1407 |
demo.launch(ssr_mode=False, mcp_server=True, css=css)
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
import uuid
|
| 4 |
import tempfile
|
| 5 |
+
import subprocess
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchaudio
|
| 9 |
+
import os
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
def _coerce_audio_path(audio_path: Any) -> str:
|
| 13 |
+
# Common Gradio case: tuple where first item is the filepath
|
| 14 |
+
if isinstance(audio_path, tuple) and len(audio_path) > 0:
|
| 15 |
+
audio_path = audio_path[0]
|
| 16 |
+
|
| 17 |
+
# Some gradio versions pass a dict-like object
|
| 18 |
+
if isinstance(audio_path, dict):
|
| 19 |
+
# common keys: "name", "path"
|
| 20 |
+
audio_path = audio_path.get("name") or audio_path.get("path")
|
| 21 |
+
|
| 22 |
+
# pathlib.Path etc.
|
| 23 |
+
if not isinstance(audio_path, (str, bytes, os.PathLike)):
|
| 24 |
+
raise TypeError(f"audio_path must be a path-like, got {type(audio_path)}: {audio_path}")
|
| 25 |
+
|
| 26 |
+
return os.fspath(audio_path)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def match_audio_to_duration(
|
| 31 |
+
audio_path: str,
|
| 32 |
+
target_seconds: float,
|
| 33 |
+
target_sr: int = 48000,
|
| 34 |
+
to_mono: bool = True,
|
| 35 |
+
pad_mode: str = "silence", # "silence" | "repeat"
|
| 36 |
+
device: str = "cuda",
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Load audio, resample, (optionally) mono, then trim/pad to exactly target_seconds.
|
| 40 |
+
Returns: (waveform[T] or [1,T], sr)
|
| 41 |
+
"""
|
| 42 |
+
audio_path = _coerce_audio_path(audio_path)
|
| 43 |
+
|
| 44 |
+
wav, sr = torchaudio.load(audio_path) # [C, T] float32 CPU
|
| 45 |
+
|
| 46 |
+
# Resample to target_sr (recommended so duration math is stable)
|
| 47 |
+
if sr != target_sr:
|
| 48 |
+
wav = torchaudio.functional.resample(wav, sr, target_sr)
|
| 49 |
+
sr = target_sr
|
| 50 |
+
|
| 51 |
+
# Mono (common expectation; if your model supports stereo, set to_mono=False)
|
| 52 |
+
if to_mono and wav.shape[0] > 1:
|
| 53 |
+
wav = wav.mean(dim=0, keepdim=True) # [1, T]
|
| 54 |
+
|
| 55 |
+
# Exact target length in samples
|
| 56 |
+
target_len = int(round(target_seconds * sr))
|
| 57 |
+
cur_len = wav.shape[-1]
|
| 58 |
+
|
| 59 |
+
if cur_len > target_len:
|
| 60 |
+
wav = wav[..., :target_len]
|
| 61 |
+
elif cur_len < target_len:
|
| 62 |
+
pad_len = target_len - cur_len
|
| 63 |
+
if pad_mode == "repeat" and cur_len > 0:
|
| 64 |
+
# Repeat then cut to exact length
|
| 65 |
+
reps = (target_len + cur_len - 1) // cur_len
|
| 66 |
+
wav = wav.repeat(1, reps)[..., :target_len]
|
| 67 |
+
else:
|
| 68 |
+
# Silence pad
|
| 69 |
+
wav = F.pad(wav, (0, pad_len))
|
| 70 |
+
|
| 71 |
+
# move to device
|
| 72 |
+
wav = wav.to(device, non_blocking=True)
|
| 73 |
+
return wav, sr
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
| 77 |
+
|
| 78 |
+
sh("pip install --no-deps easy_dwpose")
|
| 79 |
|
| 80 |
# Add packages to Python path
|
| 81 |
current_dir = Path(__file__).parent
|
|
|
|
| 91 |
import torch
|
| 92 |
from typing import Optional
|
| 93 |
from pathlib import Path
|
| 94 |
+
import torchaudio
|
| 95 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 96 |
from ltx_pipelines.distilled import DistilledPipeline
|
| 97 |
from ltx_core.model.video_vae import TilingConfig
|
| 98 |
+
from ltx_core.model.audio_vae.ops import AudioProcessor
|
| 99 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 100 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 101 |
from ltx_pipelines.utils.constants import (
|
|
|
|
| 107 |
DEFAULT_LORA_STRENGTH,
|
| 108 |
)
|
| 109 |
from ltx_core.loader.single_gpu_model_builder import enable_only_lora
|
| 110 |
+
from ltx_core.model.audio_vae import decode_audio
|
| 111 |
+
from ltx_core.model.audio_vae import encode_audio
|
| 112 |
from PIL import Image
|
| 113 |
|
| 114 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 116 |
# Install with: pip install git+https://github.com/Lightricks/LTX-2.git
|
| 117 |
from ltx_pipelines.utils import ModelLedger
|
| 118 |
from ltx_pipelines.utils.helpers import generate_enhanced_prompt
|
| 119 |
+
import imageio
|
| 120 |
+
import cv2
|
| 121 |
+
from controlnet_aux import CannyDetector
|
| 122 |
+
from easy_dwpose import DWposeDetector
|
| 123 |
+
|
| 124 |
|
| 125 |
# HuggingFace Hub defaults
|
| 126 |
DEFAULT_REPO_ID = "Lightricks/LTX-2"
|
|
|
|
| 165 |
local_files_only=False
|
| 166 |
)
|
| 167 |
|
| 168 |
+
canny_processor = CannyDetector()
|
| 169 |
+
|
| 170 |
|
| 171 |
# Load text encoder once and keep it in memory
|
| 172 |
text_encoder = model_ledger.text_encoder()
|
|
|
|
| 175 |
print("Text encoder loaded and ready!")
|
| 176 |
print("=" * 80)
|
| 177 |
|
| 178 |
+
def on_lora_change(selected: str):
|
| 179 |
+
needs_video = selected in {"Pose", "Canny", "Detailer"}
|
| 180 |
+
return (
|
| 181 |
+
selected,
|
| 182 |
+
gr.update(visible=not needs_video, value=None if needs_video else None),
|
| 183 |
+
gr.update(visible=needs_video, value=None if not needs_video else None),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def process_video_for_pose(frames, width: int, height: int):
|
| 188 |
+
|
| 189 |
+
pose_processor = DWposeDetector("cuda")
|
| 190 |
+
|
| 191 |
+
if not frames:
|
| 192 |
+
return []
|
| 193 |
+
|
| 194 |
+
pose_frames = []
|
| 195 |
+
for frame in frames:
|
| 196 |
+
# imageio frame -> PIL
|
| 197 |
+
pil = Image.fromarray(frame.astype(np.uint8)).convert("RGB")
|
| 198 |
+
|
| 199 |
+
# ✅ do NOT pass width/height here (easy_dwpose will handle drawing sizes internally)
|
| 200 |
+
pose_img = pose_processor(pil)
|
| 201 |
+
|
| 202 |
+
# Ensure it's PIL then resize to your conditioning size
|
| 203 |
+
if not isinstance(pose_img, Image.Image):
|
| 204 |
+
# some versions might return np array
|
| 205 |
+
pose_img = Image.fromarray(pose_img.astype(np.uint8))
|
| 206 |
+
|
| 207 |
+
pose_img = pose_img.convert("RGB").resize((width, height), Image.BILINEAR)
|
| 208 |
+
|
| 209 |
+
pose_np = np.array(pose_img).astype(np.float32) / 255.0
|
| 210 |
+
pose_frames.append(pose_np)
|
| 211 |
+
|
| 212 |
+
return pose_frames
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def preprocess_video_to_pose_mp4(video_path: str, width: int, height: int, fps: float):
|
| 216 |
+
frames = load_video_frames(video_path)
|
| 217 |
+
pose_frames = process_video_for_pose(frames, width=width, height=height)
|
| 218 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
| 219 |
+
tmp.close()
|
| 220 |
+
return write_video_mp4(pose_frames, fps=fps, out_path=tmp.name)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def load_video_frames(video_path: str):
|
| 224 |
+
"""Return list of frames as numpy arrays (H,W,3) uint8."""
|
| 225 |
+
frames = []
|
| 226 |
+
with imageio.get_reader(video_path) as reader:
|
| 227 |
+
for frame in reader:
|
| 228 |
+
frames.append(frame)
|
| 229 |
+
return frames
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def process_video_for_canny(frames, width: int, height: int,
|
| 233 |
+
low_threshold=50, high_threshold=200):
|
| 234 |
+
"""
|
| 235 |
+
Convert RGB frames -> canny edge frames.
|
| 236 |
+
Returns list of np arrays (H,W,3) in float [0..1] (like controlnet_aux output).
|
| 237 |
+
"""
|
| 238 |
+
if not frames:
|
| 239 |
+
return []
|
| 240 |
+
|
| 241 |
+
detect_resolution = max(frames[0].shape[0], frames[0].shape[1])
|
| 242 |
+
image_resolution = max(width, height)
|
| 243 |
+
|
| 244 |
+
canny_frames = []
|
| 245 |
+
for frame in frames:
|
| 246 |
+
# controlnet_aux CannyDetector returns float image in [0..1] if output_type="np"
|
| 247 |
+
canny = canny_processor(
|
| 248 |
+
frame,
|
| 249 |
+
low_threshold=low_threshold,
|
| 250 |
+
high_threshold=high_threshold,
|
| 251 |
+
detect_resolution=detect_resolution,
|
| 252 |
+
image_resolution=image_resolution,
|
| 253 |
+
output_type="np",
|
| 254 |
+
)
|
| 255 |
+
canny_frames.append(canny)
|
| 256 |
+
|
| 257 |
+
return canny_frames
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def write_video_mp4(frames_float_01, fps: float, out_path: str):
|
| 261 |
+
"""Write frames in float [0..1] to mp4 as uint8."""
|
| 262 |
+
frames_uint8 = [(f * 255).astype(np.uint8) for f in frames_float_01]
|
| 263 |
+
|
| 264 |
+
# PyAV backend doesn't support `quality=...`
|
| 265 |
+
with imageio.get_writer(out_path, fps=fps, macro_block_size=1) as writer:
|
| 266 |
+
for fr in frames_uint8:
|
| 267 |
+
writer.append_data(fr)
|
| 268 |
+
return out_path
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def preprocess_video_to_canny_mp4(video_path: str, width: int, height: int, fps: float):
|
| 273 |
+
"""End-to-end: read video -> canny -> write temp mp4 -> return path."""
|
| 274 |
+
frames = load_video_frames(video_path)
|
| 275 |
+
canny_frames = process_video_for_canny(frames, width=width, height=height)
|
| 276 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
| 277 |
+
tmp.close()
|
| 278 |
+
return write_video_mp4(canny_frames, fps=fps, out_path=tmp.name)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
def encode_text_simple(text_encoder, prompt: str):
|
| 282 |
"""Simple text encoding without using pipeline_utils."""
|
| 283 |
v_context, a_context, _ = text_encoder(prompt)
|
|
|
|
| 450 |
("Slide Right", 5),
|
| 451 |
("Slide Down", 6),
|
| 452 |
("Slide Up", 7),
|
| 453 |
+
|
| 454 |
]
|
| 455 |
|
| 456 |
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
|
|
|
|
| 769 |
|
| 770 |
w, h = apply_resolution(resolution)
|
| 771 |
|
| 772 |
+
output_video = generate_video(
|
| 773 |
input_image,
|
| 774 |
prompt,
|
| 775 |
10, # duration seconds
|
|
|
|
| 778 |
True, # randomize_seed
|
| 779 |
h, # height
|
| 780 |
w, # width
|
| 781 |
+
camera_lora,
|
| 782 |
+
None,
|
| 783 |
progress
|
| 784 |
)
|
| 785 |
|
| 786 |
return output_video
|
| 787 |
+
|
|
|
|
| 788 |
def generate_video_example_t2v(prompt, camera_lora, resolution, progress=gr.Progress(track_tqdm=True)):
|
| 789 |
|
| 790 |
w, h = apply_resolution(resolution)
|
| 791 |
|
| 792 |
+
output_video = generate_video(
|
| 793 |
None,
|
| 794 |
prompt,
|
| 795 |
15, # duration seconds
|
|
|
|
| 798 |
True, # randomize_seed
|
| 799 |
h, # height
|
| 800 |
w, # width
|
| 801 |
+
camera_lora,
|
| 802 |
+
None,
|
| 803 |
+
progress
|
| 804 |
+
)
|
| 805 |
+
return output_video
|
| 806 |
+
|
| 807 |
+
def generate_video_example_s2v(input_image, prompt, camera_lora, resolution, audio_path, progress=gr.Progress(track_tqdm=True)):
|
| 808 |
+
|
| 809 |
+
w, h = apply_resolution(resolution)
|
| 810 |
+
|
| 811 |
+
output_video = generate_video(
|
| 812 |
+
input_image,
|
| 813 |
+
prompt,
|
| 814 |
+
10, # duration seconds
|
| 815 |
+
True, # enhance_prompt
|
| 816 |
+
42, # seed
|
| 817 |
+
True, # randomize_seed
|
| 818 |
+
h, # height
|
| 819 |
+
w, # width
|
| 820 |
camera_lora,
|
| 821 |
+
audio_path,
|
| 822 |
progress
|
| 823 |
)
|
| 824 |
+
|
| 825 |
return output_video
|
| 826 |
+
|
| 827 |
def get_duration(
|
| 828 |
input_image,
|
| 829 |
prompt,
|
|
|
|
| 834 |
height,
|
| 835 |
width,
|
| 836 |
camera_lora,
|
| 837 |
+
audio_path,
|
| 838 |
progress
|
| 839 |
):
|
| 840 |
+
extra_time = 0
|
| 841 |
+
|
| 842 |
+
if audio_path is not None or input_image is None:
|
| 843 |
+
extra_time += 10
|
| 844 |
+
|
| 845 |
if duration <= 5:
|
| 846 |
+
return 80 + extra_time
|
| 847 |
elif duration <= 10:
|
| 848 |
+
return 120 + extra_time
|
| 849 |
else:
|
| 850 |
+
return 180 + extra_time
|
| 851 |
|
| 852 |
@spaces.GPU(duration=get_duration)
|
| 853 |
def generate_video(
|
|
|
|
| 860 |
height: int = DEFAULT_1_STAGE_HEIGHT,
|
| 861 |
width: int = DEFAULT_1_STAGE_WIDTH,
|
| 862 |
camera_lora: str = "No LoRA",
|
| 863 |
+
audio_path = None,
|
| 864 |
progress=gr.Progress(track_tqdm=True),
|
| 865 |
):
|
| 866 |
"""
|
|
|
|
| 922 |
audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
|
| 923 |
print("✓ Embeddings loaded successfully")
|
| 924 |
|
| 925 |
+
|
| 926 |
# free prompt enhancer / encoder temps ASAP
|
| 927 |
del embeddings, final_prompt, status
|
| 928 |
torch.cuda.empty_cache()
|
| 929 |
|
| 930 |
+
# ✅ if user provided audio, use a neutral audio_context
|
| 931 |
+
n_audio_context = None
|
| 932 |
+
|
| 933 |
+
if audio_path is not None:
|
| 934 |
+
with torch.inference_mode():
|
| 935 |
+
_, n_audio_context = encode_text_simple(text_encoder, "") # returns tensors on GPU already
|
| 936 |
+
del audio_context
|
| 937 |
+
audio_context = n_audio_context
|
| 938 |
+
|
| 939 |
+
torch.cuda.empty_cache()
|
| 940 |
|
| 941 |
# Map dropdown name -> adapter index
|
| 942 |
name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
|
|
|
|
| 945 |
enable_only_lora(pipeline._transformer, selected_idx)
|
| 946 |
torch.cuda.empty_cache()
|
| 947 |
|
| 948 |
+
# True video duration in seconds based on your rounding
|
| 949 |
+
video_seconds = (num_frames - 1) / frame_rate
|
| 950 |
+
|
| 951 |
+
if audio_path is not None:
|
| 952 |
+
input_waveform, input_waveform_sample_rate = match_audio_to_duration(
|
| 953 |
+
audio_path=audio_path,
|
| 954 |
+
target_seconds=video_seconds,
|
| 955 |
+
target_sr=48000, # pick what your model expects; 48k is common for AV models
|
| 956 |
+
to_mono=True, # set False if your model wants stereo
|
| 957 |
+
pad_mode="silence", # or "repeat" if you prefer looping over silence
|
| 958 |
+
device="cuda",
|
| 959 |
+
)
|
| 960 |
+
else:
|
| 961 |
+
input_waveform = None
|
| 962 |
+
input_waveform_sample_rate = None
|
| 963 |
+
|
| 964 |
# Run inference - progress automatically tracks tqdm from pipeline
|
| 965 |
with torch.inference_mode():
|
| 966 |
pipeline(
|
|
|
|
| 975 |
tiling_config=TilingConfig.default(),
|
| 976 |
video_context=video_context,
|
| 977 |
audio_context=audio_context,
|
| 978 |
+
input_waveform=input_waveform,
|
| 979 |
+
input_waveform_sample_rate=input_waveform_sample_rate,
|
| 980 |
)
|
| 981 |
del video_context, audio_context
|
| 982 |
torch.cuda.empty_cache()
|
| 983 |
print("successful generation")
|
| 984 |
|
| 985 |
+
return str(output_path)
|
| 986 |
|
| 987 |
|
| 988 |
|
|
|
|
| 1406 |
height=512
|
| 1407 |
)
|
| 1408 |
|
|
|
|
| 1409 |
prompt_ui = PromptBox(
|
| 1410 |
value="Make this image come alive with cinematic motion, smooth animation",
|
| 1411 |
elem_id="prompt_ui",
|
| 1412 |
)
|
| 1413 |
|
| 1414 |
+
audio_input = gr.Audio(label="Audio (Optional)", type="filepath")
|
| 1415 |
+
|
| 1416 |
prompt = gr.Textbox(
|
| 1417 |
label="Prompt",
|
| 1418 |
value="Make this image come alive with cinematic motion, smooth animation",
|
|
|
|
| 1549 |
height,
|
| 1550 |
width,
|
| 1551 |
camera_lora,
|
| 1552 |
+
audio_input
|
| 1553 |
],
|
| 1554 |
+
outputs=[output_video]
|
| 1555 |
)
|
| 1556 |
|
| 1557 |
|
| 1558 |
+
|
| 1559 |
timestep_prompt = """Style: Realistic live-action, cinematic, shallow depth of field, 24 fps, natural and dramatic lighting
|
| 1560 |
|
| 1561 |
Environment: Interior of a space station module or realistic mock-up, metal panels, blinking lights, Earth visible through a large window
|
|
|
|
| 1580 |
|
| 1581 |
Music: subtle cinematic synth or ambient pad, futuristic and minimal, emphasizing awe and solitude"""
|
| 1582 |
|
| 1583 |
+
gr.Examples(
|
| 1584 |
+
examples=[
|
| 1585 |
+
|
| 1586 |
+
[
|
| 1587 |
+
"supergirl-2.png",
|
| 1588 |
+
"A fuzzy puppet superhero character resembling a female puppet with blonde hair and a blue superhero suit sleeping in bed and just waking up, she gradually gets up, rubbing her eyes and looking at her dog that just popped on the bed. the scene feels chaotic, comedic, and emotional with expressive puppet reactions, cinematic lighting, smooth camera motion, shallow depth of field, and high-quality puppet-style animation",
|
| 1589 |
+
"Static",
|
| 1590 |
+
"16:9",
|
| 1591 |
+
"supergirl.m4a"
|
| 1592 |
+
],
|
| 1593 |
+
|
| 1594 |
+
],
|
| 1595 |
+
fn=generate_video_example_s2v,
|
| 1596 |
+
inputs=[input_image, prompt_ui, camera_lora_ui, radioanimated_resolution, audio_input],
|
| 1597 |
+
outputs = [output_video],
|
| 1598 |
+
label="S2V Example",
|
| 1599 |
+
cache_examples=True,
|
| 1600 |
+
)
|
| 1601 |
|
| 1602 |
gr.Examples(
|
| 1603 |
examples=[
|
|
|
|
| 1669 |
)
|
| 1670 |
|
| 1671 |
|
|
|
|
| 1672 |
if __name__ == "__main__":
|
| 1673 |
demo.launch(ssr_mode=False, mcp_server=True, css=css)
|
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""Audio VAE model components."""
|
| 2 |
|
| 3 |
-
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
| 4 |
from ltx_core.model.audio_vae.model_configurator import (
|
| 5 |
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 6 |
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
|
@@ -24,4 +24,5 @@ __all__ = [
|
|
| 24 |
"Vocoder",
|
| 25 |
"VocoderConfigurator",
|
| 26 |
"decode_audio",
|
|
|
|
| 27 |
]
|
|
|
|
| 1 |
"""Audio VAE model components."""
|
| 2 |
|
| 3 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio
|
| 4 |
from ltx_core.model.audio_vae.model_configurator import (
|
| 5 |
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 6 |
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
|
|
|
| 24 |
"Vocoder",
|
| 25 |
"VocoderConfigurator",
|
| 26 |
"decode_audio",
|
| 27 |
+
"encode_audio",
|
| 28 |
]
|
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py
CHANGED
|
@@ -8,7 +8,7 @@ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
|
| 8 |
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 9 |
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 10 |
from ltx_core.model.audio_vae.downsample import build_downsampling_path
|
| 11 |
-
from ltx_core.model.audio_vae.ops import PerChannelStatistics
|
| 12 |
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 13 |
from ltx_core.model.audio_vae.upsample import build_upsampling_path
|
| 14 |
from ltx_core.model.audio_vae.vocoder import Vocoder
|
|
@@ -464,6 +464,57 @@ class AudioDecoder(torch.nn.Module):
|
|
| 464 |
h = self.conv_out(h)
|
| 465 |
return torch.tanh(h) if self.tanh_out else h
|
| 466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
|
| 469 |
"""
|
|
|
|
| 8 |
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 9 |
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 10 |
from ltx_core.model.audio_vae.downsample import build_downsampling_path
|
| 11 |
+
from ltx_core.model.audio_vae.ops import PerChannelStatistics, AudioProcessor
|
| 12 |
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 13 |
from ltx_core.model.audio_vae.upsample import build_upsampling_path
|
| 14 |
from ltx_core.model.audio_vae.vocoder import Vocoder
|
|
|
|
| 464 |
h = self.conv_out(h)
|
| 465 |
return torch.tanh(h) if self.tanh_out else h
|
| 466 |
|
| 467 |
+
@torch.no_grad()
|
| 468 |
+
def encode_audio(
|
| 469 |
+
waveform: torch.Tensor,
|
| 470 |
+
waveform_sample_rate: int,
|
| 471 |
+
*,
|
| 472 |
+
audio_encoder: "AudioEncoder",
|
| 473 |
+
audio_processor: "AudioProcessor",
|
| 474 |
+
return_mean_only: bool = False,
|
| 475 |
+
) -> torch.Tensor:
|
| 476 |
+
"""
|
| 477 |
+
Encode a waveform into an audio latent representation.
|
| 478 |
+
|
| 479 |
+
Args:
|
| 480 |
+
waveform: Audio waveform tensor.
|
| 481 |
+
Expected shapes:
|
| 482 |
+
- (T,) -> treated as (1,1,T)
|
| 483 |
+
- (B,T) -> treated as (B,1,T)
|
| 484 |
+
- (B,C,T) -> used as-is
|
| 485 |
+
waveform_sample_rate: Sample rate of the provided waveform.
|
| 486 |
+
audio_encoder: AudioEncoder that consumes (B, C, frames, mel_bins).
|
| 487 |
+
audio_processor: AudioProcessor from ops.py that produces log-mel features.
|
| 488 |
+
return_mean_only: If True and encoder outputs double_z, return only the mean half.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Latent tensor from AudioEncoder.
|
| 492 |
+
If return_mean_only=True and double_z=True: returns (B, z_channels, frames, mel_bins).
|
| 493 |
+
Otherwise returns the raw encoder output (often (B, 2*z_channels, frames, mel_bins)).
|
| 494 |
+
"""
|
| 495 |
+
# --- normalize waveform shape to (B, C, T) ---
|
| 496 |
+
if waveform.dim() == 1:
|
| 497 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
| 498 |
+
elif waveform.dim() == 2:
|
| 499 |
+
waveform = waveform.unsqueeze(1)
|
| 500 |
+
elif waveform.dim() == 3:
|
| 501 |
+
pass
|
| 502 |
+
else:
|
| 503 |
+
raise ValueError(f"Unexpected waveform shape: {tuple(waveform.shape)}")
|
| 504 |
+
|
| 505 |
+
waveform = waveform.float()
|
| 506 |
+
|
| 507 |
+
# --- waveform -> log-mel spectrogram (B, C, frames, mel_bins) ---
|
| 508 |
+
mel = audio_processor.waveform_to_mel(waveform, waveform_sample_rate)
|
| 509 |
+
|
| 510 |
+
# --- mel -> latent ---
|
| 511 |
+
latent = audio_encoder(mel)
|
| 512 |
+
|
| 513 |
+
if return_mean_only and getattr(audio_encoder, "double_z", False):
|
| 514 |
+
latent = torch.chunk(latent, 2, dim=1)[0]
|
| 515 |
+
|
| 516 |
+
return latent
|
| 517 |
+
|
| 518 |
|
| 519 |
def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
|
| 520 |
"""
|
packages/ltx-pipelines/src/ltx_pipelines/distilled.py
CHANGED
|
@@ -8,7 +8,7 @@ from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
|
| 8 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 9 |
from ltx_core.components.noisers import GaussianNoiser
|
| 10 |
from ltx_core.components.protocols import DiffusionStepProtocol
|
| 11 |
-
from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex
|
| 12 |
from ltx_core.loader import LoraPathStrengthAndSDOps
|
| 13 |
from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
|
| 14 |
from ltx_core.model.upsampler import upsample_video
|
|
@@ -16,6 +16,7 @@ from ltx_core.model.video_vae import TilingConfig, VideoEncoder, get_video_chunk
|
|
| 16 |
from ltx_core.model.video_vae import decode_video as vae_decode_video
|
| 17 |
from ltx_core.text_encoders.gemma import encode_text
|
| 18 |
from ltx_core.types import LatentState, VideoPixelShape
|
|
|
|
| 19 |
from ltx_pipelines import utils
|
| 20 |
from ltx_pipelines.utils import ModelLedger
|
| 21 |
from ltx_pipelines.utils.args import default_2_stage_distilled_arg_parser
|
|
@@ -38,6 +39,42 @@ from ltx_pipelines.utils.helpers import (
|
|
| 38 |
from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
|
| 39 |
from ltx_pipelines.utils.types import PipelineComponents
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
device = get_device()
|
| 42 |
|
| 43 |
|
|
@@ -74,6 +111,151 @@ class DistilledPipeline:
|
|
| 74 |
# Cached models to avoid reloading
|
| 75 |
self._video_encoder = None
|
| 76 |
self._transformer = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
@torch.inference_mode()
|
| 79 |
def __call__(
|
|
@@ -92,12 +274,27 @@ class DistilledPipeline:
|
|
| 92 |
tiling_config: TilingConfig | None = None,
|
| 93 |
video_context: torch.Tensor | None = None,
|
| 94 |
audio_context: torch.Tensor | None = None,
|
|
|
|
|
|
|
|
|
|
| 95 |
) -> None:
|
| 96 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 97 |
noiser = GaussianNoiser(generator=generator)
|
| 98 |
stepper = EulerDiffusionStep()
|
| 99 |
dtype = torch.bfloat16
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Use pre-computed embeddings if provided, otherwise encode text
|
| 102 |
if video_context is None or audio_context is None:
|
| 103 |
text_encoder = self.model_ledger.text_encoder()
|
|
@@ -153,6 +350,7 @@ class DistilledPipeline:
|
|
| 153 |
video_state, audio_state = denoise_audio_video(
|
| 154 |
output_shape=stage_1_output_shape,
|
| 155 |
conditionings=stage_1_conditionings,
|
|
|
|
| 156 |
noiser=noiser,
|
| 157 |
sigmas=stage_1_sigmas,
|
| 158 |
stepper=stepper,
|
|
@@ -197,6 +395,7 @@ class DistilledPipeline:
|
|
| 197 |
video_state, audio_state = denoise_audio_video(
|
| 198 |
output_shape=stage_2_output_shape,
|
| 199 |
conditionings=stage_2_conditionings,
|
|
|
|
| 200 |
noiser=noiser,
|
| 201 |
sigmas=stage_2_sigmas,
|
| 202 |
stepper=stepper,
|
|
@@ -227,8 +426,6 @@ class DistilledPipeline:
|
|
| 227 |
)
|
| 228 |
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
def _create_conditionings(
|
| 233 |
self,
|
| 234 |
images: list[tuple[str, int, float]],
|
|
@@ -275,4 +472,4 @@ class DistilledPipeline:
|
|
| 275 |
)
|
| 276 |
)
|
| 277 |
|
| 278 |
-
return conditionings
|
|
|
|
| 8 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 9 |
from ltx_core.components.noisers import GaussianNoiser
|
| 10 |
from ltx_core.components.protocols import DiffusionStepProtocol
|
| 11 |
+
from ltx_core.conditioning import ConditioningItem, VideoConditionByKeyframeIndex, ConditioningError
|
| 12 |
from ltx_core.loader import LoraPathStrengthAndSDOps
|
| 13 |
from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
|
| 14 |
from ltx_core.model.upsampler import upsample_video
|
|
|
|
| 16 |
from ltx_core.model.video_vae import decode_video as vae_decode_video
|
| 17 |
from ltx_core.text_encoders.gemma import encode_text
|
| 18 |
from ltx_core.types import LatentState, VideoPixelShape
|
| 19 |
+
from ltx_core.tools import LatentTools
|
| 20 |
from ltx_pipelines import utils
|
| 21 |
from ltx_pipelines.utils import ModelLedger
|
| 22 |
from ltx_pipelines.utils.args import default_2_stage_distilled_arg_parser
|
|
|
|
| 39 |
from ltx_pipelines.utils.media_io import encode_video, load_video_conditioning
|
| 40 |
from ltx_pipelines.utils.types import PipelineComponents
|
| 41 |
|
| 42 |
+
import torchaudio
|
| 43 |
+
from ltx_core.model.audio_vae import AudioProcessor
|
| 44 |
+
from ltx_core.types import AudioLatentShape, VideoPixelShape
|
| 45 |
+
|
| 46 |
+
class AudioConditionByLatent(ConditioningItem):
|
| 47 |
+
"""
|
| 48 |
+
Conditions audio generation by injecting a full latent sequence.
|
| 49 |
+
Replaces tokens in the latent state with the provided audio latents,
|
| 50 |
+
and sets denoise strength according to the strength parameter.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, latent: torch.Tensor, strength: float):
|
| 54 |
+
self.latent = latent
|
| 55 |
+
self.strength = strength
|
| 56 |
+
|
| 57 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 58 |
+
if not isinstance(latent_tools.target_shape, AudioLatentShape):
|
| 59 |
+
raise ConditioningError("Audio conditioning requires an audio latent target shape.")
|
| 60 |
+
|
| 61 |
+
cond_batch, cond_channels, cond_frames, cond_bins = self.latent.shape
|
| 62 |
+
tgt_batch, tgt_channels, tgt_frames, tgt_bins = latent_tools.target_shape.to_torch_shape()
|
| 63 |
+
|
| 64 |
+
if (cond_batch, cond_channels, cond_frames, cond_bins) != (tgt_batch, tgt_channels, tgt_frames, tgt_bins):
|
| 65 |
+
raise ConditioningError(
|
| 66 |
+
f"Can't apply audio conditioning item to latent with shape {latent_tools.target_shape}, expected "
|
| 67 |
+
f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_bins})."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 71 |
+
latent_state = latent_state.clone()
|
| 72 |
+
latent_state.latent[:, : tokens.shape[1]] = tokens
|
| 73 |
+
latent_state.clean_latent[:, : tokens.shape[1]] = tokens
|
| 74 |
+
latent_state.denoise_mask[:, : tokens.shape[1]] = 1.0 - self.strength
|
| 75 |
+
|
| 76 |
+
return latent_state
|
| 77 |
+
|
| 78 |
device = get_device()
|
| 79 |
|
| 80 |
|
|
|
|
| 111 |
# Cached models to avoid reloading
|
| 112 |
self._video_encoder = None
|
| 113 |
self._transformer = None
|
| 114 |
+
|
| 115 |
+
def _build_audio_conditionings_from_waveform(
|
| 116 |
+
self,
|
| 117 |
+
input_waveform: torch.Tensor,
|
| 118 |
+
input_sample_rate: int,
|
| 119 |
+
num_frames: int,
|
| 120 |
+
fps: float,
|
| 121 |
+
strength: float,
|
| 122 |
+
) -> list[AudioConditionByLatent] | None:
|
| 123 |
+
strength = float(strength)
|
| 124 |
+
if strength <= 0.0:
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
# Expect waveform as:
|
| 128 |
+
# - (T,) or (C,T) or (B,C,T). Convert to (B,C,T)
|
| 129 |
+
waveform = input_waveform
|
| 130 |
+
if waveform.ndim == 1:
|
| 131 |
+
waveform = waveform.unsqueeze(0).unsqueeze(0)
|
| 132 |
+
elif waveform.ndim == 2:
|
| 133 |
+
waveform = waveform.unsqueeze(0)
|
| 134 |
+
elif waveform.ndim != 3:
|
| 135 |
+
raise ValueError(f"input_waveform must be 1D/2D/3D, got shape {tuple(waveform.shape)}")
|
| 136 |
+
|
| 137 |
+
# Get audio encoder + its config
|
| 138 |
+
audio_encoder = self.model_ledger.audio_encoder() # assumes ledger exposes it
|
| 139 |
+
# If you want to cache it like video_encoder/transformer, you can.
|
| 140 |
+
target_sr = int(getattr(audio_encoder, "sample_rate"))
|
| 141 |
+
target_channels = int(getattr(audio_encoder, "in_channels", waveform.shape[1]))
|
| 142 |
+
mel_bins = int(getattr(audio_encoder, "mel_bins"))
|
| 143 |
+
mel_hop = int(getattr(audio_encoder, "mel_hop_length"))
|
| 144 |
+
n_fft = int(getattr(audio_encoder, "n_fft"))
|
| 145 |
+
|
| 146 |
+
# Match channels
|
| 147 |
+
if waveform.shape[1] != target_channels:
|
| 148 |
+
if waveform.shape[1] == 1 and target_channels > 1:
|
| 149 |
+
waveform = waveform.repeat(1, target_channels, 1)
|
| 150 |
+
elif target_channels == 1:
|
| 151 |
+
waveform = waveform.mean(dim=1, keepdim=True)
|
| 152 |
+
else:
|
| 153 |
+
waveform = waveform[:, :target_channels, :]
|
| 154 |
+
if waveform.shape[1] < target_channels:
|
| 155 |
+
pad_ch = target_channels - waveform.shape[1]
|
| 156 |
+
pad = torch.zeros((waveform.shape[0], pad_ch, waveform.shape[2]), dtype=waveform.dtype)
|
| 157 |
+
waveform = torch.cat([waveform, pad], dim=1)
|
| 158 |
+
|
| 159 |
+
# Resample if needed (CPU float32 is safest for torchaudio)
|
| 160 |
+
waveform = waveform.to(device="cpu", dtype=torch.float32)
|
| 161 |
+
if int(input_sample_rate) != target_sr:
|
| 162 |
+
waveform = torchaudio.functional.resample(waveform, int(input_sample_rate), target_sr)
|
| 163 |
+
|
| 164 |
+
# Waveform -> Mel
|
| 165 |
+
audio_processor = AudioProcessor(
|
| 166 |
+
sample_rate=target_sr,
|
| 167 |
+
mel_bins=mel_bins,
|
| 168 |
+
mel_hop_length=mel_hop,
|
| 169 |
+
n_fft=n_fft,
|
| 170 |
+
).to(waveform.device)
|
| 171 |
+
|
| 172 |
+
mel = audio_processor.waveform_to_mel(waveform, target_sr)
|
| 173 |
+
|
| 174 |
+
# Mel -> latent (run encoder on its own device/dtype)
|
| 175 |
+
audio_params = next(audio_encoder.parameters(), None)
|
| 176 |
+
enc_device = audio_params.device if audio_params is not None else self.device
|
| 177 |
+
enc_dtype = audio_params.dtype if audio_params is not None else self.dtype
|
| 178 |
+
|
| 179 |
+
mel = mel.to(device=enc_device, dtype=enc_dtype)
|
| 180 |
+
with torch.inference_mode():
|
| 181 |
+
audio_latent = audio_encoder(mel)
|
| 182 |
+
|
| 183 |
+
# Pad/trim latent to match the target video duration
|
| 184 |
+
audio_downsample = getattr(getattr(audio_encoder, "patchifier", None), "audio_latent_downsample_factor", 4)
|
| 185 |
+
target_shape = AudioLatentShape.from_video_pixel_shape(
|
| 186 |
+
VideoPixelShape(batch=audio_latent.shape[0], frames=int(num_frames), width=1, height=1, fps=float(fps)),
|
| 187 |
+
channels=audio_latent.shape[1],
|
| 188 |
+
mel_bins=audio_latent.shape[3],
|
| 189 |
+
sample_rate=target_sr,
|
| 190 |
+
hop_length=mel_hop,
|
| 191 |
+
audio_latent_downsample_factor=audio_downsample,
|
| 192 |
+
)
|
| 193 |
+
target_frames = int(target_shape.frames)
|
| 194 |
+
|
| 195 |
+
if audio_latent.shape[2] < target_frames:
|
| 196 |
+
pad_frames = target_frames - audio_latent.shape[2]
|
| 197 |
+
pad = torch.zeros(
|
| 198 |
+
(audio_latent.shape[0], audio_latent.shape[1], pad_frames, audio_latent.shape[3]),
|
| 199 |
+
device=audio_latent.device,
|
| 200 |
+
dtype=audio_latent.dtype,
|
| 201 |
+
)
|
| 202 |
+
audio_latent = torch.cat([audio_latent, pad], dim=2)
|
| 203 |
+
elif audio_latent.shape[2] > target_frames:
|
| 204 |
+
audio_latent = audio_latent[:, :, :target_frames, :]
|
| 205 |
+
|
| 206 |
+
# Move latent to pipeline device/dtype for conditioning object
|
| 207 |
+
audio_latent = audio_latent.to(device=self.device, dtype=self.dtype)
|
| 208 |
+
|
| 209 |
+
return [AudioConditionByLatent(audio_latent, strength)]
|
| 210 |
+
|
| 211 |
+
def _prepare_output_waveform(
|
| 212 |
+
self,
|
| 213 |
+
input_waveform: torch.Tensor,
|
| 214 |
+
input_sample_rate: int,
|
| 215 |
+
target_sample_rate: int,
|
| 216 |
+
num_frames: int,
|
| 217 |
+
fps: float,
|
| 218 |
+
) -> torch.Tensor:
|
| 219 |
+
"""
|
| 220 |
+
Returns waveform on CPU, float32, resampled to target_sample_rate and
|
| 221 |
+
trimmed/padded to match video duration.
|
| 222 |
+
Output shape: (T,) for mono or (C, T) for multi-channel.
|
| 223 |
+
"""
|
| 224 |
+
wav = input_waveform
|
| 225 |
+
|
| 226 |
+
# Accept (T,), (C,T), (B,C,T)
|
| 227 |
+
if wav.ndim == 3:
|
| 228 |
+
wav = wav[0]
|
| 229 |
+
elif wav.ndim == 2:
|
| 230 |
+
pass
|
| 231 |
+
elif wav.ndim == 1:
|
| 232 |
+
wav = wav.unsqueeze(0)
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"input_waveform must be 1D/2D/3D, got {tuple(wav.shape)}")
|
| 235 |
+
|
| 236 |
+
# Now wav is (C, T)
|
| 237 |
+
wav = wav.detach().to("cpu", dtype=torch.float32)
|
| 238 |
+
|
| 239 |
+
# Resample if needed
|
| 240 |
+
if int(input_sample_rate) != int(target_sample_rate):
|
| 241 |
+
wav = torchaudio.functional.resample(wav, int(input_sample_rate), int(target_sample_rate))
|
| 242 |
+
|
| 243 |
+
# Match video duration
|
| 244 |
+
duration_sec = float(num_frames) / float(fps)
|
| 245 |
+
target_len = int(round(duration_sec * float(target_sample_rate)))
|
| 246 |
+
|
| 247 |
+
cur_len = int(wav.shape[-1])
|
| 248 |
+
if cur_len > target_len:
|
| 249 |
+
wav = wav[..., :target_len]
|
| 250 |
+
elif cur_len < target_len:
|
| 251 |
+
pad = target_len - cur_len
|
| 252 |
+
wav = torch.nn.functional.pad(wav, (0, pad))
|
| 253 |
+
|
| 254 |
+
# If mono, return (T,) for convenience
|
| 255 |
+
if wav.shape[0] == 1:
|
| 256 |
+
return wav[0]
|
| 257 |
+
return wav
|
| 258 |
+
|
| 259 |
|
| 260 |
@torch.inference_mode()
|
| 261 |
def __call__(
|
|
|
|
| 274 |
tiling_config: TilingConfig | None = None,
|
| 275 |
video_context: torch.Tensor | None = None,
|
| 276 |
audio_context: torch.Tensor | None = None,
|
| 277 |
+
input_waveform: torch.Tensor | None = None,
|
| 278 |
+
input_waveform_sample_rate: int | None = None,
|
| 279 |
+
audio_strength: float = 1.0, # or audio_scale, your naming
|
| 280 |
) -> None:
|
| 281 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 282 |
noiser = GaussianNoiser(generator=generator)
|
| 283 |
stepper = EulerDiffusionStep()
|
| 284 |
dtype = torch.bfloat16
|
| 285 |
|
| 286 |
+
audio_conditionings = None
|
| 287 |
+
if input_waveform is not None:
|
| 288 |
+
if input_waveform_sample_rate is None:
|
| 289 |
+
raise ValueError("input_waveform_sample_rate must be provided when input_waveform is set.")
|
| 290 |
+
audio_conditionings = self._build_audio_conditionings_from_waveform(
|
| 291 |
+
input_waveform=input_waveform,
|
| 292 |
+
input_sample_rate=int(input_waveform_sample_rate),
|
| 293 |
+
num_frames=num_frames,
|
| 294 |
+
fps=frame_rate,
|
| 295 |
+
strength=audio_strength,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
# Use pre-computed embeddings if provided, otherwise encode text
|
| 299 |
if video_context is None or audio_context is None:
|
| 300 |
text_encoder = self.model_ledger.text_encoder()
|
|
|
|
| 350 |
video_state, audio_state = denoise_audio_video(
|
| 351 |
output_shape=stage_1_output_shape,
|
| 352 |
conditionings=stage_1_conditionings,
|
| 353 |
+
audio_conditionings=audio_conditionings,
|
| 354 |
noiser=noiser,
|
| 355 |
sigmas=stage_1_sigmas,
|
| 356 |
stepper=stepper,
|
|
|
|
| 395 |
video_state, audio_state = denoise_audio_video(
|
| 396 |
output_shape=stage_2_output_shape,
|
| 397 |
conditionings=stage_2_conditionings,
|
| 398 |
+
audio_conditionings=audio_conditionings,
|
| 399 |
noiser=noiser,
|
| 400 |
sigmas=stage_2_sigmas,
|
| 401 |
stepper=stepper,
|
|
|
|
| 426 |
)
|
| 427 |
|
| 428 |
|
|
|
|
|
|
|
| 429 |
def _create_conditionings(
|
| 430 |
self,
|
| 431 |
images: list[tuple[str, int, float]],
|
|
|
|
| 472 |
)
|
| 473 |
)
|
| 474 |
|
| 475 |
+
return conditionings
|
packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py
CHANGED
|
@@ -245,6 +245,7 @@ def noise_audio_state(
|
|
| 245 |
device: torch.device,
|
| 246 |
noise_scale: float = 1.0,
|
| 247 |
initial_latent: torch.Tensor | None = None,
|
|
|
|
| 248 |
) -> tuple[LatentState, AudioLatentTools]:
|
| 249 |
"""Initialize and noise an audio latent state for the diffusion pipeline.
|
| 250 |
Creates an audio latent state from the output shape, applies conditionings,
|
|
@@ -262,6 +263,7 @@ def noise_audio_state(
|
|
| 262 |
device=device,
|
| 263 |
noise_scale=noise_scale,
|
| 264 |
initial_latent=initial_latent,
|
|
|
|
| 265 |
)
|
| 266 |
|
| 267 |
return audio_state, audio_tools
|
|
@@ -275,18 +277,35 @@ def create_noised_state(
|
|
| 275 |
device: torch.device,
|
| 276 |
noise_scale: float = 1.0,
|
| 277 |
initial_latent: torch.Tensor | None = None,
|
|
|
|
| 278 |
) -> LatentState:
|
| 279 |
-
"""Create a noised latent state from empty state, conditionings, and noiser.
|
| 280 |
-
Creates an empty latent state, applies conditionings, and then adds noise
|
| 281 |
-
using the provided noiser. Returns the final noised state ready for diffusion.
|
| 282 |
-
"""
|
| 283 |
state = tools.create_initial_state(device, dtype, initial_latent)
|
| 284 |
state = state_with_conditionings(state, conditionings, tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
state = noiser(state, noise_scale)
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
return state
|
| 288 |
|
| 289 |
|
|
|
|
| 290 |
def state_with_conditionings(
|
| 291 |
latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools
|
| 292 |
) -> LatentState:
|
|
@@ -302,7 +321,9 @@ def state_with_conditionings(
|
|
| 302 |
|
| 303 |
def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
|
| 304 |
"""Blend denoised output with clean state based on mask."""
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
def modality_from_latent_state(
|
|
@@ -386,10 +407,12 @@ def denoise_audio_video( # noqa: PLR0913
|
|
| 386 |
components: PipelineComponents,
|
| 387 |
dtype: torch.dtype,
|
| 388 |
device: torch.device,
|
|
|
|
| 389 |
noise_scale: float = 1.0,
|
| 390 |
initial_video_latent: torch.Tensor | None = None,
|
| 391 |
initial_audio_latent: torch.Tensor | None = None,
|
| 392 |
-
|
|
|
|
| 393 |
video_state, video_tools = noise_video_state(
|
| 394 |
output_shape=output_shape,
|
| 395 |
noiser=noiser,
|
|
@@ -403,7 +426,7 @@ def denoise_audio_video( # noqa: PLR0913
|
|
| 403 |
audio_state, audio_tools = noise_audio_state(
|
| 404 |
output_shape=output_shape,
|
| 405 |
noiser=noiser,
|
| 406 |
-
conditionings=[],
|
| 407 |
components=components,
|
| 408 |
dtype=dtype,
|
| 409 |
device=device,
|
|
@@ -411,13 +434,22 @@ def denoise_audio_video( # noqa: PLR0913
|
|
| 411 |
initial_latent=initial_audio_latent,
|
| 412 |
)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
video_state, audio_state = denoising_loop_fn(
|
| 415 |
sigmas,
|
| 416 |
video_state,
|
| 417 |
audio_state,
|
| 418 |
stepper,
|
|
|
|
| 419 |
)
|
| 420 |
|
|
|
|
|
|
|
|
|
|
| 421 |
video_state = video_tools.clear_conditioning(video_state)
|
| 422 |
video_state = video_tools.unpatchify(video_state)
|
| 423 |
audio_state = audio_tools.clear_conditioning(audio_state)
|
|
@@ -426,6 +458,7 @@ def denoise_audio_video( # noqa: PLR0913
|
|
| 426 |
return video_state, audio_state
|
| 427 |
|
| 428 |
|
|
|
|
| 429 |
_UNICODE_REPLACEMENTS = str.maketrans("\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-")
|
| 430 |
|
| 431 |
|
|
|
|
| 245 |
device: torch.device,
|
| 246 |
noise_scale: float = 1.0,
|
| 247 |
initial_latent: torch.Tensor | None = None,
|
| 248 |
+
denoise_mask: torch.Tensor | None = None
|
| 249 |
) -> tuple[LatentState, AudioLatentTools]:
|
| 250 |
"""Initialize and noise an audio latent state for the diffusion pipeline.
|
| 251 |
Creates an audio latent state from the output shape, applies conditionings,
|
|
|
|
| 263 |
device=device,
|
| 264 |
noise_scale=noise_scale,
|
| 265 |
initial_latent=initial_latent,
|
| 266 |
+
denoise_mask=denoise_mask,
|
| 267 |
)
|
| 268 |
|
| 269 |
return audio_state, audio_tools
|
|
|
|
| 277 |
device: torch.device,
|
| 278 |
noise_scale: float = 1.0,
|
| 279 |
initial_latent: torch.Tensor | None = None,
|
| 280 |
+
denoise_mask: torch.Tensor | None = None, # <-- add
|
| 281 |
) -> LatentState:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
state = tools.create_initial_state(device, dtype, initial_latent)
|
| 283 |
state = state_with_conditionings(state, conditionings, tools)
|
| 284 |
+
|
| 285 |
+
if denoise_mask is not None:
|
| 286 |
+
# Convert any tensor mask into a single scalar (solid mask behavior)
|
| 287 |
+
if isinstance(denoise_mask, torch.Tensor):
|
| 288 |
+
mask_value = float(denoise_mask.mean().item())
|
| 289 |
+
else:
|
| 290 |
+
mask_value = float(denoise_mask)
|
| 291 |
+
|
| 292 |
+
state = replace(
|
| 293 |
+
state,
|
| 294 |
+
clean_latent=state.latent.clone(),
|
| 295 |
+
denoise_mask=torch.full_like(state.denoise_mask, mask_value), # <- matches internal shape
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
state = noiser(state, noise_scale)
|
| 299 |
|
| 300 |
+
if denoise_mask is not None:
|
| 301 |
+
m = state.denoise_mask.to(dtype=state.latent.dtype, device=state.latent.device)
|
| 302 |
+
clean = state.clean_latent.to(dtype=state.latent.dtype, device=state.latent.device)
|
| 303 |
+
state = replace(state, latent=state.latent * m + clean * (1 - m))
|
| 304 |
+
|
| 305 |
return state
|
| 306 |
|
| 307 |
|
| 308 |
+
|
| 309 |
def state_with_conditionings(
|
| 310 |
latent_state: LatentState, conditioning_items: list[ConditioningItem], latent_tools: LatentTools
|
| 311 |
) -> LatentState:
|
|
|
|
| 321 |
|
| 322 |
def post_process_latent(denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
|
| 323 |
"""Blend denoised output with clean state based on mask."""
|
| 324 |
+
clean = clean.to(dtype=denoised.dtype)
|
| 325 |
+
denoise_mask = denoise_mask.to(dtype=denoised.dtype)
|
| 326 |
+
return denoised * denoise_mask + clean * (1 - denoise_mask)
|
| 327 |
|
| 328 |
|
| 329 |
def modality_from_latent_state(
|
|
|
|
| 407 |
components: PipelineComponents,
|
| 408 |
dtype: torch.dtype,
|
| 409 |
device: torch.device,
|
| 410 |
+
audio_conditionings: list[ConditioningItem] | None = None,
|
| 411 |
noise_scale: float = 1.0,
|
| 412 |
initial_video_latent: torch.Tensor | None = None,
|
| 413 |
initial_audio_latent: torch.Tensor | None = None,
|
| 414 |
+
# mask_context: MaskInjection | None = None,
|
| 415 |
+
) -> tuple[LatentState | None, LatentState | None]:
|
| 416 |
video_state, video_tools = noise_video_state(
|
| 417 |
output_shape=output_shape,
|
| 418 |
noiser=noiser,
|
|
|
|
| 426 |
audio_state, audio_tools = noise_audio_state(
|
| 427 |
output_shape=output_shape,
|
| 428 |
noiser=noiser,
|
| 429 |
+
conditionings=audio_conditionings or [],
|
| 430 |
components=components,
|
| 431 |
dtype=dtype,
|
| 432 |
device=device,
|
|
|
|
| 434 |
initial_latent=initial_audio_latent,
|
| 435 |
)
|
| 436 |
|
| 437 |
+
loop_kwargs = {}
|
| 438 |
+
# if "preview_tools" in inspect.signature(denoising_loop_fn).parameters:
|
| 439 |
+
# loop_kwargs["preview_tools"] = video_tools
|
| 440 |
+
# if "mask_context" in inspect.signature(denoising_loop_fn).parameters:
|
| 441 |
+
# loop_kwargs["mask_context"] = mask_context
|
| 442 |
video_state, audio_state = denoising_loop_fn(
|
| 443 |
sigmas,
|
| 444 |
video_state,
|
| 445 |
audio_state,
|
| 446 |
stepper,
|
| 447 |
+
**loop_kwargs,
|
| 448 |
)
|
| 449 |
|
| 450 |
+
if video_state is None or audio_state is None:
|
| 451 |
+
return None, None
|
| 452 |
+
|
| 453 |
video_state = video_tools.clear_conditioning(video_state)
|
| 454 |
video_state = video_tools.unpatchify(video_state)
|
| 455 |
audio_state = audio_tools.clear_conditioning(audio_state)
|
|
|
|
| 458 |
return video_state, audio_state
|
| 459 |
|
| 460 |
|
| 461 |
+
|
| 462 |
_UNICODE_REPLACEMENTS = str.maketrans("\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-")
|
| 463 |
|
| 464 |
|
packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py
CHANGED
|
@@ -36,10 +36,24 @@ from ltx_core.text_encoders.gemma import (
|
|
| 36 |
module_ops_from_gemma_root,
|
| 37 |
)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
class ModelLedger:
|
| 41 |
"""
|
| 42 |
-
Central coordinator for loading and building
|
| 43 |
The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
|
| 44 |
audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
|
| 45 |
factory methods for constructing model instances.
|
|
@@ -144,6 +158,14 @@ class ModelLedger:
|
|
| 144 |
registry=self.registry,
|
| 145 |
)
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
if self.gemma_root_path is not None:
|
| 148 |
self.text_encoder_builder = Builder(
|
| 149 |
model_path=self.checkpoint_path,
|
|
@@ -197,6 +219,14 @@ class ModelLedger:
|
|
| 197 |
.eval()
|
| 198 |
)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
def video_decoder(self) -> VideoDecoder:
|
| 201 |
if not hasattr(self, "vae_decoder_builder"):
|
| 202 |
raise ValueError(
|
|
|
|
| 36 |
module_ops_from_gemma_root,
|
| 37 |
)
|
| 38 |
|
| 39 |
+
from ltx_core.model.audio_vae import (
|
| 40 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 41 |
+
VOCODER_COMFY_KEYS_FILTER,
|
| 42 |
+
AudioDecoder,
|
| 43 |
+
AudioDecoderConfigurator,
|
| 44 |
+
Vocoder,
|
| 45 |
+
VocoderConfigurator,
|
| 46 |
+
AudioEncoder,
|
| 47 |
+
)
|
| 48 |
+
from ltx_core.model.audio_vae.model_configurator import (
|
| 49 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 50 |
+
AudioEncoderConfigurator,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
|
| 54 |
class ModelLedger:
|
| 55 |
"""
|
| 56 |
+
Central coordinator for loading and building models used in an LTX pipeline.
|
| 57 |
The ledger wires together multiple model builders (transformer, video VAE encoder/decoder,
|
| 58 |
audio VAE decoder, vocoder, text encoder, and optional latent upsampler) and exposes
|
| 59 |
factory methods for constructing model instances.
|
|
|
|
| 158 |
registry=self.registry,
|
| 159 |
)
|
| 160 |
|
| 161 |
+
self.audio_encoder_builder = Builder(
|
| 162 |
+
model_path=self.checkpoint_path,
|
| 163 |
+
model_class_configurator=AudioEncoderConfigurator,
|
| 164 |
+
model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 165 |
+
registry=self.registry,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
if self.gemma_root_path is not None:
|
| 170 |
self.text_encoder_builder = Builder(
|
| 171 |
model_path=self.checkpoint_path,
|
|
|
|
| 219 |
.eval()
|
| 220 |
)
|
| 221 |
|
| 222 |
+
def audio_encoder(self) -> AudioEncoder:
|
| 223 |
+
if not hasattr(self, "audio_encoder_builder"):
|
| 224 |
+
raise ValueError(
|
| 225 |
+
"Audio encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
|
| 226 |
+
)
|
| 227 |
+
return self.audio_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
def video_decoder(self) -> VideoDecoder:
|
| 231 |
if not hasattr(self, "vae_decoder_builder"):
|
| 232 |
raise ValueError(
|
requirements.txt
CHANGED
|
@@ -6,9 +6,15 @@ safetensors
|
|
| 6 |
accelerate
|
| 7 |
flashpack==0.1.2
|
| 8 |
scikit-image>=0.25.2
|
|
|
|
|
|
|
| 9 |
av
|
| 10 |
tqdm
|
| 11 |
pillow
|
| 12 |
scipy>=1.14
|
| 13 |
flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
|
| 14 |
-
bitsandbytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
accelerate
|
| 7 |
flashpack==0.1.2
|
| 8 |
scikit-image>=0.25.2
|
| 9 |
+
imageio
|
| 10 |
+
imageio-ffmpeg
|
| 11 |
av
|
| 12 |
tqdm
|
| 13 |
pillow
|
| 14 |
scipy>=1.14
|
| 15 |
flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
|
| 16 |
+
bitsandbytes
|
| 17 |
+
opencv-python
|
| 18 |
+
controlnet_aux
|
| 19 |
+
onnxruntime-gpu
|
| 20 |
+
matplotlib
|