import gradio as gr import numpy as np from PIL import Image, ImageFilter from transformers import pipeline try: import cv2 except ImportError: cv2 = None # --- Models --- depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Large-hf") # A robust, widely available panoptic segmentation model in transformers try: seg_pipe = pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic") except Exception: seg_pipe = None # We'll handle gracefully # --- Global state --- current_original_image = None current_depth_norm = None current_depth_map_pil = None current_near_softmask = None # float32 0..1, soft foreground mask current_edge_band = None # float32 0..1, where to snap edges # ---------- Utilities ---------- def _np_from_pil_gray(img_pil): return np.array(img_pil.convert("L"), dtype=np.float32) / 255.0 def _ensure_float01(x): x = x.astype(np.float32) if x.max() > 1.0 or x.min() < 0.0: x = (x - x.min()) / max(x.max() - x.min(), 1e-6) return x def _soften_mask(mask01, sigma_px=2.0): """Gaussian blur softening for a binary mask in [0,1].""" if cv2 is not None: return _ensure_float01(cv2.GaussianBlur(mask01, (0, 0), sigmaX=max(sigma_px, 1e-6))) # PIL fallback pil = Image.fromarray((mask01 * 255).astype(np.uint8), mode="L").filter(ImageFilter.GaussianBlur(radius=float(sigma_px))) return np.array(pil, dtype=np.float32) / 255.0 def _edge_band_from_mask(mask01, band_px=6): """Thin band around mask boundary where we enforce snapping.""" m = (mask01 > 0.5).astype(np.uint8) if cv2 is not None: k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) edge = cv2.morphologyEx(m, cv2.MORPH_GRADIENT, k) # 1px outline if band_px > 1: edge = cv2.dilate(edge, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (int(band_px), int(band_px)))) return edge.astype(np.float32) # PIL fallback: crude edge + dilation pil = Image.fromarray(m * 255, mode="L").filter(ImageFilter.FIND_EDGES) arr = (np.array(pil) > 0).astype(np.uint8) # simple max filter dilation for _ in range(max(int(band_px // 2), 1)): arr = np.maximum.reduce([ np.pad(arr, ((1,0),(0,0)))[:-1,:], np.pad(arr, ((0,1),(0,0)))[1:,:], np.pad(arr, ((0,0),(1,0)))[:, :-1], np.pad(arr, ((0,0),(0,1)))[:, 1:], arr ]) return arr.astype(np.float32) def _pick_nearest_segment_mask(segments, depth01, invert_depth=False): """ Given pipeline outputs (each with a 'mask' PIL) and normalized depth map, choose the segment with smallest median *distance to camera*. If invert_depth=True, we treat 1-depth as distance. """ if not segments: return None h, w = depth01.shape best_mask = None best_score = None # Interpret "nearness": if invert_depth True, higher values mean near near_map = (1.0 - depth01) if invert_depth else depth01 # If your model outputs "larger = farther", set invert_depth=True via UI for seg in segments: m = np.array(seg["mask"].resize((w, h)).convert("L")) > 127 if m.sum() < 50: # ignore tiny segments continue # Higher median(nearness) = nearer score = float(np.median(near_map[m])) if (best_score is None) or (score > best_score): best_score = score best_mask = m.astype(np.float32) if best_mask is None: return None return _ensure_float01(best_mask) def _depth_only_near_mask(depth01, near_percent=15, invert_depth=False): """ Fallback: use top N% nearest pixels (by depth) as a binary mask. near_percent in [1..40] typical. """ flat = depth01.flatten() if invert_depth: flat = 1.0 - flat q = np.quantile(flat, 1.0 - (near_percent / 100.0)) m = ((1.0 - depth01) if invert_depth else depth01) >= q return m.astype(np.float32) # ---------- Preprocess depth ---------- def preprocess_depth(depth_norm, smoothing_radius): if smoothing_radius > 0 and cv2 is not None: depth_uint8 = (depth_norm * 255.0).astype(np.uint8) sigma = max(smoothing_radius * 10.0, 1.0) smoothed = cv2.bilateralFilter(depth_uint8, d=5, sigmaColor=sigma, sigmaSpace=sigma) return smoothed.astype(np.float32) / 255.0 return depth_norm # ---------- Effect ---------- def apply_effect(threshold, depth_scale, feather, red_brightness, blue_brightness, gamma, black_level_percent, white_level_percent, smoothing_percent, use_segmentation, edge_snap_strength, edge_band_px, invert_depth, near_percent): """ Adds segmentation-assisted edge snapping: - use_segmentation: run / use segmentation-assisted mask if available - edge_snap_strength: 0..100 weight of snapping in edge band - edge_band_px: band width around boundary to force crisp transition - invert_depth: flip near/far interpretation if needed - near_percent: fallback mask when segmentation is off/failed """ global current_original_image, current_depth_norm global current_near_softmask, current_edge_band if current_original_image is None or current_depth_norm is None: return None # Levels adjustment (same as yours) black_level = black_level_percent * 2.55 white_level = white_level_percent * 2.55 gray = np.array(current_original_image.convert("L"), dtype=np.float32) denom = max(white_level - black_level, 1e-6) adjusted_gray = np.clip((gray - black_level) / denom, 0.0, 1.0) # Gamma gamma_val = 0.1 + (gamma / 100.0) * 2.9 adjusted_gray = np.clip(adjusted_gray ** gamma_val, 0.0, 1.0) # Depth smoothing smoothing_radius = smoothing_percent / 10.0 depth_smoothed = preprocess_depth(current_depth_norm, smoothing_radius) depth_for_blend = (1.0 - depth_smoothed) if invert_depth else depth_smoothed # Logistic blend from depth threshold_norm = threshold / 100.0 steepness = max(depth_scale, 1e-3) feather_norm = feather / 100.0 steepness_adj = steepness / (feather_norm * 10.0 + 1.0) blend = 1.0 / (1.0 + np.exp(-steepness_adj * (depth_for_blend - threshold_norm))) # Edge snapping: mix blend with (soft) near-object mask ONLY within edge band snap_w = np.clip(edge_snap_strength / 100.0, 0.0, 1.0) if current_near_softmask is None: # Build a fallback mask so snapping can still help fallback = _depth_only_near_mask(current_depth_norm, near_percent=near_percent, invert_depth=invert_depth) current_near_softmask_local = _soften_mask(fallback, sigma_px=max(edge_band_px / 2.0, 1.0)) current_edge_band_local = _edge_band_from_mask(fallback, band_px=max(int(edge_band_px), 1)) else: current_near_softmask_local = current_near_softmask current_edge_band_local = current_edge_band if snap_w > 0.0: # Per-pixel alpha only near edges per_pixel_alpha = snap_w * _ensure_float01(current_edge_band_local) blend = (1.0 - per_pixel_alpha) * blend + per_pixel_alpha * _ensure_float01(current_near_softmask_local) # Map brightness to factors (0-2) red_factor = red_brightness / 50.0 blue_factor = blue_brightness / 50.0 red_channel = red_factor * adjusted_gray * blend blue_channel = blue_factor * adjusted_gray * (1.0 - blend) red_img = np.clip(red_channel * 255.0, 0, 255).astype(np.uint8) blue_img = np.clip(blue_channel * 255.0, 0, 255).astype(np.uint8) h, w = red_img.shape output = np.zeros((h, w, 3), dtype=np.uint8) output[..., 0] = red_img output[..., 1] = 0 output[..., 2] = blue_img return Image.fromarray(output, mode="RGB") # ---------- Pipeline steps ---------- def _compute_segmentation_assist(img_pil, depth01, invert_depth, edge_band_px, near_percent): """ Build current_near_softmask and current_edge_band using segmentation if enabled; else fallback to depth-only near mask. """ global current_near_softmask, current_edge_band h, w = depth01.shape near_mask = None if seg_pipe is not None: try: segs = seg_pipe(img_pil) # Some models return dict with 'segments'; normalize to list if isinstance(segs, dict) and "segments_info" in segs and "segmentation" in segs: # Panoptic map; fall back to simple depth-based mask segs = [] # transformers panoptic map formats vary; keep generic path below # segs should be a list of dicts, each with a 'mask' PIL candidates = [s for s in segs if isinstance(s.get("mask", None), Image.Image)] near_mask = _pick_nearest_segment_mask(candidates, depth01, invert_depth=invert_depth) except Exception: near_mask = None if near_mask is None: # Fallback: depth-only near mask near_mask = _depth_only_near_mask(depth01, near_percent=near_percent, invert_depth=invert_depth) # Build soft mask + edge band current_near_softmask = _soften_mask(near_mask, sigma_px=max(edge_band_px / 2.0, 1.0)) current_edge_band = _edge_band_from_mask(near_mask, band_px=max(int(edge_band_px), 1)) def generate_depth_map(input_image, use_segmentation, edge_band_px, invert_depth, near_percent): """Generate normalized depth map and initial effect image.""" global current_original_image, current_depth_norm, current_depth_map_pil if input_image is None: current_original_image = None current_depth_norm = None current_depth_map_pil = None return None, None current_original_image = input_image # Depth estimation result = depth_pipe(input_image) depth = np.array(result["depth"], dtype=np.float32) depth -= depth.min() max_val = depth.max() if max_val > 0: depth /= max_val current_depth_norm = depth current_depth_map_pil = Image.fromarray((depth * 255.0).astype(np.uint8), mode="L") # Build segmentation assist (or fallback) if use_segmentation or True: # we compute once so UI changes can immediately work _compute_segmentation_assist(input_image, current_depth_norm, invert_depth, edge_band_px, near_percent) # Default effect parameters effect = apply_effect( threshold=50, depth_scale=50, feather=10, red_brightness=50, blue_brightness=50, gamma=50, black_level_percent=0, white_level_percent=100, smoothing_percent=0, use_segmentation=use_segmentation, edge_snap_strength=60, edge_band_px=edge_band_px, invert_depth=invert_depth, near_percent=near_percent, ) return current_depth_map_pil.convert("RGB"), effect def update_effect(threshold, depth_scale, feather, red_brightness, blue_brightness, gamma, black_level, white_level, smoothing, use_segmentation, edge_snap_strength, edge_band_px, invert_depth, near_percent): """Update the effect when any control changes.""" return apply_effect( threshold=threshold, depth_scale=depth_scale, feather=feather, red_brightness=red_brightness, blue_brightness=blue_brightness, gamma=gamma, black_level_percent=black_level, white_level_percent=white_level, smoothing_percent=smoothing, use_segmentation=use_segmentation, edge_snap_strength=edge_snap_strength, edge_band_px=edge_band_px, invert_depth=invert_depth, near_percent=near_percent, ) def clear_results(): """Reset global state and clear outputs.""" global current_original_image, current_depth_norm, current_depth_map_pil global current_near_softmask, current_edge_band current_original_image = None current_depth_norm = None current_depth_map_pil = None current_near_softmask = None current_edge_band = None return None, None # ---------- UI ---------- with gr.Blocks(title="ChromoStereoizer Enhanced", theme=gr.themes.Soft()) as demo: gr.Markdown("# ChromoStereoizer Enhanced (Segmentation-Assisted)") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Upload Image", type="pil", height=400) with gr.Accordion("Segmentation & Mask Options", open=False): use_segmentation = gr.Checkbox(value=True, label="Use segmentation-assisted edge snapping (falls back to depth-only)") edge_band_px = gr.Slider(1, 20, value=6, step=1, label="Edge Band Width (px)") invert_depth = gr.Checkbox(value=False, label="Invert depth (toggle if near/far feels flipped)") near_percent = gr.Slider(1, 40, value=15, step=1, label="Fallback: top N% nearest pixels") generate_btn = gr.Button("Generate Depth Map", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("**Depth Map**") depth_output = gr.Image(type="pil", height=400, interactive=False, show_download_button=True, show_label=False) gr.Markdown("**ChromoStereoizer Result**") chromo_output = gr.Image(type="pil", height=400, interactive=False, show_download_button=True, show_label=False) gr.Markdown("## Controls") threshold_slider = gr.Slider(0, 100, value=50, step=1, label="Threshold (%)") depth_scale_slider = gr.Slider(0, 100, value=50, step=1, label="Depth Scale (Steepness)") feather_slider = gr.Slider(0, 100, value=10, step=1, label="Feather (%)") red_slider = gr.Slider(0, 100, value=50, step=1, label="Red Brightness") blue_slider = gr.Slider(0, 100, value=50, step=1, label="Blue Brightness") gamma_slider = gr.Slider(0, 100, value=50, step=1, label="Gamma") black_slider = gr.Slider(0, 100, value=0, step=1, label="Black Level (%)") white_slider = gr.Slider(0, 100, value=100, step=1, label="White Level (%)") smoothing_slider = gr.Slider(0, 100, value=0, step=1, label="Depth Smoothing (%)") edge_snap_strength = gr.Slider(0, 100, value=60, step=1, label="Edge Snap Strength (%)") clear_btn = gr.Button("Clear", variant="secondary") # Events generate_btn.click( fn=generate_depth_map, inputs=[input_image, use_segmentation, edge_band_px, invert_depth, near_percent], outputs=[depth_output, chromo_output], show_progress=True, ) for ctrl in [ threshold_slider, depth_scale_slider, feather_slider, red_slider, blue_slider, gamma_slider, black_slider, white_slider, smoothing_slider, use_segmentation, edge_snap_strength, edge_band_px, invert_depth, near_percent ]: ctrl.change( fn=update_effect, inputs=[ threshold_slider, depth_scale_slider, feather_slider, red_slider, blue_slider, gamma_slider, black_slider, white_slider, smoothing_slider, use_segmentation, edge_snap_strength, edge_band_px, invert_depth, near_percent ], outputs=chromo_output, show_progress=False, ) clear_btn.click( fn=clear_results, inputs=[], outputs=[depth_output, chromo_output], ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)