249 lines
10 KiB
Python
249 lines
10 KiB
Python
import cv2
|
|
import numpy as np
|
|
from typing import Dict, Any
|
|
|
|
|
|
class FeatureTracker:
|
|
"""Semi-automatic feature tracking with SIFT/SURF/ORB support and full state serialization"""
|
|
|
|
def __init__(self):
|
|
# Feature detection parameters
|
|
self.detector_type = 'SIFT' # 'SIFT', 'SURF', 'ORB'
|
|
self.max_features = 1000
|
|
self.match_threshold = 0.7
|
|
|
|
# Tracking state
|
|
self.features = {} # {frame_number: {'keypoints': [...], 'descriptors': [...], 'positions': [...]}}
|
|
self.tracking_enabled = False
|
|
self.auto_tracking = False
|
|
|
|
# Initialize detectors
|
|
self._init_detectors()
|
|
|
|
def _init_detectors(self):
|
|
"""Initialize feature detectors based on type"""
|
|
try:
|
|
if self.detector_type == 'SIFT':
|
|
self.detector = cv2.SIFT_create(nfeatures=self.max_features)
|
|
elif self.detector_type == 'SURF':
|
|
# SURF requires opencv-contrib-python, fallback to SIFT
|
|
print("Warning: SURF requires opencv-contrib-python package. Using SIFT instead.")
|
|
self.detector = cv2.SIFT_create(nfeatures=self.max_features)
|
|
self.detector_type = 'SIFT'
|
|
elif self.detector_type == 'ORB':
|
|
self.detector = cv2.ORB_create(nfeatures=self.max_features)
|
|
else:
|
|
raise ValueError(f"Unknown detector type: {self.detector_type}")
|
|
except Exception as e:
|
|
print(f"Warning: Could not initialize {self.detector_type} detector: {e}")
|
|
# Fallback to ORB
|
|
self.detector_type = 'ORB'
|
|
self.detector = cv2.ORB_create(nfeatures=self.max_features)
|
|
|
|
def set_detector_type(self, detector_type: str):
|
|
"""Change detector type and reinitialize"""
|
|
if detector_type in ['SIFT', 'SURF', 'ORB']:
|
|
self.detector_type = detector_type
|
|
self._init_detectors()
|
|
print(f"Switched to {detector_type} detector")
|
|
else:
|
|
print(f"Invalid detector type: {detector_type}")
|
|
|
|
def extract_features(self, frame: np.ndarray, frame_number: int, coord_mapper=None) -> bool:
|
|
"""Extract features from a frame and store them"""
|
|
try:
|
|
# Convert to grayscale if needed
|
|
if len(frame.shape) == 3:
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
gray = frame
|
|
|
|
# Extract keypoints and descriptors
|
|
keypoints, descriptors = self.detector.detectAndCompute(gray, None)
|
|
|
|
if keypoints is None or descriptors is None:
|
|
return False
|
|
|
|
# Map coordinates back to original frame space if mapper provided
|
|
if coord_mapper:
|
|
mapped_positions = []
|
|
for kp in keypoints:
|
|
orig_x, orig_y = coord_mapper(kp.pt[0], kp.pt[1])
|
|
mapped_positions.append((int(orig_x), int(orig_y)))
|
|
else:
|
|
mapped_positions = [(int(kp.pt[0]), int(kp.pt[1])) for kp in keypoints]
|
|
|
|
# Store features
|
|
self.features[frame_number] = {
|
|
'keypoints': keypoints,
|
|
'descriptors': descriptors,
|
|
'positions': mapped_positions
|
|
}
|
|
|
|
print(f"Extracted {len(keypoints)} features from frame {frame_number}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error extracting features from frame {frame_number}: {e}")
|
|
return False
|
|
|
|
def extract_features_from_region(self, frame: np.ndarray, frame_number: int, coord_mapper=None) -> bool:
|
|
"""Extract features from a frame and ADD them to existing features"""
|
|
try:
|
|
# Convert to grayscale if needed
|
|
if len(frame.shape) == 3:
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
gray = frame
|
|
|
|
# Extract keypoints and descriptors
|
|
keypoints, descriptors = self.detector.detectAndCompute(gray, None)
|
|
|
|
if keypoints is None or descriptors is None:
|
|
return False
|
|
|
|
# Map coordinates back to original frame space if mapper provided
|
|
if coord_mapper:
|
|
mapped_positions = []
|
|
for kp in keypoints:
|
|
orig_x, orig_y = coord_mapper(kp.pt[0], kp.pt[1])
|
|
mapped_positions.append((int(orig_x), int(orig_y)))
|
|
else:
|
|
mapped_positions = [(int(kp.pt[0]), int(kp.pt[1])) for kp in keypoints]
|
|
|
|
# Add to existing features or create new entry
|
|
if frame_number in self.features:
|
|
# Check if descriptor dimensions match
|
|
existing_features = self.features[frame_number]
|
|
if existing_features['descriptors'].shape[1] != descriptors.shape[1]:
|
|
print(f"Warning: Descriptor dimension mismatch ({existing_features['descriptors'].shape[1]} vs {descriptors.shape[1]}). Cannot concatenate. Replacing features.")
|
|
# Replace instead of concatenate when dimensions don't match
|
|
existing_features['keypoints'] = keypoints
|
|
existing_features['descriptors'] = descriptors
|
|
existing_features['positions'] = mapped_positions
|
|
else:
|
|
# Append to existing features
|
|
existing_features['keypoints'] = np.concatenate([existing_features['keypoints'], keypoints])
|
|
existing_features['descriptors'] = np.concatenate([existing_features['descriptors'], descriptors])
|
|
existing_features['positions'].extend(mapped_positions)
|
|
print(f"Added {len(keypoints)} features to frame {frame_number} (total: {len(existing_features['positions'])})")
|
|
else:
|
|
# Create new features entry
|
|
self.features[frame_number] = {
|
|
'keypoints': keypoints,
|
|
'descriptors': descriptors,
|
|
'positions': mapped_positions
|
|
}
|
|
print(f"Extracted {len(keypoints)} features from frame {frame_number}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error extracting features from frame {frame_number}: {e}")
|
|
return False
|
|
|
|
def track_features_optical_flow(self, prev_frame, curr_frame, prev_points):
|
|
"""Track features using Lucas-Kanade optical flow"""
|
|
try:
|
|
# Convert to grayscale if needed
|
|
if len(prev_frame.shape) == 3:
|
|
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
prev_gray = prev_frame
|
|
|
|
if len(curr_frame.shape) == 3:
|
|
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
curr_gray = curr_frame
|
|
|
|
# Parameters for Lucas-Kanade optical flow
|
|
lk_params = dict(winSize=(15, 15),
|
|
maxLevel=2,
|
|
criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
|
|
|
|
# Calculate optical flow
|
|
new_points, status, _ = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, prev_points, None, **lk_params)
|
|
|
|
# Filter out bad tracks
|
|
good_new = new_points[status == 1]
|
|
good_old = prev_points[status == 1]
|
|
|
|
return good_new, good_old, status
|
|
|
|
except Exception as e:
|
|
print(f"Error in optical flow tracking: {e}")
|
|
return None, None, None
|
|
|
|
def clear_features(self):
|
|
"""Clear all stored features"""
|
|
self.features.clear()
|
|
print("All features cleared")
|
|
|
|
def get_feature_count(self, frame_number: int) -> int:
|
|
"""Get number of features for a frame"""
|
|
if frame_number in self.features:
|
|
return len(self.features[frame_number]['positions'])
|
|
return 0
|
|
|
|
def serialize_features(self) -> Dict[str, Any]:
|
|
"""Serialize features for state saving"""
|
|
serialized = {}
|
|
|
|
for frame_num, frame_data in self.features.items():
|
|
frame_key = str(frame_num)
|
|
serialized[frame_key] = {
|
|
'positions': frame_data['positions'],
|
|
'keypoints': None, # Keypoints are not serialized (too large)
|
|
'descriptors': None # Descriptors are not serialized (too large)
|
|
}
|
|
|
|
return serialized
|
|
|
|
def deserialize_features(self, serialized_data: Dict[str, Any]):
|
|
"""Deserialize features from state loading"""
|
|
self.features.clear()
|
|
|
|
for frame_key, frame_data in serialized_data.items():
|
|
frame_num = int(frame_key)
|
|
self.features[frame_num] = {
|
|
'positions': frame_data['positions'],
|
|
'keypoints': None,
|
|
'descriptors': None
|
|
}
|
|
|
|
print(f"Deserialized features for {len(self.features)} frames")
|
|
|
|
def get_state_dict(self) -> Dict[str, Any]:
|
|
"""Get complete state for serialization"""
|
|
return {
|
|
'detector_type': self.detector_type,
|
|
'max_features': self.max_features,
|
|
'match_threshold': self.match_threshold,
|
|
'tracking_enabled': self.tracking_enabled,
|
|
'auto_tracking': self.auto_tracking,
|
|
'features': self.serialize_features()
|
|
}
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
"""Load complete state from serialization"""
|
|
if 'detector_type' in state_dict:
|
|
self.detector_type = state_dict['detector_type']
|
|
self._init_detectors()
|
|
|
|
if 'max_features' in state_dict:
|
|
self.max_features = state_dict['max_features']
|
|
|
|
if 'match_threshold' in state_dict:
|
|
self.match_threshold = state_dict['match_threshold']
|
|
|
|
if 'tracking_enabled' in state_dict:
|
|
self.tracking_enabled = state_dict['tracking_enabled']
|
|
|
|
if 'auto_tracking' in state_dict:
|
|
self.auto_tracking = state_dict['auto_tracking']
|
|
|
|
if 'features' in state_dict:
|
|
self.deserialize_features(state_dict['features'])
|
|
|
|
print("Feature tracker state loaded")
|