Maybe implement a feature tracker
This commit is contained in:
261
croppa/main.py
261
croppa/main.py
@@ -4,7 +4,7 @@ import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
import time
|
||||
import re
|
||||
import threading
|
||||
@@ -25,6 +25,173 @@ def load_image_utf8(image_path):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not load image file: {image_path} - {e}")
|
||||
|
||||
|
||||
class FeatureTracker:
|
||||
"""Semi-automatic feature tracking with SIFT/SURF/ORB support and full state serialization"""
|
||||
|
||||
def __init__(self):
|
||||
# Feature detection parameters
|
||||
self.detector_type = 'SIFT' # 'SIFT', 'SURF', 'ORB'
|
||||
self.max_features = 1000
|
||||
self.match_threshold = 0.7
|
||||
|
||||
# Tracking state
|
||||
self.features = {} # {frame_number: {'keypoints': [...], 'descriptors': [...], 'positions': [...]}}
|
||||
self.tracking_enabled = False
|
||||
self.auto_tracking = False
|
||||
|
||||
# Initialize detectors
|
||||
self._init_detectors()
|
||||
|
||||
def _init_detectors(self):
|
||||
"""Initialize feature detectors based on type"""
|
||||
try:
|
||||
if self.detector_type == 'SIFT':
|
||||
self.detector = cv2.SIFT_create(nfeatures=self.max_features)
|
||||
self.matcher = cv2.BFMatcher()
|
||||
elif self.detector_type == 'SURF':
|
||||
self.detector = cv2.xfeatures2d.SURF_create(hessianThreshold=400)
|
||||
self.matcher = cv2.BFMatcher()
|
||||
elif self.detector_type == 'ORB':
|
||||
self.detector = cv2.ORB_create(nfeatures=self.max_features)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown detector type: {self.detector_type}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not initialize {self.detector_type} detector: {e}")
|
||||
# Fallback to ORB
|
||||
self.detector_type = 'ORB'
|
||||
self.detector = cv2.ORB_create(nfeatures=self.max_features)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
|
||||
|
||||
def set_detector_type(self, detector_type: str):
|
||||
"""Change detector type and reinitialize"""
|
||||
if detector_type in ['SIFT', 'SURF', 'ORB']:
|
||||
self.detector_type = detector_type
|
||||
self._init_detectors()
|
||||
print(f"Switched to {detector_type} detector")
|
||||
else:
|
||||
print(f"Invalid detector type: {detector_type}")
|
||||
|
||||
def extract_features(self, frame: np.ndarray, frame_number: int) -> bool:
|
||||
"""Extract features from a frame and store them"""
|
||||
try:
|
||||
# Convert to grayscale if needed
|
||||
if len(frame.shape) == 3:
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray = frame
|
||||
|
||||
# Extract keypoints and descriptors
|
||||
keypoints, descriptors = self.detector.detectAndCompute(gray, None)
|
||||
|
||||
if keypoints is None or descriptors is None:
|
||||
return False
|
||||
|
||||
# Store features
|
||||
self.features[frame_number] = {
|
||||
'keypoints': keypoints,
|
||||
'descriptors': descriptors,
|
||||
'positions': [(int(kp.pt[0]), int(kp.pt[1])) for kp in keypoints]
|
||||
}
|
||||
|
||||
print(f"Extracted {len(keypoints)} features from frame {frame_number}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting features from frame {frame_number}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_tracking_position(self, frame_number: int) -> Optional[Tuple[float, float]]:
|
||||
"""Get the average tracking position for a frame"""
|
||||
if frame_number not in self.features or not self.features[frame_number]['positions']:
|
||||
return None
|
||||
|
||||
positions = self.features[frame_number]['positions']
|
||||
if not positions:
|
||||
return None
|
||||
|
||||
avg_x = sum(pos[0] for pos in positions) / len(positions)
|
||||
avg_y = sum(pos[1] for pos in positions) / len(positions)
|
||||
return (avg_x, avg_y)
|
||||
|
||||
|
||||
def clear_features(self):
|
||||
"""Clear all stored features"""
|
||||
self.features.clear()
|
||||
print("All features cleared")
|
||||
|
||||
def get_feature_count(self, frame_number: int) -> int:
|
||||
"""Get number of features for a frame"""
|
||||
if frame_number in self.features:
|
||||
return len(self.features[frame_number]['positions'])
|
||||
return 0
|
||||
|
||||
def serialize_features(self) -> Dict[str, Any]:
|
||||
"""Serialize features for state saving"""
|
||||
serialized = {}
|
||||
|
||||
for frame_num, frame_data in self.features.items():
|
||||
frame_key = str(frame_num)
|
||||
serialized[frame_key] = {
|
||||
'positions': frame_data['positions'],
|
||||
'keypoints': None, # Keypoints are not serialized (too large)
|
||||
'descriptors': None # Descriptors are not serialized (too large)
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
def deserialize_features(self, serialized_data: Dict[str, Any]):
|
||||
"""Deserialize features from state loading"""
|
||||
self.features.clear()
|
||||
|
||||
for frame_key, frame_data in serialized_data.items():
|
||||
frame_num = int(frame_key)
|
||||
self.features[frame_num] = {
|
||||
'positions': frame_data['positions'],
|
||||
'keypoints': None,
|
||||
'descriptors': None
|
||||
}
|
||||
|
||||
print(f"Deserialized features for {len(self.features)} frames")
|
||||
|
||||
def get_state_dict(self) -> Dict[str, Any]:
|
||||
"""Get complete state for serialization"""
|
||||
return {
|
||||
'detector_type': self.detector_type,
|
||||
'max_features': self.max_features,
|
||||
'match_threshold': self.match_threshold,
|
||||
'tracking_enabled': self.tracking_enabled,
|
||||
'auto_tracking': self.auto_tracking,
|
||||
'features': self.serialize_features()
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||
"""Load complete state from serialization"""
|
||||
if 'detector_type' in state_dict:
|
||||
self.detector_type = state_dict['detector_type']
|
||||
self._init_detectors()
|
||||
|
||||
if 'max_features' in state_dict:
|
||||
self.max_features = state_dict['max_features']
|
||||
|
||||
if 'match_threshold' in state_dict:
|
||||
self.match_threshold = state_dict['match_threshold']
|
||||
|
||||
if 'tracking_enabled' in state_dict:
|
||||
self.tracking_enabled = state_dict['tracking_enabled']
|
||||
|
||||
if 'auto_tracking' in state_dict:
|
||||
self.auto_tracking = state_dict['auto_tracking']
|
||||
|
||||
if 'features' in state_dict:
|
||||
self.deserialize_features(state_dict['features'])
|
||||
|
||||
print("Feature tracker state loaded")
|
||||
|
||||
|
||||
class Cv2BufferedCap:
|
||||
"""Buffered wrapper around cv2.VideoCapture that handles frame loading, seeking, and caching correctly"""
|
||||
|
||||
@@ -595,6 +762,9 @@ class VideoEditor:
|
||||
self.tracking_points = {} # {frame_number: [(x, y), ...]} in original frame coords
|
||||
self.tracking_enabled = False
|
||||
|
||||
# Feature tracking system
|
||||
self.feature_tracker = FeatureTracker()
|
||||
|
||||
# Project view mode
|
||||
self.project_view_mode = False
|
||||
self.project_view = None
|
||||
@@ -639,7 +809,8 @@ class VideoEditor:
|
||||
'seek_multiplier': getattr(self, 'seek_multiplier', 1.0),
|
||||
'is_playing': getattr(self, 'is_playing', False),
|
||||
'tracking_enabled': self.tracking_enabled,
|
||||
'tracking_points': {str(k): v for k, v in self.tracking_points.items()}
|
||||
'tracking_points': {str(k): v for k, v in self.tracking_points.items()},
|
||||
'feature_tracker': self.feature_tracker.get_state_dict()
|
||||
}
|
||||
|
||||
with open(state_file, 'w') as f:
|
||||
@@ -721,6 +892,11 @@ class VideoEditor:
|
||||
if 'tracking_points' in state and isinstance(state['tracking_points'], dict):
|
||||
self.tracking_points = {int(k): v for k, v in state['tracking_points'].items()}
|
||||
print(f"Loaded tracking_points: {sum(len(v) for v in self.tracking_points.values())} points")
|
||||
|
||||
# Load feature tracker state
|
||||
if 'feature_tracker' in state:
|
||||
self.feature_tracker.load_state_dict(state['feature_tracker'])
|
||||
print(f"Loaded feature tracker state")
|
||||
|
||||
# Validate cut markers against current video length
|
||||
if self.cut_start_frame is not None and self.cut_start_frame >= self.total_frames:
|
||||
@@ -1054,6 +1230,18 @@ class VideoEditor:
|
||||
"""Seek to specific frame"""
|
||||
self.current_frame = max(0, min(frame_number, self.total_frames - 1))
|
||||
self.load_current_frame()
|
||||
|
||||
# Auto-extract features if feature tracking is enabled and auto-tracking is on
|
||||
if (not self.is_image_mode and
|
||||
self.feature_tracker.tracking_enabled and
|
||||
self.feature_tracker.auto_tracking and
|
||||
self.current_display_frame is not None):
|
||||
|
||||
# Only extract if we don't already have features for this frame
|
||||
if self.current_frame not in self.feature_tracker.features:
|
||||
# Extract features from the original frame (before transformations)
|
||||
# This ensures features are in the original coordinate system
|
||||
self.feature_tracker.extract_features(self.current_display_frame, self.current_frame)
|
||||
|
||||
def jump_to_previous_marker(self):
|
||||
"""Jump to the previous tracking marker (frame with tracking points)."""
|
||||
@@ -1270,6 +1458,14 @@ class VideoEditor:
|
||||
|
||||
def _get_interpolated_tracking_position(self, frame_number):
|
||||
"""Linear interpolation in ROTATED frame coords. Returns (rx, ry) or None."""
|
||||
# First try feature tracking if enabled
|
||||
if self.feature_tracker.tracking_enabled:
|
||||
feature_pos = self.feature_tracker.get_tracking_position(frame_number)
|
||||
if feature_pos:
|
||||
# Features are extracted from original frames, so coordinates are already correct
|
||||
return feature_pos
|
||||
|
||||
# Fall back to manual tracking points
|
||||
if not self.tracking_points:
|
||||
return None
|
||||
frames = sorted(self.tracking_points.keys())
|
||||
@@ -1940,13 +2136,19 @@ class VideoEditor:
|
||||
motion_text = (
|
||||
f" | Motion: {self.tracking_enabled}" if self.tracking_enabled else ""
|
||||
)
|
||||
feature_text = (
|
||||
f" | Features: {self.feature_tracker.tracking_enabled}" if self.feature_tracker.tracking_enabled else ""
|
||||
)
|
||||
if self.feature_tracker.tracking_enabled and self.current_frame in self.feature_tracker.features:
|
||||
feature_count = self.feature_tracker.get_feature_count(self.current_frame)
|
||||
feature_text = f" | Features: {feature_count} pts"
|
||||
autorepeat_text = (
|
||||
f" | Loop: ON" if self.looping_between_markers else ""
|
||||
)
|
||||
if self.is_image_mode:
|
||||
info_text = f"Image | Zoom: {self.zoom_factor:.1f}x{rotation_text}{brightness_text}{contrast_text}{motion_text}"
|
||||
info_text = f"Image | Zoom: {self.zoom_factor:.1f}x{rotation_text}{brightness_text}{contrast_text}{motion_text}{feature_text}"
|
||||
else:
|
||||
info_text = f"Frame: {self.current_frame}/{self.total_frames} | Speed: {self.playback_speed:.1f}x | Zoom: {self.zoom_factor:.1f}x{seek_multiplier_text}{rotation_text}{brightness_text}{contrast_text}{motion_text}{autorepeat_text} | {'Playing' if self.is_playing else 'Paused'}"
|
||||
info_text = f"Frame: {self.current_frame}/{self.total_frames} | Speed: {self.playback_speed:.1f}x | Zoom: {self.zoom_factor:.1f}x{seek_multiplier_text}{rotation_text}{brightness_text}{contrast_text}{motion_text}{feature_text}{autorepeat_text} | {'Playing' if self.is_playing else 'Paused'}"
|
||||
cv2.putText(
|
||||
canvas,
|
||||
info_text,
|
||||
@@ -2039,6 +2241,16 @@ class VideoEditor:
|
||||
cv2.circle(canvas, (sx, sy), 6, (255, 0, 0), -1)
|
||||
cv2.circle(canvas, (sx, sy), 6, (255, 255, 255), 1)
|
||||
|
||||
# Draw feature tracking points (green circles)
|
||||
if (not self.is_image_mode and
|
||||
self.feature_tracker.tracking_enabled and
|
||||
self.current_frame in self.feature_tracker.features):
|
||||
feature_positions = self.feature_tracker.features[self.current_frame]['positions']
|
||||
for (fx, fy) in feature_positions:
|
||||
sx, sy = self._map_rotated_to_screen(fx, fy)
|
||||
cv2.circle(canvas, (sx, sy), 4, (0, 255, 0), -1) # Green circles for features
|
||||
cv2.circle(canvas, (sx, sy), 4, (255, 255, 255), 1)
|
||||
|
||||
# Draw previous and next tracking points with motion path visualization
|
||||
if not self.is_image_mode and self.tracking_points:
|
||||
prev_result = self._get_previous_tracking_point()
|
||||
@@ -3273,6 +3485,47 @@ class VideoEditor:
|
||||
self.tracking_points = {}
|
||||
self.show_feedback_message("Tracking points cleared")
|
||||
self.save_state()
|
||||
elif key == ord("f"):
|
||||
# Toggle feature tracking on/off
|
||||
self.feature_tracker.tracking_enabled = not self.feature_tracker.tracking_enabled
|
||||
self.show_feedback_message(f"Feature tracking {'ON' if self.feature_tracker.tracking_enabled else 'OFF'}")
|
||||
self.save_state()
|
||||
elif key == ord("F"):
|
||||
# Extract features from current frame
|
||||
if not self.is_image_mode and self.current_display_frame is not None:
|
||||
# Extract features from the original frame (before transformations)
|
||||
# This ensures features are in the original coordinate system
|
||||
success = self.feature_tracker.extract_features(self.current_display_frame, self.current_frame)
|
||||
if success:
|
||||
count = self.feature_tracker.get_feature_count(self.current_frame)
|
||||
self.show_feedback_message(f"Extracted {count} features from frame {self.current_frame}")
|
||||
else:
|
||||
self.show_feedback_message("Failed to extract features")
|
||||
self.save_state()
|
||||
else:
|
||||
self.show_feedback_message("No frame data available")
|
||||
elif key == ord("g"):
|
||||
# Toggle auto tracking
|
||||
self.feature_tracker.auto_tracking = not self.feature_tracker.auto_tracking
|
||||
self.show_feedback_message(f"Auto tracking {'ON' if self.feature_tracker.auto_tracking else 'OFF'}")
|
||||
self.save_state()
|
||||
elif key == ord("G"):
|
||||
# Clear all feature tracking data
|
||||
self.feature_tracker.clear_features()
|
||||
self.show_feedback_message("Feature tracking data cleared")
|
||||
self.save_state()
|
||||
elif key == ord("h"):
|
||||
# Switch detector type (SIFT -> SURF -> ORB -> SIFT)
|
||||
current_type = self.feature_tracker.detector_type
|
||||
if current_type == 'SIFT':
|
||||
new_type = 'SURF'
|
||||
elif current_type == 'SURF':
|
||||
new_type = 'ORB'
|
||||
else:
|
||||
new_type = 'SIFT'
|
||||
self.feature_tracker.set_detector_type(new_type)
|
||||
self.show_feedback_message(f"Detector switched to {new_type}")
|
||||
self.save_state()
|
||||
elif key == ord("t"):
|
||||
# Marker looping only for videos
|
||||
if not self.is_image_mode:
|
||||
|
Reference in New Issue
Block a user