| |
| """ |
| Use MediaPipe to detect poses in images and extract landmark coordinates. |
| |
| Features: |
| 1. Run MediaPipe pose detection on images in the train folder |
| 2. Use the nose as the head reference point (headPos) |
| 3. Process coordinates as: pos = (pos - headPos) * 100 and round to 2 decimals |
| 4. Save processed landmarks into JSON files named after the image files |
| |
| Usage: |
| python pose_detection.py [--input INPUT_DIR] [--output OUTPUT_DIR] |
| """ |
| import os |
| import json |
| import argparse |
| from pathlib import Path |
| import cv2 |
| import mediapipe as mp |
|
|
|
|
| class PoseDetector: |
| def __init__(self): |
| """Initialize MediaPipe pose detector.""" |
| self.mp_pose = mp.solutions.pose |
| self.pose = self.mp_pose.Pose( |
| static_image_mode=True, |
| model_complexity=2, |
| enable_segmentation=False, |
| min_detection_confidence=0.5 |
| ) |
| |
| |
| self.landmark_names = [ |
| 'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer', |
| 'right_eye_inner', 'right_eye', 'right_eye_outer', |
| 'left_ear', 'right_ear', 'mouth_left', 'mouth_right', |
| 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', |
| 'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky', |
| 'left_index', 'right_index', 'left_thumb', 'right_thumb', |
| 'left_hip', 'right_hip', 'left_knee', 'right_knee', |
| 'left_ankle', 'right_ankle', 'left_heel', 'right_heel', |
| 'left_foot_index', 'right_foot_index' |
| ] |
| |
| def get_head_position(self, landmarks): |
| """ |
| Compute the head reference position (use the nose landmark). |
| |
| Args: |
| landmarks: MediaPipe detected landmarks |
| |
| Returns: |
| tuple: (x, y, z) head coordinates |
| """ |
| |
| nose = landmarks[0] |
| return (nose.x, nose.y, nose.z) |
| |
| def process_landmarks(self, landmarks, head_pos): |
| """ |
| Process landmarks: pos = (pos - headPos) * 100 and round to 2 decimals. |
| |
| Args: |
| landmarks: MediaPipe detected landmarks |
| head_pos: head coordinates (x, y, z) |
| |
| Returns: |
| dict: processed landmarks dictionary |
| """ |
| processed_landmarks = {} |
| head_pos_x = head_pos[0] |
| head_pos_y = head_pos[1] |
| head_pos_z = head_pos[2] |
| |
| for i, landmark in enumerate(landmarks): |
| if i < len(self.landmark_names): |
| name = self.landmark_names[i] |
| |
| |
| rel_x = round((landmark.x - head_pos_x) * 100, 2) |
| rel_y = round((landmark.y - head_pos_y) * 100, 2) |
| rel_z = round((landmark.z - head_pos_z) * 100, 2) |
| |
| processed_landmarks[name] = { |
| 'x': rel_x, |
| 'y': rel_y, |
| 'z': rel_z, |
| 'visibility': round(landmark.visibility, 3) |
| } |
| |
| return processed_landmarks |
| |
| def detect_pose(self, image_path): |
| """ |
| Detect pose for a single image. |
| |
| Args: |
| image_path: path to the image file |
| |
| Returns: |
| dict: processed landmarks and metadata, or None on failure |
| """ |
| try: |
| |
| image = cv2.imread(str(image_path)) |
| if image is None: |
| print(f"Unable to read image: {image_path}") |
| return None |
| |
| |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| results = self.pose.process(image_rgb) |
| |
| if results.pose_landmarks is None: |
| print(f"No pose detected: {image_path}") |
| return None |
| |
| |
| landmarks = results.pose_landmarks.landmark |
| |
| |
| head_pos = self.get_head_position(landmarks) |
| |
| |
| processed_landmarks = self.process_landmarks(landmarks, head_pos) |
| |
| |
| label = image_path.parent.name |
| |
| |
| result = { |
| 'image_path': str(image_path), |
| 'image_name': image_path.name, |
| 'label': label, |
| 'head_position': { |
| 'x': round(head_pos[0], 4), |
| 'y': round(head_pos[1], 4), |
| 'z': round(head_pos[2], 4) |
| }, |
| 'landmarks': processed_landmarks, |
| 'total_landmarks': len(processed_landmarks) |
| } |
| |
| return result |
| |
| except Exception as e: |
| print(f"Error processing image {image_path}: {e}") |
| return None |
| |
| def close(self): |
| """Close MediaPipe resources.""" |
| self.pose.close() |
|
|
|
|
| def process_all_training_data(input_dir, output_dir, batch_size=100): |
| """ |
| Process all images in the training dataset and write JSON files. |
| |
| Args: |
| input_dir: input images directory (TrainData/train) |
| output_dir: output JSON directory (PoseData) |
| batch_size: progress report batch size |
| """ |
| input_path = Path(input_dir) |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
| |
| detector = PoseDetector() |
|
|
| try: |
| |
| total_images = 0 |
| success_count = 0 |
| failed_count = 0 |
| label_stats = {} |
|
|
| print(f"Starting processing dataset: {input_path}") |
| print(f"Output directory: {output_path}") |
|
|
| |
| print("Counting images...") |
| label_dirs = [] |
| for item in input_path.iterdir(): |
| if item.is_dir() and item.name.startswith('label_'): |
| label = item.name |
| image_files = [f for f in item.iterdir() |
| if f.is_file() and f.suffix.lower() in image_extensions] |
| if image_files: |
| label_dirs.append((item, label, image_files)) |
| total_images += len(image_files) |
| label_stats[label] = {'total': len(image_files), 'success': 0, 'failed': 0} |
|
|
| print(f"Found {len(label_dirs)} label directories, total {total_images} images") |
| for label, stats in label_stats.items(): |
| print(f" {label}: {stats['total']} images") |
|
|
| print("\nStarting to process images...") |
|
|
| |
| for label_dir, label_name, image_files in label_dirs: |
| print(f"\n--- Processing {label_name} ({len(image_files)} images) ---") |
|
|
| |
| output_label_dir = output_path / label_name |
| output_label_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| for i, image_file in enumerate(image_files, 1): |
| json_filename = image_file.stem + '.json' |
| json_path = output_label_dir / json_filename |
|
|
| |
| result = detector.detect_pose(image_file) |
|
|
| if result is not None: |
| |
| try: |
| with open(json_path, 'w', encoding='utf-8') as f: |
| json.dump(result, f, ensure_ascii=False, indent=2) |
| success_count += 1 |
| label_stats[label_name]['success'] += 1 |
|
|
| |
| if success_count % batch_size == 0: |
| progress = (success_count / total_images) * 100 if total_images else 0 |
| print(f" Progress: {success_count}/{total_images} ({progress:.1f}%) - Current: {label_name} {i}/{len(image_files)}") |
|
|
| except Exception as e: |
| print(f" Failed to save JSON {json_path}: {e}") |
| failed_count += 1 |
| label_stats[label_name]['failed'] += 1 |
| else: |
| failed_count += 1 |
| label_stats[label_name]['failed'] += 1 |
| if failed_count % 10 == 0: |
| print(f" Detection failed: {image_file.name}") |
|
|
| |
| stats = label_stats[label_name] |
| success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 |
| print(f" {label_name} Done: Success {stats['success']}, Failed {stats['failed']}, Success rate: {success_rate:.1f}%") |
|
|
| print("\n" + "=" * 60) |
| print("Processing complete!") |
| print(f"Total images: {total_images}") |
| print(f"Successfully processed: {success_count}") |
| print(f"Failed: {failed_count}") |
| total_success_rate = (success_count / total_images) * 100 if total_images > 0 else 0 |
| print(f"Overall success rate: {total_success_rate:.1f}%") |
|
|
| print("\nPer-label statistics:") |
| for label, stats in label_stats.items(): |
| success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 |
| print(f" {label}: {stats['success']}/{stats['total']} ({success_rate:.1f}%)") |
|
|
| print(f"\nJSON files saved to: {output_path.absolute()}") |
| print("Directory structure:") |
| print("PoseData/") |
| for label in sorted(label_stats.keys()): |
| print(f"├── {label}/") |
| print("│ └── *.json") |
|
|
| finally: |
| detector.close() |
|
|
|
|
| def process_directory(input_dir, output_dir): |
| """ |
| Process all images in a directory tree and write JSON files. |
| |
| Args: |
| input_dir: input images directory |
| output_dir: output JSON directory |
| """ |
| input_path = Path(input_dir) |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
| |
| detector = PoseDetector() |
|
|
| try: |
| |
| total_images = 0 |
| success_count = 0 |
| failed_count = 0 |
|
|
| print(f"Starting to process directory: {input_path}") |
| print(f"Output directory: {output_path}") |
|
|
| |
| for root, dirs, files in os.walk(input_path): |
| root_path = Path(root) |
|
|
| |
| relative_path = root_path.relative_to(input_path) |
| current_output_dir = output_path / relative_path |
| current_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| image_files = [f for f in files if Path(f).suffix.lower() in image_extensions] |
|
|
| if image_files: |
| print(f"\nProcessing directory: {root_path}") |
| print(f"Found {len(image_files)} images") |
|
|
| for filename in image_files: |
| total_images += 1 |
| image_path = root_path / filename |
|
|
| |
| json_filename = Path(filename).stem + '.json' |
| json_path = current_output_dir / json_filename |
|
|
| |
| result = detector.detect_pose(image_path) |
|
|
| if result is not None: |
| |
| try: |
| with open(json_path, 'w', encoding='utf-8') as f: |
| json.dump(result, f, ensure_ascii=False, indent=2) |
| success_count += 1 |
|
|
| if success_count % 50 == 0: |
| print(f"Successfully processed {success_count} images...") |
|
|
| except Exception as e: |
| print(f"Failed to save JSON {json_path}: {e}") |
| failed_count += 1 |
| else: |
| failed_count += 1 |
|
|
| print("\nProcessing complete!") |
| print(f"Total images: {total_images}") |
| print(f"Successfully processed: {success_count}") |
| print(f"Failed: {failed_count}") |
| print(f"Success rate: {success_count/total_images*100:.1f}%") |
|
|
| finally: |
| detector.close() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run MediaPipe pose detection and save landmark data") |
| parser.add_argument("--input", "-i", default="TrainData/train", |
| help="input images directory (default: TrainData/train)") |
| parser.add_argument("--output", "-o", default="PoseData", |
| help="output JSON directory (default: PoseData)") |
| parser.add_argument("--batch-size", "-b", type=int, default=100, |
| help="batch size for progress reporting (default: 100)") |
| |
| args = parser.parse_args() |
| |
| |
| if not Path(args.input).exists(): |
| print(f"Error: input directory does not exist: {args.input}") |
| return |
| |
| print("MediaPipe pose detection tool") |
| print("=" * 60) |
| print(f"Input directory: {args.input}") |
| print(f"Output directory: {args.output}") |
| print("Processing rule: pos = (pos - headPos) * 100, round to 2 decimals") |
| print("Head reference: nose") |
| print(f"Batch size: show progress every {args.batch_size} images") |
| print("=" * 60) |
| |
| |
| process_all_training_data(args.input, args.output, args.batch_size) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|