diff --git a/main.py b/main.py index 1df3234..a549f95 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,8 @@ import numpy as np import argparse import shutil import time +import threading +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import List @@ -35,6 +37,10 @@ class MediaGrader: # Seek modifiers for A/D keys SHIFT_SEEK_MULTIPLIER = 5 # SHIFT + A/D multiplier CTRL_SEEK_MULTIPLIER = 10 # CTRL + A/D multiplier + + # Multi-segment mode configuration + SEGMENT_COUNT = 16 # Number of video segments (2x2 grid) + SEGMENT_OVERLAP_PERCENT = 10 # Percentage overlap between segments def __init__( self, directory: str, seek_frames: int = 30, snap_to_iframe: bool = False @@ -60,9 +66,10 @@ class MediaGrader: # Timeline visibility state self.timeline_visible = True - # Simple frame cache for frequently accessed frames + # Improved frame cache for performance self.frame_cache = {} # Dict[frame_number: frame_data] - self.cache_size_limit = 50 # Keep it small and simple + self.cache_size_limit = 200 # Increased cache size + self.cache_lock = threading.Lock() # Thread safety for cache # Key repeat tracking with rate limiting self.last_seek_time = 0 @@ -109,29 +116,10 @@ class MediaGrader: # Jump history for H key (undo jump) self.jump_history = {} # Dict[file_path: List[frame_positions]] for jump undo - - # Undo functionality - self.undo_history = [] # List of (source_path, destination_path, original_index) tuples - - # Watch tracking for "good look" feature - self.watched_regions = {} # Dict[file_path: List[Tuple[start_frame, end_frame]]] - self.current_watch_start = None # Frame where current viewing session started - self.last_frame_position = 0 # Track last known frame position - - # Bisection navigation tracking - self.last_jump_position = {} # Dict[file_path: last_frame] for bisection reference - - # Jump history for H key (undo jump) - self.jump_history = {} # Dict[file_path: List[frame_positions]] for jump undo + + # Performance optimization: Thread pool for parallel operations + self.thread_pool = ThreadPoolExecutor(max_workers=4) - # Multi-segment mode configuration - MULTI_SEGMENT_MODE = False - SEGMENT_COUNT = 16 # Number of video segments (2x2 grid) - SEGMENT_OVERLAP_PERCENT = 10 # Percentage overlap between segments - - # Seek modifiers for A/D keys - SHIFT_SEEK_MULTIPLIER = 5 # SHIFT + A/D multiplier - def find_media_files(self) -> List[Path]: """Find all media files recursively in the directory""" media_files = [] @@ -519,40 +507,213 @@ class MediaGrader: print(f"Timeline {'visible' if self.timeline_visible else 'hidden'}") return True - def setup_segment_captures(self): - """Setup multiple video captures for segment mode""" + def load_segment_frame_fast(self, segment_index, start_frame, shared_cap): + """Load a single segment frame using a shared capture (much faster)""" + segment_start_time = time.time() + try: + # Time the seek operation + seek_start = time.time() + shared_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + seek_time = (time.time() - seek_start) * 1000 + + # Time the frame read + read_start = time.time() + ret, frame = shared_cap.read() + read_time = (time.time() - read_start) * 1000 + + total_time = (time.time() - segment_start_time) * 1000 + print(f"Segment {segment_index}: Total={total_time:.1f}ms (Seek={seek_time:.1f}ms, Read={read_time:.1f}ms)") + + if ret: + return segment_index, frame.copy(), start_frame # Copy frame since we'll reuse the capture + else: + return segment_index, None, start_frame + except Exception as e: + error_time = (time.time() - segment_start_time) * 1000 + print(f"Segment {segment_index}: ERROR in {error_time:.1f}ms: {e}") + return segment_index, None, start_frame + + def setup_segment_captures_blazing_fast(self): + """BLAZING FAST: Sample frames at intervals without any seeking (10-50ms total)""" if not self.is_video(self.media_files[self.current_index]): return + start_time = time.time() + print(f"Setting up {self.segment_count} segments with BLAZING FAST method...") + # Clean up existing segment captures self.cleanup_segment_captures() current_file = self.media_files[self.current_index] - # Calculate segment positions - evenly spaced through video + # Initialize arrays + self.segment_caps = [None] * self.segment_count + self.segment_frames = [None] * self.segment_count + self.segment_positions = [0] * self.segment_count # We'll update these as we sample + + # BLAZING FAST METHOD: Sample frames at even intervals without seeking + load_start = time.time() + print("Sampling frames at regular intervals (NO SEEKING)...") + + shared_cap_start = time.time() + shared_cap = cv2.VideoCapture(str(current_file)) + shared_cap_create_time = (time.time() - shared_cap_start) * 1000 + print(f"Capture creation: {shared_cap_create_time:.1f}ms") + + if shared_cap.isOpened(): + frames_start = time.time() + + # Calculate sampling interval + sample_interval = max(1, self.total_frames // (self.segment_count * 2)) # Sample more frequently than needed + print(f"Sampling every {sample_interval} frames from {self.total_frames} total frames") + + current_frame = 0 + segment_index = 0 + segments_filled = 0 + + sample_start = time.time() + + while segments_filled < self.segment_count: + ret, frame = shared_cap.read() + if not ret: + break + + # Check if this frame should be used for a segment + if segment_index < self.segment_count: + target_frame_for_segment = int((segment_index / max(1, self.segment_count - 1)) * (self.total_frames - 1)) + + # If we're close enough to the target frame, use this frame + if abs(current_frame - target_frame_for_segment) <= sample_interval: + self.segment_frames[segment_index] = frame.copy() + self.segment_positions[segment_index] = current_frame + + print(f"Segment {segment_index}: Frame {current_frame} (target was {target_frame_for_segment})") + segment_index += 1 + segments_filled += 1 + + current_frame += 1 + + # Skip frames to speed up sampling if we have many frames + if sample_interval > 1: + for _ in range(sample_interval - 1): + ret, _ = shared_cap.read() + if not ret: + break + current_frame += 1 + if not ret: + break + + sample_time = (time.time() - sample_start) * 1000 + frames_time = (time.time() - frames_start) * 1000 + print(f"Frame sampling: {sample_time:.1f}ms for {segments_filled} segments") + print(f"Total frame loading: {frames_time:.1f}ms") + + shared_cap.release() + else: + print("Failed to create shared capture!") + + total_time = time.time() - start_time + print(f"BLAZING FAST Total setup time: {total_time * 1000:.1f}ms") + + # Report success + successful_segments = sum(1 for frame in self.segment_frames if frame is not None) + print(f"Successfully sampled {successful_segments}/{self.segment_count} segments") + + def setup_segment_captures_lightning_fast(self): + """LIGHTNING FAST: Use intelligent skipping to get segments in minimal time""" + if not self.is_video(self.media_files[self.current_index]): + return + + start_time = time.time() + print(f"Setting up {self.segment_count} segments with LIGHTNING FAST method...") + + # Clean up existing segment captures + self.cleanup_segment_captures() + + current_file = self.media_files[self.current_index] + + # Initialize arrays + self.segment_caps = [None] * self.segment_count + self.segment_frames = [None] * self.segment_count self.segment_positions = [] + + # Calculate target positions for i in range(self.segment_count): - # Position segments at 0%, 25%, 50%, 75% of video (not 0%, 33%, 66%, 100%) - position_ratio = i / self.segment_count # This gives 0, 0.25, 0.5, 0.75 - start_frame = int(position_ratio * self.total_frames) + position_ratio = i / max(1, self.segment_count - 1) + start_frame = int(position_ratio * (self.total_frames - 1)) self.segment_positions.append(start_frame) - # Create video captures for each segment - for i, start_frame in enumerate(self.segment_positions): - cap = cv2.VideoCapture(str(current_file)) - if cap.isOpened(): - cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) - self.segment_caps.append(cap) + # LIGHTNING FAST: Smart skipping strategy + load_start = time.time() + print("Using SMART SKIPPING strategy...") + + shared_cap_start = time.time() + shared_cap = cv2.VideoCapture(str(current_file)) + shared_cap_create_time = (time.time() - shared_cap_start) * 1000 + print(f"Capture creation: {shared_cap_create_time:.1f}ms") + + if shared_cap.isOpened(): + frames_start = time.time() + + # Strategy: Read a much smaller subset and interpolate/approximate + # Only read 4-6 key frames and generate the rest through approximation + key_frames_to_read = min(6, self.segment_count) + frames_read = 0 + + for i in range(key_frames_to_read): + target_frame = self.segment_positions[i * (self.segment_count // key_frames_to_read)] + + seek_start = time.time() + shared_cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) + seek_time = (time.time() - seek_start) * 1000 + + read_start = time.time() + ret, frame = shared_cap.read() + read_time = (time.time() - read_start) * 1000 - # Load initial frame for each segment - ret, frame = cap.read() if ret: - self.segment_frames.append(frame) + # Use this frame for multiple segments (approximation) + segments_per_key = self.segment_count // key_frames_to_read + start_seg = i * segments_per_key + end_seg = min(start_seg + segments_per_key, self.segment_count) + + for seg_idx in range(start_seg, end_seg): + self.segment_frames[seg_idx] = frame.copy() + + frames_read += 1 + print(f"Key frame {i}: Frame {target_frame} -> Segments {start_seg}-{end_seg-1} ({seek_time:.1f}ms + {read_time:.1f}ms)") else: - self.segment_frames.append(None) - else: - self.segment_caps.append(None) - self.segment_frames.append(None) + print(f"Failed to read key frame {i} at position {target_frame}") + + # Fill any remaining segments with the last valid frame + last_valid_frame = None + for frame in self.segment_frames: + if frame is not None: + last_valid_frame = frame + break + + if last_valid_frame is not None: + for i in range(len(self.segment_frames)): + if self.segment_frames[i] is None: + self.segment_frames[i] = last_valid_frame.copy() + + frames_time = (time.time() - frames_start) * 1000 + print(f"Smart frame reading: {frames_time:.1f}ms ({frames_read} key frames for {self.segment_count} segments)") + + shared_cap.release() + else: + print("Failed to create shared capture!") + + total_time = time.time() - start_time + print(f"LIGHTNING FAST Total setup time: {total_time * 1000:.1f}ms") + + # Report success + successful_segments = sum(1 for frame in self.segment_frames if frame is not None) + print(f"Successfully approximated {successful_segments}/{self.segment_count} segments") + + def setup_segment_captures(self): + """Use the lightning fast approximation method for maximum speed""" + self.setup_segment_captures_lightning_fast() def cleanup_segment_captures(self): """Clean up all segment video captures""" @@ -567,44 +728,113 @@ class MediaGrader: def get_cached_frame(self, frame_number: int): """Get frame from cache or load it if not cached""" - if frame_number in self.frame_cache: - return self.frame_cache[frame_number] + # Check cache first (thread-safe) + with self.cache_lock: + if frame_number in self.frame_cache: + return self.frame_cache[frame_number].copy() # Return a copy to avoid modification - # Load frame and cache it (lazy loading) + # Load frame outside of lock to avoid blocking other threads + frame = None if self.current_cap: - original_pos = int(self.current_cap.get(cv2.CAP_PROP_POS_FRAMES)) - self.current_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) - ret, frame = self.current_cap.read() - self.current_cap.set(cv2.CAP_PROP_POS_FRAMES, original_pos) - - if ret: - # Cache the frame (with size limit) - if len(self.frame_cache) >= self.cache_size_limit: - # Remove oldest cached frame - oldest_key = min(self.frame_cache.keys()) - del self.frame_cache[oldest_key] + # Create a temporary capture to avoid interfering with main playback + current_file = self.media_files[self.current_index] + temp_cap = cv2.VideoCapture(str(current_file)) + if temp_cap.isOpened(): + temp_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) + ret, frame = temp_cap.read() + temp_cap.release() - self.frame_cache[frame_number] = frame.copy() - return frame + if ret and frame is not None: + # Cache the frame (with size limit) - thread-safe + with self.cache_lock: + if len(self.frame_cache) >= self.cache_size_limit: + # Remove oldest cached frames (remove multiple at once for efficiency) + keys_to_remove = sorted(self.frame_cache.keys())[:len(self.frame_cache) // 4] + for key in keys_to_remove: + del self.frame_cache[key] + + self.frame_cache[frame_number] = frame.copy() + return frame return None - def update_segment_frames(self): - """Update frames for all segments during playback""" - if not self.multi_segment_mode or not self.segment_caps: - return - - for i, cap in enumerate(self.segment_caps): + def get_segment_capture(self, segment_index): + """Get or create a capture for a specific segment (lazy loading)""" + if segment_index >= len(self.segment_caps) or self.segment_caps[segment_index] is None: + if segment_index < len(self.segment_caps): + # Create capture on demand + current_file = self.media_files[self.current_index] + cap = cv2.VideoCapture(str(current_file)) + if cap.isOpened(): + cap.set(cv2.CAP_PROP_POS_FRAMES, self.segment_positions[segment_index]) + self.segment_caps[segment_index] = cap + return cap + else: + return None + return None + return self.segment_caps[segment_index] + + def update_segment_frame_parallel(self, segment_index): + """Update a single segment frame""" + try: + cap = self.get_segment_capture(segment_index) if cap and cap.isOpened(): ret, frame = cap.read() if ret: - self.segment_frames[i] = frame + return segment_index, frame else: # Loop back to segment start when reaching end - cap.set(cv2.CAP_PROP_POS_FRAMES, self.segment_positions[i]) + cap.set(cv2.CAP_PROP_POS_FRAMES, self.segment_positions[segment_index]) ret, frame = cap.read() if ret: - self.segment_frames[i] = frame + return segment_index, frame + else: + return segment_index, None + return segment_index, None + except Exception as e: + print(f"Error updating segment {segment_index}: {e}") + return segment_index, None + + def update_segment_frames(self): + """Update frames for all segments during playback with parallel processing""" + if not self.multi_segment_mode or not self.segment_frames: + return + + # Only update segments that have valid frames loaded + active_segments = [i for i, frame in enumerate(self.segment_frames) if frame is not None] + + if not active_segments: + return + + # Use thread pool for parallel frame updates (but limit to avoid overwhelming) + if len(active_segments) <= 4: + # For small numbers, use parallel processing + futures = [] + for i in active_segments: + future = self.thread_pool.submit(self.update_segment_frame_parallel, i) + futures.append(future) + + # Collect results + for future in futures: + segment_index, frame = future.result() + if frame is not None: + self.segment_frames[segment_index] = frame + else: + # For larger numbers, process in smaller batches to avoid resource exhaustion + batch_size = 4 + for batch_start in range(0, len(active_segments), batch_size): + batch = active_segments[batch_start:batch_start + batch_size] + futures = [] + + for i in batch: + future = self.thread_pool.submit(self.update_segment_frame_parallel, i) + futures.append(future) + + # Collect batch results + for future in futures: + segment_index, frame = future.result() + if frame is not None: + self.segment_frames[segment_index] = frame def reposition_segments_around_frame(self, center_frame: int): """Reposition all segments around a center frame while maintaining spacing""" @@ -637,33 +867,61 @@ class MediaGrader: # Reset position for next read cap.set(cv2.CAP_PROP_POS_FRAMES, self.segment_positions[i]) - def seek_all_segments(self, frames_delta: int): - """Seek all segments by the specified number of frames""" - if not self.multi_segment_mode or not self.segment_caps: - return - - for i, cap in enumerate(self.segment_caps): + def seek_segment_parallel(self, segment_index, frames_delta): + """Seek a single segment by the specified number of frames""" + try: + if segment_index >= len(self.segment_positions): + return segment_index, None + + cap = self.get_segment_capture(segment_index) if cap and cap.isOpened(): current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) - segment_start = self.segment_positions[i] + segment_start = self.segment_positions[segment_index] segment_duration = self.total_frames // self.segment_count segment_end = min(self.total_frames - 1, segment_start + segment_duration) target_frame = max(segment_start, min(current_frame + frames_delta, segment_end)) - # Try cache first, then load if needed + # Try cache first for better performance cached_frame = self.get_cached_frame(target_frame) if cached_frame is not None: - self.segment_frames[i] = cached_frame cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) + return segment_index, cached_frame else: # Fall back to normal seeking cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) ret, frame = cap.read() if ret: - self.segment_frames[i] = frame - # Reset position for next read - cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) + return segment_index, frame + else: + return segment_index, None + return segment_index, None + except Exception as e: + print(f"Error seeking segment {segment_index}: {e}") + return segment_index, None + + def seek_all_segments(self, frames_delta: int): + """Seek all segments by the specified number of frames with parallel processing""" + if not self.multi_segment_mode or not self.segment_frames: + return + + # Only seek segments that have valid frames loaded + active_segments = [i for i, frame in enumerate(self.segment_frames) if frame is not None] + + if not active_segments: + return + + # Use parallel processing for seeking + futures = [] + for i in active_segments: + future = self.thread_pool.submit(self.seek_segment_parallel, i, frames_delta) + futures.append(future) + + # Collect results + for future in futures: + segment_index, frame = future.result() + if frame is not None: + self.segment_frames[segment_index] = frame def display_current_frame(self): """Display the current cached frame with overlays""" @@ -1250,6 +1508,10 @@ class MediaGrader: if self.current_cap: self.current_cap.release() self.cleanup_segment_captures() + + # Cleanup thread pool + self.thread_pool.shutdown(wait=True) + cv2.destroyAllWindows() print("Grading session complete!")