| | |
| | """ |
| | Oculus Car Part Detection Demo |
| | |
| | Demonstrates detection on car images using the extended training model. |
| | """ |
| |
|
| | import sys |
| | import requests |
| | from io import BytesIO |
| | from PIL import Image, ImageDraw, ImageFont |
| | import torch |
| | import numpy as np |
| |
|
| | |
| | from pathlib import Path |
| | sys.path.insert(0, str(Path(__file__).parent)) |
| |
|
| | from oculus_unified_model import OculusForConditionalGeneration |
| |
|
| | def visualize_results(image, output, filename="output_car_parts.png"): |
| | """Draw bounding boxes and labels on image.""" |
| | draw = ImageDraw.Draw(image) |
| | |
| | |
| | try: |
| | font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16) |
| | except: |
| | font = ImageFont.load_default() |
| | |
| | width, height = image.size |
| | |
| | |
| | COCO_CLASSES = [ |
| | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', |
| | 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', |
| | 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', |
| | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', |
| | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', |
| | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', |
| | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', |
| | 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', |
| | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', |
| | 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', |
| | 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', |
| | 'toothbrush' |
| | ] |
| |
|
| | |
| | for box, label, conf in zip(output.boxes, output.labels, output.confidences): |
| | |
| | x1, y1, x2, y2 = box |
| | |
| | |
| | x1 = max(0.0, min(1.0, x1)) |
| | y1 = max(0.0, min(1.0, y1)) |
| | x2 = max(0.0, min(1.0, x2)) |
| | y2 = max(0.0, min(1.0, y2)) |
| | |
| | |
| | if x2 <= x1 or y2 <= y1: |
| | continue |
| | |
| | x1 *= width |
| | y1 *= height |
| | x2 *= width |
| | y2 *= height |
| | |
| | |
| | color = "red" if conf < 0.5 else "green" |
| | |
| | draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
| | |
| | |
| | try: |
| | class_name = COCO_CLASSES[int(label)] |
| | except: |
| | class_name = str(label) |
| | |
| | label_text = f"{class_name} ({conf:.2f})" |
| | |
| | |
| | text_bbox = draw.textbbox((x1, y1), label_text, font=font) |
| | draw.rectangle(text_bbox, fill=color) |
| | draw.text((x1, y1), label_text, fill="white", font=font) |
| | |
| | image.save(filename) |
| | print(f"Saved visualization to {filename}") |
| |
|
| | def main(): |
| | import argparse |
| | parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo") |
| | parser.add_argument("--image", type=str, help="Path to image file to test") |
| | parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model") |
| | parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode") |
| | parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold") |
| | parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename") |
| | args = parser.parse_args() |
| | |
| | |
| | |
| | checkpoint_dir = Path("checkpoints/oculus_detection_v2") |
| | model_path = None |
| | |
| | if checkpoint_dir.exists(): |
| | |
| | steps = [] |
| | for d in checkpoint_dir.iterdir(): |
| | if d.is_dir() and d.name.startswith("step_"): |
| | try: |
| | step = int(d.name.split("_")[1]) |
| | steps.append((step, d)) |
| | except: |
| | pass |
| | |
| | |
| | if steps: |
| | steps.sort(key=lambda x: x[0], reverse=True) |
| | model_path = str(steps[0][1]) |
| | print(f"✨ Found latest checkpoint: {model_path}") |
| | |
| | if model_path is None: |
| | model_path = str(checkpoint_dir / "final") |
| | |
| | |
| | if not Path(model_path).exists(): |
| | model_path = "checkpoints/oculus_detection/final" |
| | print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}") |
| | |
| | print(f"Loading model from {model_path}...") |
| | try: |
| | model = OculusForConditionalGeneration.from_pretrained(model_path) |
| | |
| | |
| | heads_path = Path(model_path) / "heads.pth" |
| | if heads_path.exists(): |
| | heads = torch.load(heads_path, map_location="cpu") |
| | model.detection_head.load_state_dict(heads['detection']) |
| | print("✓ Loaded detection heads") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | return |
| |
|
| | |
| | if args.image: |
| | image_path = args.image |
| | print(f"\nProcessing Custom Image: {image_path}...") |
| | else: |
| | |
| | |
| | image_path = "data/coco/images/000000071345.jpg" |
| | print(f"\nProcessing Default Image: {image_path}...") |
| | |
| | try: |
| | if Path(image_path).exists(): |
| | image = Image.open(image_path).convert('RGB') |
| | else: |
| | |
| | |
| | url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg" |
| | print(f"Image not found, downloading sample {url}...") |
| | response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) |
| | image = Image.open(BytesIO(response.content)).convert('RGB') |
| | |
| | |
| | if args.mode == "box": |
| | print(f"Running detection with prompt: '{args.prompt}'...") |
| | output = model.generate( |
| | image, |
| | mode="box", |
| | prompt=args.prompt, |
| | threshold=args.threshold |
| | ) |
| | print(f"Found {len(output.boxes)} objects") |
| | visualize_results(image, output, args.output) |
| | |
| | elif args.mode == "caption": |
| | print("Generating caption...") |
| | output = model.generate(image, mode="text", prompt="A photo of") |
| | print(f"\n📝 Caption: {output.text}\n") |
| | |
| | elif args.mode == "vqa": |
| | question = args.prompt if args.prompt != "Detect objects" else "What is in this image?" |
| | print(f"Thinking about question: '{question}'...") |
| | output = model.generate(image, mode="text", prompt=question) |
| | print(f"\n🤔 Answer: {output.text}\n") |
| | |
| | except Exception as e: |
| | print(f"Error processing image: {e}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|