bplusplus 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bplusplus/inference.py ADDED
@@ -0,0 +1,1337 @@
1
+ import cv2
2
+ import time
3
+ import os
4
+ import yaml
5
+ import json
6
+ import numpy as np
7
+ import pandas as pd
8
+ from datetime import datetime
9
+ from collections import defaultdict
10
+ import uuid
11
+ import argparse
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torchvision import transforms, models
16
+ from PIL import Image
17
+ import requests
18
+ import logging
19
+
20
+ from .tracker import InsectTracker
21
+ from .detector import (
22
+ DEFAULT_DETECTION_CONFIG,
23
+ get_default_config,
24
+ build_detection_params,
25
+ extract_motion_detections,
26
+ analyze_path_topology,
27
+ check_track_consistency,
28
+ )
29
+
30
+ # Torch serialization compatibility
31
+ if hasattr(torch.serialization, 'add_safe_globals'):
32
+ torch.serialization.add_safe_globals([
33
+ 'torch.LongTensor',
34
+ 'torch.cuda.LongTensor',
35
+ 'torch.FloatStorage',
36
+ 'torch.cuda.FloatStorage',
37
+ ])
38
+
39
+ logging.basicConfig(level=logging.INFO)
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ # ============================================================================
44
+ # CONFIGURATION
45
+ # ============================================================================
46
+
47
+ def load_config(config_path):
48
+ """
49
+ Load detection configuration from YAML or JSON file.
50
+
51
+ Args:
52
+ config_path: Path to config file (.yaml, .yml, or .json)
53
+
54
+ Returns:
55
+ dict: Configuration parameters
56
+ """
57
+ if not os.path.exists(config_path):
58
+ raise FileNotFoundError(f"Config file not found: {config_path}")
59
+
60
+ ext = os.path.splitext(config_path)[1].lower()
61
+
62
+ with open(config_path, 'r') as f:
63
+ if ext in ['.yaml', '.yml']:
64
+ config = yaml.safe_load(f)
65
+ elif ext == '.json':
66
+ config = json.load(f)
67
+ else:
68
+ raise ValueError(f"Unsupported config format: {ext}")
69
+
70
+ # Merge with defaults
71
+ params = get_default_config()
72
+ for key, value in config.items():
73
+ if key in params:
74
+ params[key] = value
75
+ else:
76
+ logger.warning(f"Unknown config parameter ignored: {key}")
77
+
78
+ return params
79
+
80
+
81
+ # ============================================================================
82
+ # TAXONOMY UTILITIES
83
+ # ============================================================================
84
+
85
+ def get_taxonomy(species_list):
86
+ """
87
+ Retrieve taxonomic information from GBIF API.
88
+
89
+ Args:
90
+ species_list: List of species names
91
+
92
+ Returns:
93
+ tuple: (taxonomy_dict, species_to_genus, genus_to_family)
94
+ """
95
+ taxonomy = {1: [], 2: {}, 3: {}}
96
+ species_to_genus = {}
97
+ genus_to_family = {}
98
+
99
+ species_for_gbif = [s for s in species_list if s.lower() != 'unknown']
100
+ has_unknown = len(species_for_gbif) != len(species_list)
101
+
102
+ logger.info(f"Building taxonomy from GBIF for {len(species_for_gbif)} species")
103
+ print(f"\n{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
104
+ print("-" * 80)
105
+
106
+ for species_name in species_for_gbif:
107
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
108
+ try:
109
+ response = requests.get(url)
110
+ data = response.json()
111
+
112
+ if data.get('status') in ['ACCEPTED', 'SYNONYM']:
113
+ family = data.get('family')
114
+ genus = data.get('genus')
115
+
116
+ if family and genus:
117
+ print(f"{species_name:<30} {family:<20} {genus:<20} OK")
118
+ species_to_genus[species_name] = genus
119
+ genus_to_family[genus] = family
120
+ if family not in taxonomy[1]:
121
+ taxonomy[1].append(family)
122
+ taxonomy[2][genus] = family
123
+ taxonomy[3][species_name] = genus
124
+ else:
125
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
126
+ logger.error(f"Species '{species_name}' missing family/genus")
127
+ else:
128
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
129
+ logger.error(f"Species '{species_name}' not found in GBIF")
130
+ except Exception as e:
131
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
132
+ logger.error(f"Error for '{species_name}': {e}")
133
+
134
+ if has_unknown:
135
+ if "Unknown" not in taxonomy[1]:
136
+ taxonomy[1].append("Unknown")
137
+ taxonomy[2]["Unknown"] = "Unknown"
138
+ taxonomy[3]["unknown"] = "Unknown"
139
+ species_to_genus["unknown"] = "Unknown"
140
+ genus_to_family["Unknown"] = "Unknown"
141
+ print(f"{'unknown':<30} {'Unknown':<20} {'Unknown':<20} OK")
142
+
143
+ taxonomy[1] = sorted(set(taxonomy[1]))
144
+ print("-" * 80)
145
+
146
+ for level, name, items in [(1, "Family", taxonomy[1]),
147
+ (2, "Genus", taxonomy[2].keys()),
148
+ (3, "Species", species_list)]:
149
+ print(f"\n{name} indices:")
150
+ for i, item in enumerate(items):
151
+ print(f" {i}: {item}")
152
+
153
+ logger.info(f"Taxonomy: {len(taxonomy[1])} families, {len(taxonomy[2])} genera, {len(taxonomy[3])} species")
154
+ return taxonomy, species_to_genus, genus_to_family
155
+
156
+
157
+ def create_mappings(taxonomy, species_list=None):
158
+ """Create index mappings from taxonomy."""
159
+ level_to_idx = {}
160
+ idx_to_level = {}
161
+
162
+ for level, labels in taxonomy.items():
163
+ if isinstance(labels, list):
164
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
165
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
166
+ else:
167
+ sorted_keys = species_list if level == 3 and species_list else sorted(labels.keys())
168
+ level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
169
+ idx_to_level[level] = {idx: label for idx, label in enumerate(sorted_keys)}
170
+
171
+ return level_to_idx, idx_to_level
172
+
173
+
174
+ # ============================================================================
175
+ # MODEL
176
+ # ============================================================================
177
+
178
+ class HierarchicalInsectClassifier(nn.Module):
179
+ """Hierarchical classifier with ResNet backbone and multi-branch heads."""
180
+
181
+ def __init__(self, num_classes_per_level, backbone: str = "resnet50"):
182
+ super().__init__()
183
+ self.backbone = self._build_backbone(backbone)
184
+ self.backbone_name = backbone
185
+ backbone_features = self.backbone.fc.in_features
186
+ self.backbone.fc = nn.Identity()
187
+
188
+ self.branches = nn.ModuleList([
189
+ nn.Sequential(
190
+ nn.Linear(backbone_features, 512),
191
+ nn.ReLU(),
192
+ nn.Dropout(0.5),
193
+ nn.Linear(512, num_classes)
194
+ ) for num_classes in num_classes_per_level
195
+ ])
196
+ self.num_levels = len(num_classes_per_level)
197
+
198
+ @staticmethod
199
+ def _build_backbone(backbone: str):
200
+ """Build ResNet backbone by name."""
201
+ name = backbone.lower()
202
+ if name == "resnet18":
203
+ return models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
204
+ if name == "resnet50":
205
+ return models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
206
+ if name == "resnet101":
207
+ return models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
208
+ raise ValueError(f"Unsupported backbone '{backbone}'. Choose from 'resnet18', 'resnet50', 'resnet101'.")
209
+
210
+ def forward(self, x):
211
+ features = self.backbone(x)
212
+ return [branch(features) for branch in self.branches]
213
+
214
+
215
+ # ============================================================================
216
+ # VISUALIZATION
217
+ # ============================================================================
218
+
219
+ class FrameVisualizer:
220
+ """Visualization utilities for detection overlay."""
221
+
222
+ COLORS = [
223
+ (68, 189, 50), (255, 59, 48), (0, 122, 255), (255, 149, 0),
224
+ (175, 82, 222), (255, 204, 0), (50, 173, 230), (255, 45, 85),
225
+ (48, 209, 88), (90, 200, 250), (255, 159, 10), (191, 90, 242),
226
+ ]
227
+
228
+ @staticmethod
229
+ def get_track_color(track_id):
230
+ if track_id is None:
231
+ return (68, 189, 50)
232
+ try:
233
+ track_uuid = uuid.UUID(track_id)
234
+ except (ValueError, TypeError):
235
+ track_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(track_id))
236
+ return FrameVisualizer.COLORS[track_uuid.int % len(FrameVisualizer.COLORS)]
237
+
238
+ @staticmethod
239
+ def draw_path(frame, path, track_id):
240
+ """Draw track path on frame."""
241
+ if len(path) < 2:
242
+ return
243
+ color = FrameVisualizer.get_track_color(track_id)
244
+ path_points = np.array(path, dtype=np.int32)
245
+ cv2.polylines(frame, [path_points], False, color, 2)
246
+ # Draw center point
247
+ cx, cy = path[-1]
248
+ cv2.circle(frame, (int(cx), int(cy)), 4, color, -1)
249
+
250
+ @staticmethod
251
+ def draw_detection(frame, x1, y1, x2, y2, track_id, detection_data):
252
+ """Draw bounding box and classification label on frame."""
253
+ color = FrameVisualizer.get_track_color(track_id)
254
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
255
+
256
+ track_display = f"ID: {str(track_id)[:8]}" if track_id else "NEW"
257
+ lines = [track_display]
258
+
259
+ for level, conf_key in [("family", "family_confidence"),
260
+ ("genus", "genus_confidence"),
261
+ ("species", "species_confidence")]:
262
+ if detection_data.get(level):
263
+ name = detection_data[level]
264
+ conf = detection_data.get(conf_key, 0)
265
+ name = name[:15] + "..." if len(name) > 18 else name
266
+ lines.append(f"{level[0].upper()}: {name}")
267
+ lines.append(f" {conf:.1%}")
268
+
269
+ if not lines[1:] and track_id is None:
270
+ return
271
+
272
+ font, scale, thickness = cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1
273
+ padding, spacing = 8, 6
274
+ text_sizes = [cv2.getTextSize(line, font, scale, thickness)[0] for line in lines]
275
+ max_w = max(s[0] for s in text_sizes)
276
+ text_h = text_sizes[0][1]
277
+
278
+ total_h = len(lines) * (text_h + spacing) + padding * 2
279
+ label_x1 = max(0, int(x1))
280
+ label_y1 = max(0, int(y1) - total_h - 5)
281
+ if label_y1 < 0:
282
+ label_y1 = int(y2) + 5
283
+ label_x2 = min(frame.shape[1], label_x1 + max_w + padding * 2)
284
+ label_y2 = min(frame.shape[0], label_y1 + total_h)
285
+
286
+ overlay = frame.copy()
287
+ cv2.rectangle(overlay, (label_x1, label_y1), (label_x2, label_y2), (20, 20, 20), -1)
288
+ cv2.addWeighted(overlay, 0.85, frame, 0.15, 0, frame)
289
+ cv2.rectangle(frame, (label_x1, label_y1), (label_x2, label_y2), color, 1)
290
+
291
+ y = label_y1 + padding + text_h
292
+ for i, line in enumerate(lines):
293
+ text_color = color if i == 0 else ((160, 160, 160) if "%" in line else (255, 255, 255))
294
+ cv2.putText(frame, line, (label_x1 + padding, y), font, scale, text_color, thickness, cv2.LINE_AA)
295
+ y += text_h + spacing
296
+
297
+
298
+ # ============================================================================
299
+ # VIDEO PROCESSOR
300
+ # ============================================================================
301
+
302
+ class VideoInferenceProcessor:
303
+ """
304
+ Processes video frames for insect detection and classification.
305
+
306
+ Combines motion-based detection with hierarchical classification
307
+ and track-based prediction aggregation.
308
+ """
309
+
310
+ def __init__(self, species_list, hierarchical_model_path, params, backbone="resnet50", img_size=60):
311
+ """
312
+ Initialize the processor.
313
+
314
+ Args:
315
+ species_list: List of species names for classification
316
+ hierarchical_model_path: Path to trained model weights
317
+ params: Detection parameters dict
318
+ backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101')
319
+ img_size: Image size for classification (should match training)
320
+ """
321
+ self.img_size = img_size
322
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
323
+ self.species_list = species_list
324
+ self.params = params
325
+
326
+ print(f"Using device: {self.device}")
327
+
328
+ # Build taxonomy
329
+ self.taxonomy, self.species_to_genus, self.genus_to_family = get_taxonomy(species_list)
330
+ self.level_to_idx, self.idx_to_level = create_mappings(self.taxonomy, species_list)
331
+ self.family_list = sorted(self.taxonomy[1])
332
+ self.genus_list = sorted(self.taxonomy[2].keys())
333
+
334
+ # Motion detection setup
335
+ self.back_sub = cv2.createBackgroundSubtractorMOG2(
336
+ history=500, varThreshold=16, detectShadows=False
337
+ )
338
+ self.morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
339
+
340
+ # Load classification model
341
+ print(f"Loading hierarchical model from {hierarchical_model_path}")
342
+ checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
343
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
344
+
345
+ # Use backbone from checkpoint if available, otherwise use provided
346
+ model_backbone = checkpoint.get("backbone", backbone)
347
+ if model_backbone != backbone:
348
+ print(f"Note: Using backbone '{model_backbone}' from checkpoint (overrides '{backbone}')")
349
+
350
+ num_classes = [len(self.family_list), len(self.genus_list), len(self.species_list)]
351
+ print(f"Model architecture: {num_classes} classes per level, backbone: {model_backbone}")
352
+
353
+ self.model = HierarchicalInsectClassifier(num_classes, backbone=model_backbone)
354
+ self.model.load_state_dict(state_dict, strict=False)
355
+ self.model.to(self.device)
356
+ self.model.eval()
357
+
358
+ self.transform = transforms.Compose([
359
+ transforms.Resize((self.img_size, self.img_size)),
360
+ transforms.ToTensor(),
361
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
362
+ ])
363
+
364
+ # Track state
365
+ self.all_detections = []
366
+ self.track_paths = defaultdict(list)
367
+ self.track_areas = defaultdict(list)
368
+
369
+ print("Processor initialized successfully!")
370
+
371
+ def _extract_detections(self, frame):
372
+ """Extract motion-based detections."""
373
+ detections, fg_mask = extract_motion_detections(
374
+ frame, self.back_sub, self.morph_kernel, self.params
375
+ )
376
+ return detections, fg_mask
377
+
378
+ def _classify(self, frame, x1, y1, x2, y2):
379
+ """Classify a detection crop."""
380
+ crop = frame[int(y1):int(y2), int(x1):int(x2)]
381
+ crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
382
+ tensor = self.transform(Image.fromarray(crop_rgb)).unsqueeze(0).to(self.device)
383
+
384
+ with torch.no_grad():
385
+ outputs = self.model(tensor)
386
+
387
+ probs = [torch.softmax(o, dim=1).cpu().numpy().flatten() for o in outputs]
388
+ idxs = [np.argmax(p) for p in probs]
389
+
390
+ return {
391
+ "family": self.family_list[idxs[0]] if idxs[0] < len(self.family_list) else f"Family_{idxs[0]}",
392
+ "genus": self.genus_list[idxs[1]] if idxs[1] < len(self.genus_list) else f"Genus_{idxs[1]}",
393
+ "species": self.species_list[idxs[2]] if idxs[2] < len(self.species_list) else f"Species_{idxs[2]}",
394
+ "family_confidence": float(probs[0][idxs[0]]),
395
+ "genus_confidence": float(probs[1][idxs[1]]),
396
+ "species_confidence": float(probs[2][idxs[2]]),
397
+ "family_probs": probs[0].tolist(),
398
+ "genus_probs": probs[1].tolist(),
399
+ "species_probs": probs[2].tolist(),
400
+ }
401
+
402
+ def process_frame(self, frame, frame_time, tracker, frame_number):
403
+ """
404
+ Process a single frame: detect and track only (no classification).
405
+ Classification happens later for confirmed tracks only.
406
+
407
+ Args:
408
+ frame: BGR image frame
409
+ frame_time: Time in seconds
410
+ tracker: InsectTracker instance
411
+ frame_number: Frame index
412
+
413
+ Returns:
414
+ tuple: (foreground_mask, list of detections with track_ids)
415
+ """
416
+ detections, fg_mask = self._extract_detections(frame)
417
+ track_ids = tracker.update(detections, frame_number)
418
+
419
+ height, width = frame.shape[:2]
420
+ frame_detections = []
421
+
422
+ for i, det in enumerate(detections):
423
+ x1, y1, x2, y2 = det
424
+ track_id = track_ids[i] if i < len(track_ids) else None
425
+
426
+ # Track consistency check
427
+ if track_id:
428
+ cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
429
+ area = (x2 - x1) * (y2 - y1)
430
+
431
+ if self.track_paths[track_id]:
432
+ prev_pos = self.track_paths[track_id][-1]
433
+ prev_area = self.track_areas[track_id][-1] if self.track_areas[track_id] else area
434
+
435
+ if not check_track_consistency(
436
+ prev_pos, (cx, cy), prev_area, area,
437
+ self.params["max_frame_jump"]
438
+ ):
439
+ # Reset track
440
+ self.track_paths[track_id] = []
441
+ self.track_areas[track_id] = []
442
+
443
+ self.track_paths[track_id].append((cx, cy))
444
+ self.track_areas[track_id].append(area)
445
+
446
+ # Store detection WITHOUT classification
447
+ detection_data = {
448
+ "timestamp": datetime.now().isoformat(),
449
+ "frame_number": frame_number,
450
+ "frame_time_seconds": frame_time,
451
+ "track_id": track_id,
452
+ "bbox": [x1, y1, x2, y2],
453
+ "bbox_normalized": [
454
+ (x1 + x2) / (2 * width), (y1 + y2) / (2 * height),
455
+ (x2 - x1) / width, (y2 - y1) / height
456
+ ],
457
+ }
458
+ self.all_detections.append(detection_data)
459
+ frame_detections.append(detection_data)
460
+
461
+ # Log (detection only, no classification yet)
462
+ track_display = str(track_id)[:8] if track_id else "NEW"
463
+ print(f"Frame {frame_time:6.2f}s | Track {track_display} | Detected")
464
+
465
+ return fg_mask, frame_detections
466
+
467
+ def classify_confirmed_tracks(self, video_path, confirmed_track_ids, crops_dir=None):
468
+ """
469
+ Classify only the confirmed tracks by re-reading relevant frames.
470
+
471
+ Args:
472
+ video_path: Path to original video
473
+ confirmed_track_ids: Set of track IDs that passed topology analysis
474
+ crops_dir: Optional directory to save cropped frames
475
+
476
+ Returns:
477
+ dict: track_id -> list of classifications
478
+ """
479
+ if not confirmed_track_ids:
480
+ print("No confirmed tracks to classify.")
481
+ return {}
482
+
483
+ print(f"\nClassifying {len(confirmed_track_ids)} confirmed tracks...")
484
+
485
+ # Setup crops directory if requested
486
+ if crops_dir:
487
+ os.makedirs(crops_dir, exist_ok=True)
488
+ # Create subdirectory for each track
489
+ for track_id in confirmed_track_ids:
490
+ track_dir = os.path.join(crops_dir, str(track_id)[:8])
491
+ os.makedirs(track_dir, exist_ok=True)
492
+ print(f" Saving crops to: {crops_dir}")
493
+
494
+ # Group detections by frame for confirmed tracks
495
+ frames_to_classify = defaultdict(list)
496
+ for det in self.all_detections:
497
+ if det['track_id'] in confirmed_track_ids:
498
+ frames_to_classify[det['frame_number']].append(det)
499
+
500
+ if not frames_to_classify:
501
+ return {}
502
+
503
+ cap = cv2.VideoCapture(video_path)
504
+ track_classifications = defaultdict(list)
505
+
506
+ frame_numbers = sorted(frames_to_classify.keys())
507
+ current_frame = 0
508
+ classified_count = 0
509
+
510
+ for target_frame in frame_numbers:
511
+ # Seek to frame
512
+ while current_frame < target_frame:
513
+ cap.read()
514
+ current_frame += 1
515
+
516
+ ret, frame = cap.read()
517
+ if not ret:
518
+ break
519
+ current_frame += 1
520
+
521
+ # Classify each detection in this frame
522
+ for det in frames_to_classify[target_frame]:
523
+ x1, y1, x2, y2 = det['bbox']
524
+ classification = self._classify(frame, x1, y1, x2, y2)
525
+
526
+ # Update detection with classification
527
+ det.update(classification)
528
+
529
+ track_classifications[det['track_id']].append(classification)
530
+ classified_count += 1
531
+
532
+ # Save crop if requested
533
+ if crops_dir:
534
+ track_id = det['track_id']
535
+ track_dir = os.path.join(crops_dir, str(track_id)[:8])
536
+ crop = frame[int(y1):int(y2), int(x1):int(x2)]
537
+ if crop.size > 0:
538
+ crop_path = os.path.join(track_dir, f"frame_{target_frame:06d}.jpg")
539
+ cv2.imwrite(crop_path, crop)
540
+
541
+ if classified_count % 20 == 0:
542
+ print(f" Classified {classified_count} detections...", end='\r')
543
+
544
+ cap.release()
545
+ print(f"\n✓ Classified {classified_count} detections from {len(confirmed_track_ids)} tracks")
546
+ if crops_dir:
547
+ print(f"✓ Saved {classified_count} crops to {crops_dir}")
548
+
549
+ return track_classifications
550
+
551
+ def analyze_tracks(self):
552
+ """
553
+ Analyze all tracks to determine which pass topology (before classification).
554
+
555
+ Returns:
556
+ tuple: (confirmed_track_ids set, all_track_info dict)
557
+ """
558
+ print("\n" + "="*60)
559
+ print("TRACK TOPOLOGY ANALYSIS")
560
+ print("="*60)
561
+
562
+ track_detections = defaultdict(list)
563
+ for det in self.all_detections:
564
+ if det['track_id']:
565
+ track_detections[det['track_id']].append(det)
566
+
567
+ confirmed_track_ids = set()
568
+ all_track_info = {}
569
+
570
+ for track_id, detections in track_detections.items():
571
+ # Path topology analysis
572
+ path = self.track_paths.get(track_id, [])
573
+ passes_topology, topology_metrics = analyze_path_topology(path, self.params)
574
+
575
+ frame_times = [d['frame_time_seconds'] for d in detections]
576
+
577
+ track_info = {
578
+ 'track_id': track_id,
579
+ 'num_detections': len(detections),
580
+ 'first_frame_time': min(frame_times),
581
+ 'last_frame_time': max(frame_times),
582
+ 'duration': max(frame_times) - min(frame_times),
583
+ 'passes_topology': passes_topology,
584
+ **topology_metrics
585
+ }
586
+ all_track_info[track_id] = track_info
587
+
588
+ status = "✓ CONFIRMED" if passes_topology else "? unconfirmed"
589
+ print(f"Track {str(track_id)[:8]}: {len(detections)} detections, "
590
+ f"{track_info['duration']:.1f}s - {status}")
591
+
592
+ if passes_topology:
593
+ confirmed_track_ids.add(track_id)
594
+
595
+ print(f"\n✓ {len(confirmed_track_ids)} confirmed / {len(track_detections)} total tracks")
596
+ return confirmed_track_ids, all_track_info
597
+
598
+ def hierarchical_aggregation(self, confirmed_track_ids):
599
+ """
600
+ Aggregate predictions for confirmed tracks using hierarchical selection.
601
+ Must be called AFTER classify_confirmed_tracks().
602
+
603
+ Args:
604
+ confirmed_track_ids: Set of confirmed track IDs
605
+
606
+ Returns:
607
+ list: Aggregated results for confirmed tracks only
608
+ """
609
+ print("\n" + "="*60)
610
+ print("HIERARCHICAL AGGREGATION (Confirmed Tracks)")
611
+ print("="*60)
612
+
613
+ track_detections = defaultdict(list)
614
+ for det in self.all_detections:
615
+ if det['track_id'] in confirmed_track_ids:
616
+ track_detections[det['track_id']].append(det)
617
+
618
+ results = []
619
+ for track_id, detections in track_detections.items():
620
+ # Check if classifications exist
621
+ if 'family_probs' not in detections[0]:
622
+ print(f"Warning: Track {str(track_id)[:8]} has no classifications, skipping")
623
+ continue
624
+
625
+ print(f"\nTrack {str(track_id)[:8]}: {len(detections)} classified detections")
626
+
627
+ # Path topology analysis
628
+ path = self.track_paths.get(track_id, [])
629
+ passes_topology, topology_metrics = analyze_path_topology(path, self.params)
630
+
631
+ # Average probabilities
632
+ prob_avgs = [
633
+ np.mean([d['family_probs'] for d in detections], axis=0),
634
+ np.mean([d['genus_probs'] for d in detections], axis=0),
635
+ np.mean([d['species_probs'] for d in detections], axis=0),
636
+ ]
637
+
638
+ # Hierarchical selection
639
+ best_family_idx = np.argmax(prob_avgs[0])
640
+ best_family = self.family_list[best_family_idx]
641
+
642
+ family_genera = [i for i, g in enumerate(self.genus_list)
643
+ if self.genus_to_family.get(g) == best_family]
644
+ if family_genera:
645
+ best_genus_idx = family_genera[np.argmax(prob_avgs[1][family_genera])]
646
+ else:
647
+ best_genus_idx = np.argmax(prob_avgs[1])
648
+ best_genus = self.genus_list[best_genus_idx]
649
+
650
+ genus_species = [i for i, s in enumerate(self.species_list)
651
+ if self.species_to_genus.get(s) == best_genus]
652
+ if genus_species:
653
+ best_species_idx = genus_species[np.argmax(prob_avgs[2][genus_species])]
654
+ else:
655
+ best_species_idx = np.argmax(prob_avgs[2])
656
+ best_species = self.species_list[best_species_idx]
657
+
658
+ frame_times = [d['frame_time_seconds'] for d in detections]
659
+
660
+ result = {
661
+ 'track_id': track_id,
662
+ 'num_detections': len(detections),
663
+ 'first_frame_time': min(frame_times),
664
+ 'last_frame_time': max(frame_times),
665
+ 'duration': max(frame_times) - min(frame_times),
666
+ 'final_family': best_family,
667
+ 'final_genus': best_genus,
668
+ 'final_species': best_species,
669
+ 'family_confidence': float(prob_avgs[0][best_family_idx]),
670
+ 'genus_confidence': float(prob_avgs[1][best_genus_idx]),
671
+ 'species_confidence': float(prob_avgs[2][best_species_idx]),
672
+ 'passes_topology': passes_topology,
673
+ **topology_metrics
674
+ }
675
+ results.append(result)
676
+
677
+ print(f" → {best_family} / {best_genus} / {best_species} "
678
+ f"({result['species_confidence']:.1%})")
679
+
680
+ return results
681
+
682
+ def save_results(self, results, output_paths):
683
+ """
684
+ Save results to CSV and print summary.
685
+
686
+ Args:
687
+ results: Aggregated results list (confirmed tracks only)
688
+ output_paths: Dict with output file paths
689
+
690
+ Returns:
691
+ pd.DataFrame: Results dataframe (confirmed tracks only)
692
+ """
693
+ # Count total tracks vs confirmed
694
+ total_tracks = len(self.track_paths)
695
+ num_confirmed = len(results)
696
+ num_unconfirmed = total_tracks - num_confirmed
697
+
698
+ # Save confirmed tracks to results CSV
699
+ if results:
700
+ df = pd.DataFrame(results).sort_values('num_detections', ascending=False)
701
+ df.to_csv(output_paths["results_csv"], index=False)
702
+ print(f"\n📊 Confirmed results saved: {output_paths['results_csv']} ({num_confirmed} tracks)")
703
+ else:
704
+ # Create empty CSV with headers
705
+ df = pd.DataFrame(columns=[
706
+ 'track_id', 'num_detections', 'first_frame_time', 'last_frame_time',
707
+ 'duration', 'final_family', 'final_genus', 'final_species',
708
+ 'family_confidence', 'genus_confidence', 'species_confidence',
709
+ 'passes_topology', 'total_displacement', 'revisit_ratio',
710
+ 'progression_ratio', 'directional_variance'
711
+ ])
712
+ df.to_csv(output_paths["results_csv"], index=False)
713
+ print(f"\n📊 Results file created (empty - no confirmed tracks): {output_paths['results_csv']}")
714
+
715
+ # Save all detections (frame-by-frame, regardless of confirmation)
716
+ det_df = pd.DataFrame(self.all_detections)
717
+ det_df.to_csv(output_paths["detections_csv"], index=False)
718
+ print(f"📋 Frame-by-frame detections saved: {output_paths['detections_csv']}")
719
+
720
+ # Print summary
721
+ print("\n" + "="*60)
722
+ print("🐛 FINAL SUMMARY")
723
+ print("="*60)
724
+
725
+ if results:
726
+ print(f"\n✓ CONFIRMED INSECTS ({num_confirmed}):")
727
+ for r in results:
728
+ print(f" • {r['final_species']} - {r['num_detections']} detections, "
729
+ f"{r['duration']:.1f}s, {r['species_confidence']:.1%}")
730
+
731
+ if num_unconfirmed > 0:
732
+ print(f"\n? Unconfirmed tracks: {num_unconfirmed} (failed topology analysis)")
733
+
734
+ print(f"\n📈 Total: {total_tracks} tracks ({num_confirmed} confirmed, {num_unconfirmed} unconfirmed)")
735
+
736
+ # Bold warning if no confirmed tracks
737
+ if not results:
738
+ print("\n" + "!"*60)
739
+ print("⚠️ WARNING: NO CONFIRMED INSECT TRACKS DETECTED!")
740
+ print("!"*60)
741
+ print("Possible reasons:")
742
+ print(" • No insects present in the video")
743
+ print(" • Detection parameters too strict (try lowering min_area)")
744
+ print(" • Tracking parameters too strict (try increasing lost_track_seconds)")
745
+ print(" • Path topology too strict (try lowering min_displacement)")
746
+ print(" • Video quality/resolution issues")
747
+ if num_unconfirmed > 0:
748
+ print(f"\nNote: {num_unconfirmed} tracks were detected but failed topology check.")
749
+ print("Consider relaxing path topology parameters if these might be valid insects.")
750
+ print("!"*60)
751
+
752
+ print("="*60)
753
+
754
+ return df
755
+
756
+
757
+ # ============================================================================
758
+ # VIDEO PROCESSING
759
+ # ============================================================================
760
+
761
+ def process_video(video_path, processor, output_paths, show_video=False, fps=None, crops_dir=None):
762
+ """
763
+ Process video file with efficient classification (confirmed tracks only).
764
+
765
+ Pipeline:
766
+ 1. Detection & Tracking: Process all frames, detect motion, build tracks
767
+ 2. Topology Analysis: Determine which tracks are confirmed insects
768
+ 3. Classification: Classify ONLY confirmed tracks (saves compute)
769
+ 4. Render Videos: Debug (all detections) + Annotated (confirmed with classifications)
770
+
771
+ Args:
772
+ video_path: Input video path
773
+ processor: VideoInferenceProcessor instance
774
+ output_paths: Dict with output file paths
775
+ show_video: Display video while processing
776
+ fps: Target FPS (skip frames if lower than input)
777
+ crops_dir: Optional directory to save cropped frames for each track
778
+
779
+ Returns:
780
+ list: Aggregated results
781
+ """
782
+ if not os.path.exists(video_path):
783
+ raise FileNotFoundError(f"Video not found: {video_path}")
784
+
785
+ cap = cv2.VideoCapture(video_path)
786
+ if not cap.isOpened():
787
+ raise ValueError(f"Could not open: {video_path}")
788
+
789
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
790
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
791
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
792
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
793
+
794
+ print(f"\nVideo: {video_path}")
795
+ print(f"Properties: {total_frames} frames, {input_fps:.1f} FPS, {total_frames/input_fps:.1f}s")
796
+
797
+ # Setup tracker
798
+ effective_fps = fps if fps and fps > 0 else input_fps if input_fps > 0 else 30
799
+ track_memory = max(30, int(processor.params["lost_track_seconds"] * effective_fps))
800
+
801
+ tracker = InsectTracker(
802
+ image_height=height, image_width=width,
803
+ max_frames=track_memory, track_memory_frames=track_memory,
804
+ w_dist=0.6, w_area=0.4, cost_threshold=0.3, debug=False
805
+ )
806
+
807
+ # Frame skip
808
+ skip_interval = max(1, int(input_fps / fps)) if fps and fps > 0 else 1
809
+ if skip_interval > 1:
810
+ print(f"Processing every {skip_interval} frame(s)")
811
+
812
+ # ==========================================================================
813
+ # PHASE 1: Detection & Tracking (no classification)
814
+ # ==========================================================================
815
+ print("\n" + "="*60)
816
+ print("PHASE 1: DETECTION & TRACKING")
817
+ print("="*60)
818
+
819
+ frame_num = 0
820
+ processed = 0
821
+ start = time.time()
822
+
823
+ while True:
824
+ ret, frame = cap.read()
825
+ if not ret:
826
+ break
827
+
828
+ frame_time = frame_num / input_fps if input_fps > 0 else 0
829
+
830
+ if frame_num % skip_interval == 0:
831
+ fg_mask, frame_dets = processor.process_frame(frame, frame_time, tracker, frame_num)
832
+ processed += 1
833
+
834
+ if processed % 50 == 0:
835
+ print(f" Progress: {processed} frames, {len(processor.all_detections)} detections", end='\r')
836
+
837
+ frame_num += 1
838
+
839
+ cap.release()
840
+ elapsed = time.time() - start
841
+ print(f"\n✓ Phase 1 complete: {processed} frames in {elapsed:.1f}s ({processed/elapsed:.1f} FPS)")
842
+ print(f" Total detections: {len(processor.all_detections)}")
843
+ print(f" Unique tracks: {len(processor.track_paths)}")
844
+
845
+ # ==========================================================================
846
+ # PHASE 2: Topology Analysis (determine confirmed tracks)
847
+ # ==========================================================================
848
+ confirmed_track_ids, all_track_info = processor.analyze_tracks()
849
+
850
+ # ==========================================================================
851
+ # PHASE 3: Classification (confirmed tracks only)
852
+ # ==========================================================================
853
+ print("\n" + "="*60)
854
+ print("PHASE 3: CLASSIFICATION (Confirmed Tracks Only)")
855
+ print("="*60)
856
+
857
+ if confirmed_track_ids:
858
+ processor.classify_confirmed_tracks(video_path, confirmed_track_ids, crops_dir=crops_dir)
859
+ results = processor.hierarchical_aggregation(confirmed_track_ids)
860
+ else:
861
+ results = []
862
+
863
+ # ==========================================================================
864
+ # PHASE 4: Render Videos
865
+ # ==========================================================================
866
+ # Render videos if requested
867
+ if "annotated_video" in output_paths or "debug_video" in output_paths:
868
+ print("\n" + "="*60)
869
+ print("PHASE 4: RENDERING VIDEOS")
870
+ print("="*60)
871
+
872
+ # Render debug video (all detections, showing confirmed vs unconfirmed)
873
+ if "debug_video" in output_paths:
874
+ print(f"\nRendering debug video (all detections)...")
875
+ _render_debug_video(
876
+ video_path, output_paths["debug_video"],
877
+ processor, confirmed_track_ids, all_track_info, input_fps
878
+ )
879
+
880
+ # Render annotated video (confirmed tracks with classifications)
881
+ if "annotated_video" in output_paths:
882
+ print(f"\nRendering annotated video ({len(confirmed_track_ids)} confirmed tracks)...")
883
+ _render_annotated_video(
884
+ video_path, output_paths["annotated_video"],
885
+ processor, confirmed_track_ids, input_fps
886
+ )
887
+ else:
888
+ print("\n(Video rendering skipped)")
889
+
890
+ # Save results
891
+ processor.save_results(results, output_paths)
892
+
893
+ return results
894
+
895
+
896
+ def _render_debug_video(video_path, output_path, processor, confirmed_track_ids, all_track_info, fps):
897
+ """
898
+ Render debug video showing all detections with confirmed/unconfirmed status.
899
+ Shows detection boxes, track IDs, and GMM motion mask side-by-side.
900
+ """
901
+ cap = cv2.VideoCapture(video_path)
902
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
903
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
904
+
905
+ # Recreate background subtractor for GMM visualization
906
+ back_sub = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=False)
907
+
908
+ out = cv2.VideoWriter(
909
+ output_path,
910
+ cv2.VideoWriter_fourcc(*'mp4v'),
911
+ fps, (width * 2, height)
912
+ )
913
+
914
+ # Build frame-to-detections lookup
915
+ frame_detections = defaultdict(list)
916
+ for det in processor.all_detections:
917
+ frame_detections[det['frame_number']].append(det)
918
+
919
+ frame_num = 0
920
+ while True:
921
+ ret, frame = cap.read()
922
+ if not ret:
923
+ break
924
+
925
+ # Get GMM mask
926
+ fg_mask = back_sub.apply(frame)
927
+ fg_display = cv2.cvtColor(fg_mask, cv2.COLOR_GRAY2BGR)
928
+
929
+ # Draw all detections with status
930
+ for det in frame_detections[frame_num]:
931
+ x1, y1, x2, y2 = [int(v) for v in det['bbox']]
932
+ track_id = det['track_id']
933
+
934
+ if track_id in confirmed_track_ids:
935
+ # Confirmed: Green box, show classification if available
936
+ color = (0, 255, 0)
937
+ status = "CONFIRMED"
938
+ classification = {
939
+ 'species': det.get('species', ''),
940
+ 'species_confidence': det.get('species_confidence', 0),
941
+ }
942
+ label = f"{str(track_id)[:6]} ✓"
943
+ if classification['species']:
944
+ label += f" {classification['species'][:12]}"
945
+ else:
946
+ # Unconfirmed: Yellow box
947
+ color = (0, 255, 255)
948
+ status = "tracking..."
949
+ label = f"{str(track_id)[:6] if track_id else 'NEW'}"
950
+
951
+ # Draw box
952
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
953
+
954
+ # Draw label background
955
+ (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
956
+ cv2.rectangle(frame, (x1, y1 - th - 4), (x1 + tw + 4, y1), color, -1)
957
+ cv2.putText(frame, label, (x1 + 2, y1 - 2),
958
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
959
+
960
+ # Draw on GMM mask too
961
+ cv2.rectangle(fg_display, (x1, y1), (x2, y2), color, 2)
962
+
963
+ # Add headers
964
+ cv2.putText(frame, f"Frame {frame_num} | Detections (Green=Confirmed, Yellow=Tracking)",
965
+ (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
966
+ cv2.putText(fg_display, "GMM Motion Mask",
967
+ (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
968
+
969
+ # Combine side-by-side
970
+ combined = np.hstack((frame, fg_display))
971
+ out.write(combined)
972
+
973
+ frame_num += 1
974
+ if frame_num % 100 == 0:
975
+ print(f" Debug: {frame_num} frames", end='\r')
976
+
977
+ cap.release()
978
+ out.release()
979
+ print(f"\n✓ Debug video saved: {output_path}")
980
+
981
+
982
+ def _render_annotated_video(video_path, output_path, processor, confirmed_track_ids, fps):
983
+ """
984
+ Render annotated video showing only confirmed tracks with classifications.
985
+ """
986
+ if not confirmed_track_ids:
987
+ # Create video with "No confirmed tracks" message
988
+ cap = cv2.VideoCapture(video_path)
989
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
990
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
991
+
992
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
993
+
994
+ frame_num = 0
995
+ while True:
996
+ ret, frame = cap.read()
997
+ if not ret:
998
+ break
999
+ cv2.putText(frame, "No confirmed insect tracks",
1000
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
1001
+ out.write(frame)
1002
+ frame_num += 1
1003
+
1004
+ cap.release()
1005
+ out.release()
1006
+ print(f"✓ Annotated video saved (no confirmed tracks): {output_path}")
1007
+ return
1008
+
1009
+ cap = cv2.VideoCapture(video_path)
1010
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
1011
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
1012
+
1013
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
1014
+
1015
+ # Build frame-to-detections lookup for confirmed tracks only
1016
+ frame_detections = defaultdict(list)
1017
+ for det in processor.all_detections:
1018
+ if det['track_id'] in confirmed_track_ids:
1019
+ frame_detections[det['frame_number']].append(det)
1020
+
1021
+ frame_num = 0
1022
+ while True:
1023
+ ret, frame = cap.read()
1024
+ if not ret:
1025
+ break
1026
+
1027
+ # Draw paths for confirmed tracks (up to current frame)
1028
+ for track_id in confirmed_track_ids:
1029
+ path_to_draw = []
1030
+ for det in processor.all_detections:
1031
+ if det['track_id'] == track_id and det['frame_number'] <= frame_num:
1032
+ bbox = det['bbox']
1033
+ cx = (bbox[0] + bbox[2]) / 2
1034
+ cy = (bbox[1] + bbox[3]) / 2
1035
+ path_to_draw.append((cx, cy))
1036
+
1037
+ if len(path_to_draw) > 1:
1038
+ FrameVisualizer.draw_path(frame, path_to_draw, track_id)
1039
+
1040
+ # Draw detections for this frame (confirmed tracks only, with classification)
1041
+ for det in frame_detections[frame_num]:
1042
+ x1, y1, x2, y2 = det['bbox']
1043
+ track_id = det['track_id']
1044
+
1045
+ classification = {
1046
+ 'family': det.get('family', ''),
1047
+ 'genus': det.get('genus', ''),
1048
+ 'species': det.get('species', ''),
1049
+ 'family_confidence': det.get('family_confidence', 0),
1050
+ 'genus_confidence': det.get('genus_confidence', 0),
1051
+ 'species_confidence': det.get('species_confidence', 0),
1052
+ }
1053
+ FrameVisualizer.draw_detection(frame, x1, y1, x2, y2, track_id, classification)
1054
+
1055
+ # Add header
1056
+ cv2.putText(frame, f"Confirmed Insects ({len(confirmed_track_ids)} tracks)",
1057
+ (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
1058
+
1059
+ out.write(frame)
1060
+ frame_num += 1
1061
+
1062
+ if frame_num % 100 == 0:
1063
+ print(f" Annotated: {frame_num} frames", end='\r')
1064
+
1065
+ cap.release()
1066
+ out.release()
1067
+ print(f"\n✓ Annotated video saved: {output_path}")
1068
+
1069
+
1070
+ # ============================================================================
1071
+ # MAIN ENTRY POINT
1072
+ # ============================================================================
1073
+
1074
+ def inference(
1075
+ species_list,
1076
+ hierarchical_model_path,
1077
+ video_path,
1078
+ output_dir,
1079
+ fps=None,
1080
+ config=None,
1081
+ backbone="resnet50",
1082
+ crops=False,
1083
+ save_video=True,
1084
+ img_size=60,
1085
+ ):
1086
+ """
1087
+ Run inference on a video file.
1088
+
1089
+ Args:
1090
+ species_list: List of species names for classification
1091
+ hierarchical_model_path: Path to trained model weights
1092
+ video_path: Input video path
1093
+ output_dir: Output directory for all generated files
1094
+ fps: Target processing FPS (None = use input FPS)
1095
+ config: Detection config - can be:
1096
+ - None: use defaults
1097
+ - str: path to YAML/JSON config file
1098
+ - dict: config parameters directly
1099
+ backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101').
1100
+ If model checkpoint contains backbone info, it will be used instead.
1101
+ crops: If True, save cropped frames for each classified track
1102
+ save_video: If True, save annotated and debug videos. Defaults to True.
1103
+ img_size: Image size for classification (should match training). Default: 60.
1104
+
1105
+ Returns:
1106
+ dict: Processing results with output file paths
1107
+
1108
+ Generated files in output_dir:
1109
+ - {video_name}_annotated.mp4: Video with detection boxes and paths (if save_video=True)
1110
+ - {video_name}_debug.mp4: Side-by-side with GMM motion mask (if save_video=True)
1111
+ - {video_name}_results.csv: Aggregated track results
1112
+ - {video_name}_detections.csv: Frame-by-frame detections
1113
+ - {video_name}_crops/ (if crops=True): Directory with cropped frames per track
1114
+ """
1115
+ if not os.path.exists(video_path):
1116
+ print(f"Error: Video not found: {video_path}")
1117
+ return {"error": f"Video not found: {video_path}", "success": False}
1118
+
1119
+ # Build parameters from config
1120
+ if config is None:
1121
+ params = get_default_config()
1122
+ elif isinstance(config, str):
1123
+ params = load_config(config)
1124
+ elif isinstance(config, dict):
1125
+ params = get_default_config()
1126
+ for key, value in config.items():
1127
+ if key in params:
1128
+ params[key] = value
1129
+ else:
1130
+ logger.warning(f"Unknown config parameter: {key}")
1131
+ else:
1132
+ raise ValueError("config must be None, a file path (str), or a dict")
1133
+
1134
+ # Setup output directory and file paths
1135
+ os.makedirs(output_dir, exist_ok=True)
1136
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
1137
+
1138
+ output_paths = {
1139
+ "results_csv": os.path.join(output_dir, f"{video_name}_results.csv"),
1140
+ "detections_csv": os.path.join(output_dir, f"{video_name}_detections.csv"),
1141
+ }
1142
+
1143
+ if save_video:
1144
+ output_paths["annotated_video"] = os.path.join(output_dir, f"{video_name}_annotated.mp4")
1145
+ output_paths["debug_video"] = os.path.join(output_dir, f"{video_name}_debug.mp4")
1146
+
1147
+ # Setup crops directory if requested
1148
+ crops_dir = os.path.join(output_dir, f"{video_name}_crops") if crops else None
1149
+ if crops_dir:
1150
+ output_paths["crops_dir"] = crops_dir
1151
+
1152
+ print("\n" + "="*60)
1153
+ print("BPLUSPLUS INFERENCE")
1154
+ print("="*60)
1155
+ print(f"Video: {video_path}")
1156
+ print(f"Model: {hierarchical_model_path}")
1157
+ print(f"Output directory: {output_dir}")
1158
+ print("\nOutput files:")
1159
+ for name, path in output_paths.items():
1160
+ print(f" {name}: {os.path.basename(path)}")
1161
+ print("\nDetection Parameters:")
1162
+ for key, value in params.items():
1163
+ print(f" {key}: {value}")
1164
+ print("="*60)
1165
+
1166
+ # Process
1167
+ processor = VideoInferenceProcessor(
1168
+ species_list=species_list,
1169
+ hierarchical_model_path=hierarchical_model_path,
1170
+ params=params,
1171
+ backbone=backbone,
1172
+ img_size=img_size,
1173
+ )
1174
+
1175
+ try:
1176
+ results = process_video(
1177
+ video_path=video_path,
1178
+ processor=processor,
1179
+ output_paths=output_paths,
1180
+ fps=fps,
1181
+ crops_dir=crops_dir
1182
+ )
1183
+
1184
+ return {
1185
+ "video_file": os.path.basename(video_path),
1186
+ "output_dir": output_dir,
1187
+ "output_files": output_paths,
1188
+ "success": True,
1189
+ "detections": len(processor.all_detections),
1190
+ "tracks": len(results),
1191
+ "confirmed_tracks": len([r for r in results if r.get('passes_topology', False)]),
1192
+ }
1193
+ except Exception as e:
1194
+ logger.exception("Inference failed")
1195
+ return {"error": str(e), "success": False}
1196
+
1197
+
1198
+ # ============================================================================
1199
+ # COMMAND LINE INTERFACE
1200
+ # ============================================================================
1201
+
1202
+ def main():
1203
+ """Command line interface for inference."""
1204
+ parser = argparse.ArgumentParser(
1205
+ description='Bplusplus Video Inference - Detect and classify insects in videos',
1206
+ formatter_class=argparse.RawDescriptionHelpFormatter,
1207
+ epilog="""
1208
+ Examples:
1209
+ # Basic usage
1210
+ python -m bplusplus.inference --video input.mp4 --model model.pt \\
1211
+ --output-dir results/ --species "Apis mellifera" "Bombus terrestris"
1212
+
1213
+ # With config file
1214
+ python -m bplusplus.inference --video input.mp4 --model model.pt \\
1215
+ --output-dir results/ --species "Apis mellifera" --config detection_config.yaml
1216
+
1217
+ Output files generated in output directory:
1218
+ - {video_name}_annotated.mp4: Video with detection boxes and paths
1219
+ - {video_name}_debug.mp4: Side-by-side view with GMM motion mask
1220
+ - {video_name}_results.csv: Aggregated track results
1221
+ - {video_name}_detections.csv: Frame-by-frame detections
1222
+ - {video_name}_crops/ (with --crops): Cropped frames for each track
1223
+ """
1224
+ )
1225
+
1226
+ # Required arguments
1227
+ parser.add_argument('--video', '-v', help='Input video path')
1228
+ parser.add_argument('--model', '-m', help='Path to hierarchical model weights')
1229
+ parser.add_argument('--output-dir', '-o', help='Output directory for all generated files')
1230
+ parser.add_argument('--species', '-s', nargs='+', help='List of species names')
1231
+
1232
+ # Config
1233
+ parser.add_argument('--config', '-c', help='Path to config file (YAML or JSON)')
1234
+
1235
+ # Processing
1236
+ parser.add_argument('--fps', type=float, help='Target processing FPS')
1237
+ parser.add_argument('--show', action='store_true', help='Display video while processing')
1238
+ parser.add_argument('--backbone', '-b', default='resnet50',
1239
+ choices=['resnet18', 'resnet50', 'resnet101'],
1240
+ help='ResNet backbone (default: resnet50, overridden by checkpoint if saved)')
1241
+ parser.add_argument('--crops', action='store_true',
1242
+ help='Save cropped frames for each classified track')
1243
+ parser.add_argument('--no-video', action='store_true',
1244
+ help='Skip saving annotated and debug videos')
1245
+ parser.add_argument('--img-size', type=int, default=60,
1246
+ help='Image size for classification (should match training, default: 60)')
1247
+
1248
+ # Detection parameters (override config)
1249
+ defaults = DEFAULT_DETECTION_CONFIG
1250
+
1251
+ cohesive = parser.add_argument_group('Cohesiveness parameters')
1252
+ cohesive.add_argument('--min-blob-ratio', type=float,
1253
+ help=f'Min largest blob ratio (default: {defaults["min_largest_blob_ratio"]})')
1254
+ cohesive.add_argument('--max-num-blobs', type=int,
1255
+ help=f'Max number of blobs (default: {defaults["max_num_blobs"]})')
1256
+
1257
+ shape = parser.add_argument_group('Shape parameters')
1258
+ shape.add_argument('--min-area', type=int,
1259
+ help=f'Min area in px² (default: {defaults["min_area"]})')
1260
+ shape.add_argument('--max-area', type=int,
1261
+ help=f'Max area in px² (default: {defaults["max_area"]})')
1262
+ shape.add_argument('--min-density', type=float,
1263
+ help=f'Min density (default: {defaults["min_density"]})')
1264
+ shape.add_argument('--min-solidity', type=float,
1265
+ help=f'Min solidity (default: {defaults["min_solidity"]})')
1266
+
1267
+ tracking = parser.add_argument_group('Tracking parameters')
1268
+ tracking.add_argument('--min-displacement', type=int,
1269
+ help=f'Min NET displacement in px (default: {defaults["min_displacement"]})')
1270
+ tracking.add_argument('--min-path-points', type=int,
1271
+ help=f'Min path points (default: {defaults["min_path_points"]})')
1272
+ tracking.add_argument('--max-frame-jump', type=int,
1273
+ help=f'Max pixels between frames (default: {defaults["max_frame_jump"]})')
1274
+ tracking.add_argument('--lost-track-seconds', type=float,
1275
+ help=f'Lost track memory in seconds (default: {defaults["lost_track_seconds"]})')
1276
+
1277
+ topology = parser.add_argument_group('Path topology parameters')
1278
+ topology.add_argument('--max-revisit-ratio', type=float,
1279
+ help=f'Max revisit ratio (default: {defaults["max_revisit_ratio"]})')
1280
+ topology.add_argument('--min-progression-ratio', type=float,
1281
+ help=f'Min progression ratio (default: {defaults["min_progression_ratio"]})')
1282
+ topology.add_argument('--max-directional-variance', type=float,
1283
+ help=f'Max directional variance (default: {defaults["max_directional_variance"]})')
1284
+
1285
+ args = parser.parse_args()
1286
+
1287
+ # Validate required args
1288
+ if not all([args.video, args.model, args.output_dir, args.species]):
1289
+ parser.error("--video, --model, --output-dir, and --species are required")
1290
+
1291
+ # Build config: start with file if provided, then override with CLI args
1292
+ if args.config:
1293
+ config = args.config # Pass path, inference() will load it
1294
+ else:
1295
+ # Build dict from CLI args (only non-None values)
1296
+ cli_overrides = {
1297
+ "min_largest_blob_ratio": args.min_blob_ratio,
1298
+ "max_num_blobs": args.max_num_blobs,
1299
+ "min_area": args.min_area,
1300
+ "max_area": args.max_area,
1301
+ "min_density": args.min_density,
1302
+ "min_solidity": args.min_solidity,
1303
+ "min_displacement": args.min_displacement,
1304
+ "min_path_points": args.min_path_points,
1305
+ "max_frame_jump": args.max_frame_jump,
1306
+ "lost_track_seconds": args.lost_track_seconds,
1307
+ "max_revisit_ratio": args.max_revisit_ratio,
1308
+ "min_progression_ratio": args.min_progression_ratio,
1309
+ "max_directional_variance": args.max_directional_variance,
1310
+ }
1311
+ config = {k: v for k, v in cli_overrides.items() if v is not None} or None
1312
+
1313
+ # Run inference
1314
+ result = inference(
1315
+ species_list=args.species,
1316
+ hierarchical_model_path=args.model,
1317
+ video_path=args.video,
1318
+ output_dir=args.output_dir,
1319
+ fps=args.fps,
1320
+ config=config,
1321
+ backbone=args.backbone,
1322
+ crops=args.crops,
1323
+ save_video=not args.no_video,
1324
+ img_size=args.img_size,
1325
+ )
1326
+
1327
+ if result.get("success"):
1328
+ print(f"\n✓ Inference complete!")
1329
+ print(f" Output directory: {result['output_dir']}")
1330
+ print(f" Detections: {result['detections']}")
1331
+ print(f" Tracks: {result['tracks']} ({result['confirmed_tracks']} confirmed)")
1332
+ else:
1333
+ print(f"\n✗ Inference failed: {result.get('error')}")
1334
+
1335
+
1336
+ if __name__ == "__main__":
1337
+ main()