""" ╔══════════════════════════════════════════════════════════════════════╗ ║ AI Video Background Studio — Professional Edition ║ ║ Powered by BiRefNet · Gradio · MoviePy · PyTorch ║ ╚══════════════════════════════════════════════════════════════════════╝ """ # ─── Standard Library ──────────────────────────────────────────────────────── import gc import logging import os import tempfile import time import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path from typing import Generator, List, Optional, Tuple, Union # ─── Third-Party ───────────────────────────────────────────────────────────── import spaces import gradio as gr import numpy as np import torch from PIL import Image, ImageEnhance, ImageFilter from pydub import AudioSegment from torchvision import transforms from transformers import AutoModelForImageSegmentation from moviepy import VideoFileClip, vfx, concatenate_videoclips, ImageSequenceClip # ─── Logging Setup ─────────────────────────────────────────────────────────── logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) log = logging.getLogger("AIBGStudio") # ─── Configuration ──────────────────────────────────────────────────────────── @dataclass class Config: device: str = "cuda" if torch.cuda.is_available() else "cpu" input_size: Tuple[int, int] = (1024, 1024) # higher res → sharper masks lite_size: Tuple[int, int] = (768, 768) model_full: str = "ZhengPeng7/BiRefNet" model_lite: str = "ZhengPeng7/BiRefNet_lite" output_codec: str = "libx264" output_crf: int = 18 # visually lossless max_temp_age_hours: int = 2 imagenet_mean: List[float] = field(default_factory=lambda: [0.485, 0.456, 0.406]) imagenet_std: List[float] = field(default_factory=lambda: [0.229, 0.224, 0.225]) CFG = Config() torch.set_float32_matmul_precision("high") # ─── Model Manager ─────────────────────────────────────────────────────────── class ModelManager: """Lazy-loads and caches segmentation models to avoid duplicate GPU copies.""" _models: dict = {} @classmethod def get(cls, fast: bool) -> torch.nn.Module: key = "lite" if fast else "full" if key not in cls._models: name = CFG.model_lite if fast else CFG.model_full log.info(f"Loading model [{key}]: {name}") m = AutoModelForImageSegmentation.from_pretrained(name, trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" m.to(device).eval() cls._models[key] = m return cls._models[key] # NOTE: Do NOT pre-load models at module level on ZeroGPU Spaces. # Models are loaded on first GPU-decorated call via ModelManager.get() # ─── Transform Builders ─────────────────────────────────────────────────────── def _build_transform(size: Tuple[int, int]) -> transforms.Compose: return transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize(CFG.imagenet_mean, CFG.imagenet_std), ]) _tf_full = _build_transform(CFG.input_size) _tf_lite = _build_transform(CFG.lite_size) # ─── Core Processing ────────────────────────────────────────────────────────── def extract_mask(image: Image.Image, fast: bool) -> Image.Image: """Run BiRefNet inference and return a soft binary mask (same size as input).""" tf = _tf_lite if fast else _tf_full model = ModelManager.get(fast) device = "cuda" if torch.cuda.is_available() else "cpu" tensor = tf(image.convert("RGB")).unsqueeze(0).to(device) with torch.no_grad(): pred = model(tensor)[-1].sigmoid().squeeze().cpu() mask = transforms.ToPILImage()(pred).resize(image.size, Image.LANCZOS) return mask def refine_mask(mask: Image.Image, blur_radius: float = 1.5, erode: int = 0) -> Image.Image: """Post-process mask: optional blur for soft edges and slight erosion.""" if blur_radius > 0: mask = mask.filter(ImageFilter.GaussianBlur(radius=blur_radius)) return mask def compose_frame( fg: Image.Image, mask: Image.Image, background: Image.Image, fg_brightness: float = 1.0, fg_contrast: float = 1.0, edge_feather: float = 1.5, ) -> Image.Image: """Composite foreground over background using the provided mask.""" size = fg.size bg = background.convert("RGBA").resize(size, Image.LANCZOS) fg_rgba = fg.convert("RGBA") # Optional FG color-grade if fg_brightness != 1.0: fg_rgba = ImageEnhance.Brightness(fg_rgba).enhance(fg_brightness) if fg_contrast != 1.0: fg_rgba = ImageEnhance.Contrast(fg_rgba).enhance(fg_contrast) # Soft edges soft_mask = refine_mask(mask, blur_radius=edge_feather) return Image.composite(fg_rgba, bg, soft_mask).convert("RGB") def build_background( bg_type: str, color: str, bg_image_path: Optional[str], bg_frame: Optional[Image.Image], ) -> Image.Image: """Return a background PIL Image for the current frame.""" if bg_type == "Transparent": return Image.new("RGBA", (1, 1), (0, 0, 0, 0)) # placeholder; handled in compose if bg_type == "Color": r, g, b = (int(color.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) return Image.new("RGBA", (2, 2), (r, g, b, 255)) if bg_type == "Image" and bg_image_path: return Image.open(bg_image_path).convert("RGBA") if bg_type == "Video" and bg_frame is not None: return bg_frame return Image.new("RGBA", (2, 2), (0, 255, 0, 255)) # fallback green def process_single_frame( frame_rgb: np.ndarray, bg: Image.Image, fast: bool, fg_brightness: float, fg_contrast: float, edge_feather: float, transparent_bg: bool, ) -> np.ndarray: """Full pipeline for a single video frame → numpy RGB array.""" pil = Image.fromarray(frame_rgb) mask = extract_mask(pil, fast) if transparent_bg: pil_rgba = pil.convert("RGBA") soft_mask = refine_mask(mask, blur_radius=edge_feather) r, g, b, _ = pil_rgba.split() out = Image.merge("RGBA", (r, g, b, soft_mask)) return np.array(out.convert("RGB")) # MoviePy needs RGB result = compose_frame(pil, mask, bg, fg_brightness, fg_contrast, edge_feather) return np.array(result) # ─── Video Helpers ──────────────────────────────────────────────────────────── def prepare_bg_video( bg_video_path: str, target_duration: float, target_fps: float, handling: str, ) -> List[np.ndarray]: bv = VideoFileClip(bg_video_path) if bv.duration < target_duration: if handling == "loop": loops = int(np.ceil(target_duration / bv.duration)) bv = concatenate_videoclips([bv] * loops) else: # slow_down factor = target_duration / bv.duration bv = bv.fx(vfx.MultiplySpeed, factor=factor) return list(bv.iter_frames(fps=target_fps)) def safe_temp_path(suffix: str = ".mp4") -> str: tmp_dir = Path(tempfile.gettempdir()) / "aibg_studio" tmp_dir.mkdir(exist_ok=True) return str(tmp_dir / f"{uuid.uuid4().hex}{suffix}") def cleanup_old_temps(max_age_hours: int = CFG.max_temp_age_hours): """Remove temp files older than max_age_hours.""" tmp_dir = Path(tempfile.gettempdir()) / "aibg_studio" if not tmp_dir.exists(): return cutoff = time.time() - max_age_hours * 3600 for f in tmp_dir.glob("*.mp4"): if f.stat().st_mtime < cutoff: try: f.unlink() except Exception: pass # ─── Main Processing Generator ─────────────────────────────────────────────── @spaces.GPU def fn( vid: str, bg_type: str = "Color", bg_image: Optional[str] = None, bg_video: Optional[str] = None, color: str = "#00FF00", fps: int = 0, video_handling: str = "loop", fast_mode: bool = True, fg_brightness: float = 1.0, fg_contrast: float = 1.0, edge_feather: float = 1.5, output_quality: str = "High (CRF 18)", ) -> Generator: cleanup_old_temps() t0 = time.time() def elapsed() -> str: return f"{time.time() - t0:.1f}s" def status(msg: str) -> str: return f"⏱ {elapsed()} | {msg}" try: # ── Guard: no video uploaded ───────────────────────────────────────── if not vid: yield gr.update(visible=False), gr.update(visible=True), status("❌ Please upload a video first.") return # ── Guard: safe color fallback ─────────────────────────────────────── safe_color = color if (color and isinstance(color, str)) else "#00FF00" # ── Load source video ──────────────────────────────────────────────── video = VideoFileClip(vid) target_fps = fps if fps > 0 else video.fps audio = video.audio frames = list(video.iter_frames(fps=target_fps)) total = len(frames) log.info(f"Source: {total} frames @ {target_fps:.2f} fps") yield ( gr.update(visible=True), gr.update(visible=False), status(f"Loaded {total} frames @ {target_fps:.1f} fps"), ) # ── Prepare background ─────────────────────────────────────────────── bg_frames_list: Optional[List[np.ndarray]] = None transparent_bg = bg_type == "Transparent" if bg_type == "Video" and bg_video: log.info("Preparing background video…") yield gr.update(), gr.update(), status("Preparing background video…") bg_frames_list = prepare_bg_video(bg_video, video.duration, target_fps, video_handling) # ── Pre-build static background (Color / Image) once ───────────────── static_bg: Optional[Image.Image] = None if bg_type not in ("Video", "Transparent"): static_bg = build_background(bg_type, safe_color, bg_image, None) # ── Frame processing (sequential — ZeroGPU requires GPU on main thread) ── crf_map = {"High (CRF 18)": 18, "Medium (CRF 23)": 23, "Low (CRF 28)": 28} crf = crf_map.get(output_quality, 18) processed_frames: List[np.ndarray] = [] for i, frame in enumerate(frames): if bg_type == "Video" and bg_frames_list: idx = min(i, len(bg_frames_list) - 1) bg_pil = Image.fromarray(bg_frames_list[idx]) else: bg_pil = static_bg arr = process_single_frame( frame, bg_pil, fast_mode, fg_brightness, fg_contrast, edge_feather, transparent_bg, ) processed_frames.append(arr) pct = (i + 1) / total * 100 yield ( arr, gr.update(visible=False), status(f"Processing frames… {i+1}/{total} ({pct:.0f}%)"), ) # ── Export video ───────────────────────────────────────────────────── yield gr.update(), gr.update(), status("Encoding final video…") out_path = safe_temp_path(".mp4") out_clip = ImageSequenceClip(processed_frames, fps=target_fps) if audio: out_clip = out_clip.with_audio(audio) ffmpeg_params = ["-crf", str(crf), "-preset", "fast", "-pix_fmt", "yuv420p"] out_clip.write_videofile( out_path, codec=CFG.output_codec, ffmpeg_params=ffmpeg_params, logger=None, ) # Free memory del processed_frames, frames gc.collect() torch.cuda.empty_cache() log.info(f"Done → {out_path} ({elapsed()})") yield gr.update(visible=False), gr.update(visible=True), status("✅ Complete!") yield out_clip.get_frame(0), out_path, status("✅ Video ready for download!") except Exception as exc: log.exception("Processing failed") yield ( gr.update(visible=False), gr.update(visible=True), status(f"❌ Error: {exc}"), ) yield None, None, status(f"❌ Error: {exc}") # ─── Custom CSS ────────────────────────────────────────────────────────────── CSS = """ /* ── Global ──────────────────────────────────── */ :root { --accent: #6366f1; --accent-h: #4f46e5; --surface: #1e1e2e; --card: #252535; --border: #3a3a5c; --text: #e2e8f0; --muted: #94a3b8; --radius: 12px; } body, .gradio-container { background: var(--surface) !important; color: var(--text) !important; } /* ── Header ──────────────────────────────────── */ .app-header { background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%); border-bottom: 1px solid var(--border); border-radius: var(--radius); padding: 28px 36px; margin-bottom: 20px; } .app-header h1 { font-size: 2rem; font-weight: 800; color: #fff; margin: 0; letter-spacing: -0.5px; } .app-header .badge { display: inline-block; background: var(--accent); color: #fff; font-size: 0.7rem; font-weight: 700; letter-spacing: 1.5px; text-transform: uppercase; padding: 2px 10px; border-radius: 999px; margin-left: 10px; vertical-align: middle; } .app-header p { color: var(--muted); margin: 6px 0 0; font-size: 0.95rem; } /* ── Cards ───────────────────────────────────── */ .card { background: var(--card); border: 1px solid var(--border); border-radius: var(--radius); padding: 20px; } .section-title { font-size: 0.75rem; font-weight: 700; letter-spacing: 1.5px; text-transform: uppercase; color: var(--muted); margin-bottom: 14px; display: flex; align-items: center; gap: 6px; } .section-title::before { content: ''; display: inline-block; width: 3px; height: 14px; background: var(--accent); border-radius: 2px; } /* ── Primary Button ──────────────────────────── */ .run-btn button { background: linear-gradient(135deg, var(--accent) 0%, var(--accent-h) 100%) !important; border: none !important; border-radius: 10px !important; color: #fff !important; font-size: 1rem !important; font-weight: 700 !important; padding: 14px 32px !important; letter-spacing: 0.3px; transition: opacity 0.2s, transform 0.1s !important; box-shadow: 0 4px 20px rgba(99,102,241,.45) !important; } .run-btn button:hover { opacity: .9; transform: translateY(-1px); } .run-btn button:active { transform: translateY(0); } /* ── Status Bar ──────────────────────────────── */ .status-box textarea { background: #0f172a !important; border: 1px solid var(--border) !important; border-radius: 8px !important; color: #7dd3fc !important; font-family: 'JetBrains Mono', 'Fira Code', monospace !important; font-size: 0.82rem !important; } /* ── Inputs ──────────────────────────────────── */ .gr-box, .gr-padded, .gr-input, .gr-panel { background: var(--card) !important; } input[type="range"] { accent-color: var(--accent); } /* ── Video preview ───────────────────────────── */ video { border-radius: 8px; } """ # ─── Gradio UI ──────────────────────────────────────────────────────────────── with gr.Blocks(theme=gr.themes.Base(), css=CSS, title="AI Video BG Studio") as demo: # ── Header ──────────────────────────────────────────────────────────────── gr.HTML("""

🎬 AI Video Background Studio Pro

Replace video backgrounds with any color, image, or video using state-of-the-art BiRefNet segmentation.

""") # ── Preview Row ─────────────────────────────────────────────────────────── with gr.Row(equal_height=True): in_video = gr.Video(label="📥 Input Video", interactive=True, height=340) stream_image = gr.Image(label="⚡ Live Preview", visible=False, height=340) out_video = gr.Video(label="📤 Output Video", height=340, interactive=False) # ── Controls Row ────────────────────────────────────────────────────────── with gr.Row(): # ─ Column 1: Background ─────────────────────────────────────────────── with gr.Column(scale=2): gr.HTML('
🖼 Background
') bg_type = gr.Radio( ["Color", "Image", "Video", "Transparent"], value="Color", label="Type", interactive=True, ) color_picker = gr.ColorPicker(label="Solid Color", value="#00FF00", visible=True, interactive=True) bg_image = gr.Image(label="Background Image", type="filepath", visible=False, interactive=True) bg_video = gr.Video(label="Background Video", visible=False, interactive=True) with gr.Group(visible=False) as video_handling_grp: video_handling = gr.Radio( ["loop", "slow_down"], label="If BG video is shorter…", value="loop", interactive=True, ) # ─ Column 2: Quality ────────────────────────────────────────────────── with gr.Column(scale=2): gr.HTML('
⚙ Quality & Performance
') fast_mode = gr.Checkbox( label="⚡ Fast Mode (BiRefNet-Lite — faster, slightly less precise)", value=True, interactive=True, ) fps_slider = gr.Slider(0, 60, step=1, value=0, label="Output FPS (0 = inherit source)", interactive=True) output_quality = gr.Radio( ["High (CRF 18)", "Medium (CRF 23)", "Low (CRF 28)"], value="High (CRF 18)", label="Output Quality", interactive=True, ) # ─ Column 3: Refinement ─────────────────────────────────────────────── with gr.Column(scale=2): gr.HTML('
✨ Fine-Tuning
') edge_feather = gr.Slider(0.0, 5.0, step=0.5, value=1.5, label="Edge Feathering (soft edges)", interactive=True) fg_brightness = gr.Slider(0.5, 2.0, step=0.05, value=1.0, label="Foreground Brightness", interactive=True) fg_contrast = gr.Slider(0.5, 2.0, step=0.05, value=1.0, label="Foreground Contrast", interactive=True) # ── Action Row ──────────────────────────────────────────────────────────── with gr.Row(): with gr.Column(elem_classes="run-btn"): submit_btn = gr.Button("🚀 Process Video", variant="primary") status_box = gr.Textbox( label="Status", interactive=False, elem_classes="status-box", lines=1, placeholder="Ready — load a video and click Process.", ) # ── Dynamic BG visibility ───────────────────────────────────────────────── def _toggle_bg(t): return ( gr.update(visible=(t == "Color")), gr.update(visible=(t == "Image")), gr.update(visible=(t == "Video")), gr.update(visible=(t == "Video")), ) bg_type.change( _toggle_bg, bg_type, [color_picker, bg_image, bg_video, video_handling_grp], ) # ── Examples ────────────────────────────────────────────────────────────── gr.HTML('
📚 Example Presets
') gr.Examples( examples=[ ["rickroll-2sec.mp4", "Video", None, "background.mp4", "#00FF00", 0, "loop", True, 1.0, 1.0, 1.5, "High (CRF 18)"], ["rickroll-2sec.mp4", "Image", "images.webp", None, "#00FF00", 0, "loop", True, 1.0, 1.0, 1.5, "High (CRF 18)"], ["rickroll-2sec.mp4", "Color", None, None, "#1a1a2e", 0, "loop", True, 1.0, 1.0, 1.5, "High (CRF 18)"], ["rickroll-2sec.mp4", "Transparent", None, None, "#00FF00", 0, "loop", False, 1.0, 1.0, 2.0, "High (CRF 18)"], ], inputs=[ in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling, fast_mode, fg_brightness, fg_contrast, edge_feather, output_quality, ], outputs=[stream_image, out_video, status_box], fn=fn, cache_examples=True, cache_mode="eager", ) # ── Wire Button ─────────────────────────────────────────────────────────── submit_btn.click( fn=fn, inputs=[ in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling, fast_mode, fg_brightness, fg_contrast, edge_feather, output_quality, ], outputs=[stream_image, out_video, status_box], ) gr.HTML("""
Powered by BiRefNet  ·  MoviePy  ·  Gradio  ·  PyTorch  ·  © AI Video Background Studio
""") # ─── Entry Point ───────────────────────────────────────────────────────────── if __name__ == "__main__": demo.launch( show_error=True, server_name="0.0.0.0", server_port=7860, share=False, )