Files
py-media-grader/croppa/tracking.py

249 lines
10 KiB
Python

import cv2
import numpy as np
from typing import Dict, Any
class FeatureTracker:
"""Semi-automatic feature tracking with SIFT/SURF/ORB support and full state serialization"""
def __init__(self):
# Feature detection parameters
self.detector_type = 'SIFT' # 'SIFT', 'SURF', 'ORB'
self.max_features = 1000
self.match_threshold = 0.7
# Tracking state
self.features = {} # {frame_number: {'keypoints': [...], 'descriptors': [...], 'positions': [...]}}
self.tracking_enabled = False
self.auto_tracking = False
# Initialize detectors
self._init_detectors()
def _init_detectors(self):
"""Initialize feature detectors based on type"""
try:
if self.detector_type == 'SIFT':
self.detector = cv2.SIFT_create(nfeatures=self.max_features)
elif self.detector_type == 'SURF':
# SURF requires opencv-contrib-python, fallback to SIFT
print("Warning: SURF requires opencv-contrib-python package. Using SIFT instead.")
self.detector = cv2.SIFT_create(nfeatures=self.max_features)
self.detector_type = 'SIFT'
elif self.detector_type == 'ORB':
self.detector = cv2.ORB_create(nfeatures=self.max_features)
else:
raise ValueError(f"Unknown detector type: {self.detector_type}")
except Exception as e:
print(f"Warning: Could not initialize {self.detector_type} detector: {e}")
# Fallback to ORB
self.detector_type = 'ORB'
self.detector = cv2.ORB_create(nfeatures=self.max_features)
def set_detector_type(self, detector_type: str):
"""Change detector type and reinitialize"""
if detector_type in ['SIFT', 'SURF', 'ORB']:
self.detector_type = detector_type
self._init_detectors()
print(f"Switched to {detector_type} detector")
else:
print(f"Invalid detector type: {detector_type}")
def extract_features(self, frame: np.ndarray, frame_number: int, coord_mapper=None) -> bool:
"""Extract features from a frame and store them"""
try:
# Convert to grayscale if needed
if len(frame.shape) == 3:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
else:
gray = frame
# Extract keypoints and descriptors
keypoints, descriptors = self.detector.detectAndCompute(gray, None)
if keypoints is None or descriptors is None:
return False
# Map coordinates back to original frame space if mapper provided
if coord_mapper:
mapped_positions = []
for kp in keypoints:
orig_x, orig_y = coord_mapper(kp.pt[0], kp.pt[1])
mapped_positions.append((int(orig_x), int(orig_y)))
else:
mapped_positions = [(int(kp.pt[0]), int(kp.pt[1])) for kp in keypoints]
# Store features
self.features[frame_number] = {
'keypoints': keypoints,
'descriptors': descriptors,
'positions': mapped_positions
}
print(f"Extracted {len(keypoints)} features from frame {frame_number}")
return True
except Exception as e:
print(f"Error extracting features from frame {frame_number}: {e}")
return False
def extract_features_from_region(self, frame: np.ndarray, frame_number: int, coord_mapper=None) -> bool:
"""Extract features from a frame and ADD them to existing features"""
try:
# Convert to grayscale if needed
if len(frame.shape) == 3:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
else:
gray = frame
# Extract keypoints and descriptors
keypoints, descriptors = self.detector.detectAndCompute(gray, None)
if keypoints is None or descriptors is None:
return False
# Map coordinates back to original frame space if mapper provided
if coord_mapper:
mapped_positions = []
for kp in keypoints:
orig_x, orig_y = coord_mapper(kp.pt[0], kp.pt[1])
mapped_positions.append((int(orig_x), int(orig_y)))
else:
mapped_positions = [(int(kp.pt[0]), int(kp.pt[1])) for kp in keypoints]
# Add to existing features or create new entry
if frame_number in self.features:
# Check if descriptor dimensions match
existing_features = self.features[frame_number]
if existing_features['descriptors'].shape[1] != descriptors.shape[1]:
print(f"Warning: Descriptor dimension mismatch ({existing_features['descriptors'].shape[1]} vs {descriptors.shape[1]}). Cannot concatenate. Replacing features.")
# Replace instead of concatenate when dimensions don't match
existing_features['keypoints'] = keypoints
existing_features['descriptors'] = descriptors
existing_features['positions'] = mapped_positions
else:
# Append to existing features
existing_features['keypoints'] = np.concatenate([existing_features['keypoints'], keypoints])
existing_features['descriptors'] = np.concatenate([existing_features['descriptors'], descriptors])
existing_features['positions'].extend(mapped_positions)
print(f"Added {len(keypoints)} features to frame {frame_number} (total: {len(existing_features['positions'])})")
else:
# Create new features entry
self.features[frame_number] = {
'keypoints': keypoints,
'descriptors': descriptors,
'positions': mapped_positions
}
print(f"Extracted {len(keypoints)} features from frame {frame_number}")
return True
except Exception as e:
print(f"Error extracting features from frame {frame_number}: {e}")
return False
def track_features_optical_flow(self, prev_frame, curr_frame, prev_points):
"""Track features using Lucas-Kanade optical flow"""
try:
# Convert to grayscale if needed
if len(prev_frame.shape) == 3:
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
else:
prev_gray = prev_frame
if len(curr_frame.shape) == 3:
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
else:
curr_gray = curr_frame
# Parameters for Lucas-Kanade optical flow
lk_params = dict(winSize=(15, 15),
maxLevel=2,
criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
# Calculate optical flow
new_points, status, _ = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, prev_points, None, **lk_params)
# Filter out bad tracks
good_new = new_points[status == 1]
good_old = prev_points[status == 1]
return good_new, good_old, status
except Exception as e:
print(f"Error in optical flow tracking: {e}")
return None, None, None
def clear_features(self):
"""Clear all stored features"""
self.features.clear()
print("All features cleared")
def get_feature_count(self, frame_number: int) -> int:
"""Get number of features for a frame"""
if frame_number in self.features:
return len(self.features[frame_number]['positions'])
return 0
def serialize_features(self) -> Dict[str, Any]:
"""Serialize features for state saving"""
serialized = {}
for frame_num, frame_data in self.features.items():
frame_key = str(frame_num)
serialized[frame_key] = {
'positions': frame_data['positions'],
'keypoints': None, # Keypoints are not serialized (too large)
'descriptors': None # Descriptors are not serialized (too large)
}
return serialized
def deserialize_features(self, serialized_data: Dict[str, Any]):
"""Deserialize features from state loading"""
self.features.clear()
for frame_key, frame_data in serialized_data.items():
frame_num = int(frame_key)
self.features[frame_num] = {
'positions': frame_data['positions'],
'keypoints': None,
'descriptors': None
}
print(f"Deserialized features for {len(self.features)} frames")
def get_state_dict(self) -> Dict[str, Any]:
"""Get complete state for serialization"""
return {
'detector_type': self.detector_type,
'max_features': self.max_features,
'match_threshold': self.match_threshold,
'tracking_enabled': self.tracking_enabled,
'auto_tracking': self.auto_tracking,
'features': self.serialize_features()
}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load complete state from serialization"""
if 'detector_type' in state_dict:
self.detector_type = state_dict['detector_type']
self._init_detectors()
if 'max_features' in state_dict:
self.max_features = state_dict['max_features']
if 'match_threshold' in state_dict:
self.match_threshold = state_dict['match_threshold']
if 'tracking_enabled' in state_dict:
self.tracking_enabled = state_dict['tracking_enabled']
if 'auto_tracking' in state_dict:
self.auto_tracking = state_dict['auto_tracking']
if 'features' in state_dict:
self.deserialize_features(state_dict['features'])
print("Feature tracker state loaded")