Maybe implement a feature tracker
This commit is contained in:
		
							
								
								
									
										261
									
								
								croppa/main.py
									
									
									
									
									
								
							
							
						
						
									
										261
									
								
								croppa/main.py
									
									
									
									
									
								
							| @@ -4,7 +4,7 @@ import cv2 | ||||
| import argparse | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
| from typing import List, Optional, Tuple, Dict, Any | ||||
| import time | ||||
| import re | ||||
| import threading | ||||
| @@ -25,6 +25,173 @@ def load_image_utf8(image_path): | ||||
|     except Exception as e: | ||||
|         raise ValueError(f"Could not load image file: {image_path} - {e}") | ||||
|  | ||||
|  | ||||
| class FeatureTracker: | ||||
|     """Semi-automatic feature tracking with SIFT/SURF/ORB support and full state serialization""" | ||||
|      | ||||
|     def __init__(self): | ||||
|         # Feature detection parameters | ||||
|         self.detector_type = 'SIFT'  # 'SIFT', 'SURF', 'ORB' | ||||
|         self.max_features = 1000 | ||||
|         self.match_threshold = 0.7 | ||||
|          | ||||
|         # Tracking state | ||||
|         self.features = {}  # {frame_number: {'keypoints': [...], 'descriptors': [...], 'positions': [...]}} | ||||
|         self.tracking_enabled = False | ||||
|         self.auto_tracking = False | ||||
|          | ||||
|         # Initialize detectors | ||||
|         self._init_detectors() | ||||
|          | ||||
|     def _init_detectors(self): | ||||
|         """Initialize feature detectors based on type""" | ||||
|         try: | ||||
|             if self.detector_type == 'SIFT': | ||||
|                 self.detector = cv2.SIFT_create(nfeatures=self.max_features) | ||||
|                 self.matcher = cv2.BFMatcher() | ||||
|             elif self.detector_type == 'SURF': | ||||
|                 self.detector = cv2.xfeatures2d.SURF_create(hessianThreshold=400) | ||||
|                 self.matcher = cv2.BFMatcher() | ||||
|             elif self.detector_type == 'ORB': | ||||
|                 self.detector = cv2.ORB_create(nfeatures=self.max_features) | ||||
|                 self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) | ||||
|             else: | ||||
|                 raise ValueError(f"Unknown detector type: {self.detector_type}") | ||||
|         except Exception as e: | ||||
|             print(f"Warning: Could not initialize {self.detector_type} detector: {e}") | ||||
|             # Fallback to ORB | ||||
|             self.detector_type = 'ORB' | ||||
|             self.detector = cv2.ORB_create(nfeatures=self.max_features) | ||||
|             self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) | ||||
|      | ||||
|     def set_detector_type(self, detector_type: str): | ||||
|         """Change detector type and reinitialize""" | ||||
|         if detector_type in ['SIFT', 'SURF', 'ORB']: | ||||
|             self.detector_type = detector_type | ||||
|             self._init_detectors() | ||||
|             print(f"Switched to {detector_type} detector") | ||||
|         else: | ||||
|             print(f"Invalid detector type: {detector_type}") | ||||
|      | ||||
|     def extract_features(self, frame: np.ndarray, frame_number: int) -> bool: | ||||
|         """Extract features from a frame and store them""" | ||||
|         try: | ||||
|             # Convert to grayscale if needed | ||||
|             if len(frame.shape) == 3: | ||||
|                 gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | ||||
|             else: | ||||
|                 gray = frame | ||||
|              | ||||
|             # Extract keypoints and descriptors | ||||
|             keypoints, descriptors = self.detector.detectAndCompute(gray, None) | ||||
|              | ||||
|             if keypoints is None or descriptors is None: | ||||
|                 return False | ||||
|              | ||||
|             # Store features | ||||
|             self.features[frame_number] = { | ||||
|                 'keypoints': keypoints, | ||||
|                 'descriptors': descriptors, | ||||
|                 'positions': [(int(kp.pt[0]), int(kp.pt[1])) for kp in keypoints] | ||||
|             } | ||||
|              | ||||
|             print(f"Extracted {len(keypoints)} features from frame {frame_number}") | ||||
|             return True | ||||
|              | ||||
|         except Exception as e: | ||||
|             print(f"Error extracting features from frame {frame_number}: {e}") | ||||
|             return False | ||||
|      | ||||
|      | ||||
|      | ||||
|     def get_tracking_position(self, frame_number: int) -> Optional[Tuple[float, float]]: | ||||
|         """Get the average tracking position for a frame""" | ||||
|         if frame_number not in self.features or not self.features[frame_number]['positions']: | ||||
|             return None | ||||
|          | ||||
|         positions = self.features[frame_number]['positions'] | ||||
|         if not positions: | ||||
|             return None | ||||
|          | ||||
|         avg_x = sum(pos[0] for pos in positions) / len(positions) | ||||
|         avg_y = sum(pos[1] for pos in positions) / len(positions) | ||||
|         return (avg_x, avg_y) | ||||
|      | ||||
|      | ||||
|     def clear_features(self): | ||||
|         """Clear all stored features""" | ||||
|         self.features.clear() | ||||
|         print("All features cleared") | ||||
|      | ||||
|     def get_feature_count(self, frame_number: int) -> int: | ||||
|         """Get number of features for a frame""" | ||||
|         if frame_number in self.features: | ||||
|             return len(self.features[frame_number]['positions']) | ||||
|         return 0 | ||||
|      | ||||
|     def serialize_features(self) -> Dict[str, Any]: | ||||
|         """Serialize features for state saving""" | ||||
|         serialized = {} | ||||
|          | ||||
|         for frame_num, frame_data in self.features.items(): | ||||
|             frame_key = str(frame_num) | ||||
|             serialized[frame_key] = { | ||||
|                 'positions': frame_data['positions'], | ||||
|                 'keypoints': None,  # Keypoints are not serialized (too large) | ||||
|                 'descriptors': None  # Descriptors are not serialized (too large) | ||||
|             } | ||||
|          | ||||
|         return serialized | ||||
|      | ||||
|     def deserialize_features(self, serialized_data: Dict[str, Any]): | ||||
|         """Deserialize features from state loading""" | ||||
|         self.features.clear() | ||||
|          | ||||
|         for frame_key, frame_data in serialized_data.items(): | ||||
|             frame_num = int(frame_key) | ||||
|             self.features[frame_num] = { | ||||
|                 'positions': frame_data['positions'], | ||||
|                 'keypoints': None, | ||||
|                 'descriptors': None | ||||
|             } | ||||
|          | ||||
|         print(f"Deserialized features for {len(self.features)} frames") | ||||
|      | ||||
|     def get_state_dict(self) -> Dict[str, Any]: | ||||
|         """Get complete state for serialization""" | ||||
|         return { | ||||
|             'detector_type': self.detector_type, | ||||
|             'max_features': self.max_features, | ||||
|             'match_threshold': self.match_threshold, | ||||
|             'tracking_enabled': self.tracking_enabled, | ||||
|             'auto_tracking': self.auto_tracking, | ||||
|             'features': self.serialize_features() | ||||
|         } | ||||
|      | ||||
|     def load_state_dict(self, state_dict: Dict[str, Any]): | ||||
|         """Load complete state from serialization""" | ||||
|         if 'detector_type' in state_dict: | ||||
|             self.detector_type = state_dict['detector_type'] | ||||
|             self._init_detectors() | ||||
|          | ||||
|         if 'max_features' in state_dict: | ||||
|             self.max_features = state_dict['max_features'] | ||||
|          | ||||
|         if 'match_threshold' in state_dict: | ||||
|             self.match_threshold = state_dict['match_threshold'] | ||||
|          | ||||
|         if 'tracking_enabled' in state_dict: | ||||
|             self.tracking_enabled = state_dict['tracking_enabled'] | ||||
|          | ||||
|         if 'auto_tracking' in state_dict: | ||||
|             self.auto_tracking = state_dict['auto_tracking'] | ||||
|          | ||||
|         if 'features' in state_dict: | ||||
|             self.deserialize_features(state_dict['features']) | ||||
|          | ||||
|         print("Feature tracker state loaded") | ||||
|  | ||||
|  | ||||
| class Cv2BufferedCap: | ||||
|     """Buffered wrapper around cv2.VideoCapture that handles frame loading, seeking, and caching correctly""" | ||||
|      | ||||
| @@ -595,6 +762,9 @@ class VideoEditor: | ||||
|         self.tracking_points = {}  # {frame_number: [(x, y), ...]} in original frame coords | ||||
|         self.tracking_enabled = False | ||||
|  | ||||
|         # Feature tracking system | ||||
|         self.feature_tracker = FeatureTracker() | ||||
|  | ||||
|         # Project view mode | ||||
|         self.project_view_mode = False | ||||
|         self.project_view = None | ||||
| @@ -639,7 +809,8 @@ class VideoEditor: | ||||
|                 'seek_multiplier': getattr(self, 'seek_multiplier', 1.0), | ||||
|                 'is_playing': getattr(self, 'is_playing', False), | ||||
|                 'tracking_enabled': self.tracking_enabled, | ||||
|                 'tracking_points': {str(k): v for k, v in self.tracking_points.items()} | ||||
|                 'tracking_points': {str(k): v for k, v in self.tracking_points.items()}, | ||||
|                 'feature_tracker': self.feature_tracker.get_state_dict() | ||||
|             } | ||||
|  | ||||
|             with open(state_file, 'w') as f: | ||||
| @@ -721,6 +892,11 @@ class VideoEditor: | ||||
|             if 'tracking_points' in state and isinstance(state['tracking_points'], dict): | ||||
|                 self.tracking_points = {int(k): v for k, v in state['tracking_points'].items()} | ||||
|                 print(f"Loaded tracking_points: {sum(len(v) for v in self.tracking_points.values())} points") | ||||
|              | ||||
|             # Load feature tracker state | ||||
|             if 'feature_tracker' in state: | ||||
|                 self.feature_tracker.load_state_dict(state['feature_tracker']) | ||||
|                 print(f"Loaded feature tracker state") | ||||
|  | ||||
|             # Validate cut markers against current video length | ||||
|             if self.cut_start_frame is not None and self.cut_start_frame >= self.total_frames: | ||||
| @@ -1054,6 +1230,18 @@ class VideoEditor: | ||||
|         """Seek to specific frame""" | ||||
|         self.current_frame = max(0, min(frame_number, self.total_frames - 1)) | ||||
|         self.load_current_frame() | ||||
|          | ||||
|         # Auto-extract features if feature tracking is enabled and auto-tracking is on | ||||
|         if (not self.is_image_mode and  | ||||
|             self.feature_tracker.tracking_enabled and  | ||||
|             self.feature_tracker.auto_tracking and | ||||
|             self.current_display_frame is not None): | ||||
|              | ||||
|             # Only extract if we don't already have features for this frame | ||||
|             if self.current_frame not in self.feature_tracker.features: | ||||
|                 # Extract features from the original frame (before transformations) | ||||
|                 # This ensures features are in the original coordinate system | ||||
|                 self.feature_tracker.extract_features(self.current_display_frame, self.current_frame) | ||||
|  | ||||
|     def jump_to_previous_marker(self): | ||||
|         """Jump to the previous tracking marker (frame with tracking points).""" | ||||
| @@ -1270,6 +1458,14 @@ class VideoEditor: | ||||
|  | ||||
|     def _get_interpolated_tracking_position(self, frame_number): | ||||
|         """Linear interpolation in ROTATED frame coords. Returns (rx, ry) or None.""" | ||||
|         # First try feature tracking if enabled | ||||
|         if self.feature_tracker.tracking_enabled: | ||||
|             feature_pos = self.feature_tracker.get_tracking_position(frame_number) | ||||
|             if feature_pos: | ||||
|                 # Features are extracted from original frames, so coordinates are already correct | ||||
|                 return feature_pos | ||||
|          | ||||
|         # Fall back to manual tracking points | ||||
|         if not self.tracking_points: | ||||
|             return None | ||||
|         frames = sorted(self.tracking_points.keys()) | ||||
| @@ -1940,13 +2136,19 @@ class VideoEditor: | ||||
|         motion_text = ( | ||||
|             f" | Motion: {self.tracking_enabled}" if self.tracking_enabled else "" | ||||
|         ) | ||||
|         feature_text = ( | ||||
|             f" | Features: {self.feature_tracker.tracking_enabled}" if self.feature_tracker.tracking_enabled else "" | ||||
|         ) | ||||
|         if self.feature_tracker.tracking_enabled and self.current_frame in self.feature_tracker.features: | ||||
|             feature_count = self.feature_tracker.get_feature_count(self.current_frame) | ||||
|             feature_text = f" | Features: {feature_count} pts" | ||||
|         autorepeat_text = ( | ||||
|             f" | Loop: ON" if self.looping_between_markers else "" | ||||
|         ) | ||||
|         if self.is_image_mode: | ||||
|             info_text = f"Image | Zoom: {self.zoom_factor:.1f}x{rotation_text}{brightness_text}{contrast_text}{motion_text}" | ||||
|             info_text = f"Image | Zoom: {self.zoom_factor:.1f}x{rotation_text}{brightness_text}{contrast_text}{motion_text}{feature_text}" | ||||
|         else: | ||||
|             info_text = f"Frame: {self.current_frame}/{self.total_frames} | Speed: {self.playback_speed:.1f}x | Zoom: {self.zoom_factor:.1f}x{seek_multiplier_text}{rotation_text}{brightness_text}{contrast_text}{motion_text}{autorepeat_text} | {'Playing' if self.is_playing else 'Paused'}" | ||||
|             info_text = f"Frame: {self.current_frame}/{self.total_frames} | Speed: {self.playback_speed:.1f}x | Zoom: {self.zoom_factor:.1f}x{seek_multiplier_text}{rotation_text}{brightness_text}{contrast_text}{motion_text}{feature_text}{autorepeat_text} | {'Playing' if self.is_playing else 'Paused'}" | ||||
|         cv2.putText( | ||||
|             canvas, | ||||
|             info_text, | ||||
| @@ -2039,6 +2241,16 @@ class VideoEditor: | ||||
|             cv2.circle(canvas, (sx, sy), 6, (255, 0, 0), -1) | ||||
|             cv2.circle(canvas, (sx, sy), 6, (255, 255, 255), 1) | ||||
|          | ||||
|         # Draw feature tracking points (green circles) | ||||
|         if (not self.is_image_mode and  | ||||
|             self.feature_tracker.tracking_enabled and  | ||||
|             self.current_frame in self.feature_tracker.features): | ||||
|             feature_positions = self.feature_tracker.features[self.current_frame]['positions'] | ||||
|             for (fx, fy) in feature_positions: | ||||
|                 sx, sy = self._map_rotated_to_screen(fx, fy) | ||||
|                 cv2.circle(canvas, (sx, sy), 4, (0, 255, 0), -1)  # Green circles for features | ||||
|                 cv2.circle(canvas, (sx, sy), 4, (255, 255, 255), 1) | ||||
|          | ||||
|         # Draw previous and next tracking points with motion path visualization | ||||
|         if not self.is_image_mode and self.tracking_points: | ||||
|             prev_result = self._get_previous_tracking_point() | ||||
| @@ -3273,6 +3485,47 @@ class VideoEditor: | ||||
|                 self.tracking_points = {} | ||||
|                 self.show_feedback_message("Tracking points cleared") | ||||
|                 self.save_state() | ||||
|             elif key == ord("f"): | ||||
|                 # Toggle feature tracking on/off | ||||
|                 self.feature_tracker.tracking_enabled = not self.feature_tracker.tracking_enabled | ||||
|                 self.show_feedback_message(f"Feature tracking {'ON' if self.feature_tracker.tracking_enabled else 'OFF'}") | ||||
|                 self.save_state() | ||||
|             elif key == ord("F"): | ||||
|                 # Extract features from current frame | ||||
|                 if not self.is_image_mode and self.current_display_frame is not None: | ||||
|                     # Extract features from the original frame (before transformations) | ||||
|                     # This ensures features are in the original coordinate system | ||||
|                     success = self.feature_tracker.extract_features(self.current_display_frame, self.current_frame) | ||||
|                     if success: | ||||
|                         count = self.feature_tracker.get_feature_count(self.current_frame) | ||||
|                         self.show_feedback_message(f"Extracted {count} features from frame {self.current_frame}") | ||||
|                     else: | ||||
|                         self.show_feedback_message("Failed to extract features") | ||||
|                     self.save_state() | ||||
|                 else: | ||||
|                     self.show_feedback_message("No frame data available") | ||||
|             elif key == ord("g"): | ||||
|                 # Toggle auto tracking | ||||
|                 self.feature_tracker.auto_tracking = not self.feature_tracker.auto_tracking | ||||
|                 self.show_feedback_message(f"Auto tracking {'ON' if self.feature_tracker.auto_tracking else 'OFF'}") | ||||
|                 self.save_state() | ||||
|             elif key == ord("G"): | ||||
|                 # Clear all feature tracking data | ||||
|                 self.feature_tracker.clear_features() | ||||
|                 self.show_feedback_message("Feature tracking data cleared") | ||||
|                 self.save_state() | ||||
|             elif key == ord("h"): | ||||
|                 # Switch detector type (SIFT -> SURF -> ORB -> SIFT) | ||||
|                 current_type = self.feature_tracker.detector_type | ||||
|                 if current_type == 'SIFT': | ||||
|                     new_type = 'SURF' | ||||
|                 elif current_type == 'SURF': | ||||
|                     new_type = 'ORB' | ||||
|                 else: | ||||
|                     new_type = 'SIFT' | ||||
|                 self.feature_tracker.set_detector_type(new_type) | ||||
|                 self.show_feedback_message(f"Detector switched to {new_type}") | ||||
|                 self.save_state() | ||||
|             elif key == ord("t"): | ||||
|                 # Marker looping only for videos | ||||
|                 if not self.is_image_mode: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user