bplusplus 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

bplusplus/inference.py ADDED
@@ -0,0 +1,929 @@
1
+ import cv2
2
+ import time
3
+ import os
4
+ import sys
5
+ import numpy as np
6
+ import json
7
+ import pandas as pd
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from .tracker import InsectTracker
11
+ import torch
12
+ import torchvision.transforms as T
13
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
14
+ from ultralytics import YOLO
15
+ from torchvision import transforms
16
+ from PIL import Image
17
+ import torch.nn as nn
18
+ from torchvision.models import resnet50
19
+ import requests
20
+ import logging
21
+ from collections import defaultdict
22
+ import uuid
23
+
24
+ # Add this check for backwards compatibility
25
+ if hasattr(torch.serialization, 'add_safe_globals'):
26
+ torch.serialization.add_safe_globals([
27
+ 'torch.LongTensor',
28
+ 'torch.cuda.LongTensor',
29
+ 'torch.FloatStorage',
30
+ 'torch.FloatStorage',
31
+ 'torch.cuda.FloatStorage',
32
+ ])
33
+
34
+ # Set up logging
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # ============================================================================
39
+ # UTILITY FUNCTIONS
40
+ # ============================================================================
41
+
42
+ def get_taxonomy(species_list):
43
+ """
44
+ Retrieves taxonomic information for a list of species from GBIF API.
45
+ Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
46
+ """
47
+ taxonomy = {1: [], 2: {}, 3: {}}
48
+ species_to_genus = {}
49
+ genus_to_family = {}
50
+
51
+ species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
52
+ has_unknown = len(species_list_for_gbif) != len(species_list)
53
+
54
+ logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
55
+
56
+ print(f"\n{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
57
+ print("-" * 80)
58
+
59
+ for species_name in species_list_for_gbif:
60
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
61
+ try:
62
+ response = requests.get(url)
63
+ data = response.json()
64
+
65
+ if data.get('status') in ['ACCEPTED', 'SYNONYM']:
66
+ family = data.get('family')
67
+ genus = data.get('genus')
68
+
69
+ if family and genus:
70
+ print(f"{species_name:<30} {family:<20} {genus:<20} OK")
71
+
72
+ species_to_genus[species_name] = genus
73
+ genus_to_family[genus] = family
74
+
75
+ if family not in taxonomy[1]:
76
+ taxonomy[1].append(family)
77
+
78
+ taxonomy[2][genus] = family
79
+ taxonomy[3][species_name] = genus
80
+ else:
81
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
82
+ logger.error(f"Species '{species_name}' found but missing family/genus data")
83
+ else:
84
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
85
+ logger.error(f"Species '{species_name}' not found in GBIF")
86
+
87
+ except Exception as e:
88
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
89
+ logger.error(f"Error retrieving data for '{species_name}': {str(e)}")
90
+
91
+ if has_unknown:
92
+ unknown_family = "Unknown"
93
+ unknown_genus = "Unknown"
94
+ unknown_species = "unknown"
95
+
96
+ if unknown_family not in taxonomy[1]:
97
+ taxonomy[1].append(unknown_family)
98
+
99
+ taxonomy[2][unknown_genus] = unknown_family
100
+ taxonomy[3][unknown_species] = unknown_genus
101
+ species_to_genus[unknown_species] = unknown_genus
102
+ genus_to_family[unknown_genus] = unknown_family
103
+
104
+ print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
105
+
106
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
107
+ print("-" * 80)
108
+
109
+ # Print indices
110
+ for level, name, items in [(1, "Family", taxonomy[1]), (2, "Genus", taxonomy[2].keys()), (3, "Species", species_list)]:
111
+ print(f"\n{name} indices:")
112
+ for i, item in enumerate(items):
113
+ print(f" {i}: {item}")
114
+
115
+ logger.info(f"Taxonomy built: {len(taxonomy[1])} families, {len(taxonomy[2])} genera, {len(taxonomy[3])} species")
116
+ return taxonomy, species_to_genus, genus_to_family
117
+
118
+ def create_mappings(taxonomy, species_list=None):
119
+ """Create index mappings from taxonomy"""
120
+ level_to_idx = {}
121
+ idx_to_level = {}
122
+
123
+ for level, labels in taxonomy.items():
124
+ if isinstance(labels, list):
125
+ # Level 1: Family (already sorted)
126
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
127
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
128
+ else: # Dictionary for levels 2 and 3
129
+ if level == 3 and species_list is not None:
130
+ # For species, the order is determined by species_list
131
+ sorted_keys = species_list
132
+ else:
133
+ # For genus, sort alphabetically
134
+ sorted_keys = sorted(labels.keys())
135
+
136
+ level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
137
+ idx_to_level[level] = {idx: label for idx, label in enumerate(sorted_keys)}
138
+
139
+ return level_to_idx, idx_to_level
140
+
141
+ # ============================================================================
142
+ # MODEL CLASSES
143
+ # ============================================================================
144
+
145
+ class HierarchicalInsectClassifier(nn.Module):
146
+ def __init__(self, num_classes_per_level):
147
+ """
148
+ Args:
149
+ num_classes_per_level (list): Number of classes for each taxonomic level [family, genus, species]
150
+ """
151
+ super(HierarchicalInsectClassifier, self).__init__()
152
+
153
+ self.backbone = resnet50(pretrained=True)
154
+ backbone_output_features = self.backbone.fc.in_features
155
+ self.backbone.fc = nn.Identity() # Remove the final fully connected layer
156
+
157
+ self.branches = nn.ModuleList()
158
+ for num_classes in num_classes_per_level:
159
+ branch = nn.Sequential(
160
+ nn.Linear(backbone_output_features, 512),
161
+ nn.ReLU(),
162
+ nn.Dropout(0.5),
163
+ nn.Linear(512, num_classes)
164
+ )
165
+ self.branches.append(branch)
166
+
167
+ self.num_levels = len(num_classes_per_level)
168
+
169
+ def forward(self, x):
170
+ R0 = self.backbone(x)
171
+
172
+ outputs = []
173
+ for branch in self.branches:
174
+ outputs.append(branch(R0))
175
+
176
+ return outputs
177
+
178
+ # ============================================================================
179
+ # VISUALIZATION UTILITIES
180
+ # ============================================================================
181
+
182
+ class FrameVisualizer:
183
+ """Modern, slick visualization system for insect detection and tracking"""
184
+
185
+ # Modern color palette - vibrant but professional
186
+ COLORS = [
187
+ (68, 189, 50), # Vibrant Green
188
+ (255, 59, 48), # Red
189
+ (0, 122, 255), # Blue
190
+ (255, 149, 0), # Orange
191
+ (175, 82, 222), # Purple
192
+ (255, 204, 0), # Yellow
193
+ (50, 173, 230), # Light Blue
194
+ (255, 45, 85), # Pink
195
+ (48, 209, 88), # Light Green
196
+ (90, 200, 250), # Sky Blue
197
+ (255, 159, 10), # Amber
198
+ (191, 90, 242), # Lavender
199
+ ]
200
+
201
+ @staticmethod
202
+ def get_track_color(track_id):
203
+ """Get a consistent, vibrant color for a track ID"""
204
+ if track_id is None:
205
+ return (68, 189, 50) # Default green for untracked
206
+
207
+ # Generate consistent index from track_id
208
+ try:
209
+ track_uuid = uuid.UUID(track_id)
210
+ except (ValueError, TypeError):
211
+ track_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(track_id))
212
+
213
+ color_index = track_uuid.int % len(FrameVisualizer.COLORS)
214
+ return FrameVisualizer.COLORS[color_index]
215
+
216
+ @staticmethod
217
+ def draw_rounded_rectangle(frame, pt1, pt2, color, thickness, radius=8):
218
+ """Draw a rounded rectangle"""
219
+ x1, y1 = pt1
220
+ x2, y2 = pt2
221
+
222
+ # Draw main rectangle
223
+ cv2.rectangle(frame, (x1 + radius, y1), (x2 - radius, y2), color, thickness)
224
+ cv2.rectangle(frame, (x1, y1 + radius), (x2, y2 - radius), color, thickness)
225
+
226
+ # Draw corners
227
+ cv2.circle(frame, (x1 + radius, y1 + radius), radius, color, thickness)
228
+ cv2.circle(frame, (x2 - radius, y1 + radius), radius, color, thickness)
229
+ cv2.circle(frame, (x1 + radius, y2 - radius), radius, color, thickness)
230
+ cv2.circle(frame, (x2 - radius, y2 - radius), radius, color, thickness)
231
+
232
+ @staticmethod
233
+ def draw_gradient_background(frame, x1, y1, x2, y2, color, alpha=0.85):
234
+ """Draw a modern gradient background with rounded corners"""
235
+ overlay = frame.copy()
236
+
237
+ # Create gradient effect
238
+ height = y2 - y1
239
+ for i in range(height):
240
+ intensity = 1.0 - (i / height) * 0.3 # Gradient from top to bottom
241
+ gradient_color = tuple(int(c * intensity) for c in color)
242
+ cv2.rectangle(overlay, (x1, y1 + i), (x2, y1 + i + 1), gradient_color, -1)
243
+
244
+ # Blend with original frame
245
+ cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
246
+
247
+ @staticmethod
248
+ def draw_detection_on_frame(frame, x1, y1, x2, y2, track_id, detection_data):
249
+ """Draw modern, sleek detection visualization"""
250
+
251
+ # Get colors
252
+ primary_color = FrameVisualizer.get_track_color(track_id)
253
+
254
+ # Simple, clean bounding box
255
+ box_thickness = 2
256
+
257
+ # Draw single clean rectangle
258
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), primary_color, box_thickness)
259
+
260
+ # Prepare label content without emojis
261
+ if track_id is not None:
262
+ track_short = str(track_id)[:8]
263
+ track_display = f"ID: {track_short}"
264
+ else:
265
+ track_display = "NEW"
266
+
267
+ # Classification results without icons
268
+ classification_lines = []
269
+
270
+ for level, key in [("family", "family_confidence"),
271
+ ("genus", "genus_confidence"),
272
+ ("species", "species_confidence")]:
273
+ if detection_data.get(level):
274
+ conf = detection_data.get(key, 0)
275
+ name = detection_data[level]
276
+ # Truncate long names
277
+ if len(name) > 18:
278
+ name = name[:15] + "..."
279
+ level_short = level[0].upper() # F, G, S
280
+ classification_lines.append(f"{level_short}: {name}")
281
+ classification_lines.append(f" {conf:.1%}")
282
+
283
+ if not classification_lines and track_id is None:
284
+ return
285
+
286
+ # Calculate label box dimensions with smaller, lighter font
287
+ font = cv2.FONT_HERSHEY_SIMPLEX
288
+ font_scale = 0.45
289
+ thickness = 1
290
+ padding = 8
291
+ line_spacing = 6
292
+
293
+ # Calculate text dimensions
294
+ all_lines = [track_display] + classification_lines
295
+ text_sizes = [cv2.getTextSize(line, font, font_scale, thickness)[0] for line in all_lines]
296
+ max_w = max(size[0] for size in text_sizes) if text_sizes else 100
297
+ text_h = text_sizes[0][1] if text_sizes else 20
298
+
299
+ total_h = len(all_lines) * (text_h + line_spacing) + padding * 2
300
+ label_w = max_w + padding * 2
301
+
302
+ # Position label box (above bbox, or below if no space)
303
+ label_x1 = max(0, int(x1))
304
+ label_y1 = max(0, int(y1) - total_h - 5)
305
+ if label_y1 < 0:
306
+ label_y1 = int(y2) + 5
307
+
308
+ label_x2 = min(frame.shape[1], label_x1 + label_w)
309
+ label_y2 = min(frame.shape[0], label_y1 + total_h)
310
+
311
+ # Draw modern gradient background with rounded corners
312
+ FrameVisualizer.draw_gradient_background(frame, label_x1, label_y1, label_x2, label_y2,
313
+ (20, 20, 20), alpha=0.88)
314
+
315
+ # Add subtle border
316
+ FrameVisualizer.draw_rounded_rectangle(frame,
317
+ (label_x1, label_y1),
318
+ (label_x2, label_y2),
319
+ primary_color, 1, radius=6)
320
+
321
+ # Draw text with modern styling
322
+ current_y = label_y1 + padding + text_h
323
+
324
+ for i, line in enumerate(all_lines):
325
+ if i == 0: # Track ID line - use primary color
326
+ text_color = primary_color
327
+ line_thickness = 1
328
+ elif "%" in line: # Confidence lines - use lighter color
329
+ text_color = (160, 160, 160)
330
+ line_thickness = 1
331
+ else: # Classification name lines - use white
332
+ text_color = (255, 255, 255)
333
+ line_thickness = 1
334
+
335
+ # Add subtle text shadow for readability
336
+ cv2.putText(frame, line, (label_x1 + padding + 1, current_y + 1),
337
+ font, font_scale, (0, 0, 0), 1, cv2.LINE_AA)
338
+
339
+ # Main text
340
+ cv2.putText(frame, line, (label_x1 + padding, current_y),
341
+ font, font_scale, text_color, line_thickness, cv2.LINE_AA)
342
+
343
+ current_y += text_h + line_spacing
344
+
345
+ # ============================================================================
346
+ # MAIN PROCESSING CLASS
347
+ # ============================================================================
348
+
349
+ class VideoInferenceProcessor:
350
+ def __init__(self, species_list, yolo_model_path, hierarchical_model_path,
351
+ confidence_threshold=0.35, device_id="video_processor"):
352
+
353
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
354
+ self.species_list = species_list
355
+ self.confidence_threshold = confidence_threshold
356
+ self.device_id = device_id
357
+
358
+ print(f"Using device: {self.device}")
359
+
360
+ # Build taxonomy from species list
361
+ self.taxonomy, self.species_to_genus, self.genus_to_family = get_taxonomy(species_list)
362
+ self.level_to_idx, self.idx_to_level = create_mappings(self.taxonomy, species_list)
363
+ self.family_list = sorted(self.taxonomy[1])
364
+ self.genus_list = sorted(list(self.taxonomy[2].keys()))
365
+
366
+ # Load models
367
+ print(f"Loading YOLO model from {yolo_model_path}")
368
+ self.yolo_model = YOLO(yolo_model_path)
369
+
370
+ print(f"Loading hierarchical model from {hierarchical_model_path}")
371
+ checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
372
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
373
+
374
+ num_classes_per_level = [len(self.family_list), len(self.genus_list), len(self.species_list)]
375
+ print(f"Model architecture: {num_classes_per_level} classes per level")
376
+
377
+ self.classification_model = HierarchicalInsectClassifier(num_classes_per_level)
378
+ self.classification_model.load_state_dict(state_dict, strict=False)
379
+ self.classification_model.to(self.device)
380
+ self.classification_model.eval()
381
+
382
+ # Classification preprocessing
383
+ self.classification_transform = transforms.Compose([
384
+ transforms.Resize((768, 768)),
385
+ transforms.CenterCrop(640),
386
+ transforms.ToTensor(),
387
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
388
+ ])
389
+
390
+ self.all_detections = []
391
+ self.frame_count = 0
392
+ print("Models loaded successfully!")
393
+
394
+ # ------------------------------------------------------------------------
395
+ # UTILITY METHODS
396
+ # ------------------------------------------------------------------------
397
+
398
+ def convert_bbox_to_normalized(self, x, y, x2, y2, width, height):
399
+ x_center = (x + x2) / 2.0 / width
400
+ y_center = (y + y2) / 2.0 / height
401
+ norm_width = (x2 - x) / width
402
+ norm_height = (y2 - y) / height
403
+ return [x_center, y_center, norm_width, norm_height]
404
+
405
+ def store_detection(self, detection_data, timestamp, frame_time_seconds, track_id, bbox):
406
+ """Store detection for final aggregation"""
407
+ payload = {
408
+ "timestamp": timestamp,
409
+ "frame_time_seconds": frame_time_seconds,
410
+ "track_id": track_id,
411
+ "bbox": bbox,
412
+ **detection_data
413
+ }
414
+
415
+ self.all_detections.append(payload)
416
+
417
+ # Print frame-by-frame prediction (only for processed frames)
418
+ track_display = str(track_id)[:8] if track_id else "NEW"
419
+ species = payload.get('species', 'Unknown')
420
+ species_conf = payload.get('species_confidence', 0)
421
+ print(f"Processed {frame_time_seconds:6.2f}s | Track {track_display} | {species} ({species_conf:.1%})")
422
+
423
+ # ------------------------------------------------------------------------
424
+ # DETECTION AND CLASSIFICATION METHODS
425
+ # ------------------------------------------------------------------------
426
+
427
+ def process_classification_results(self, classification_outputs):
428
+ """Process raw classification outputs to get predictions and probabilities"""
429
+ family_output = classification_outputs[0].cpu().numpy().flatten()
430
+ genus_output = classification_outputs[1].cpu().numpy().flatten()
431
+ species_output = classification_outputs[2].cpu().numpy().flatten()
432
+
433
+ # Apply softmax to get probabilities
434
+ family_probs = torch.softmax(classification_outputs[0], dim=1).cpu().numpy().flatten()
435
+ genus_probs = torch.softmax(classification_outputs[1], dim=1).cpu().numpy().flatten()
436
+ species_probs = torch.softmax(classification_outputs[2], dim=1).cpu().numpy().flatten()
437
+
438
+ # Get top predictions
439
+ family_idx = np.argmax(family_probs)
440
+ genus_idx = np.argmax(genus_probs)
441
+ species_idx = np.argmax(species_probs)
442
+
443
+ # Get names
444
+ family_name = self.family_list[family_idx] if family_idx < len(self.family_list) else f"Family_{family_idx}"
445
+ genus_name = self.genus_list[genus_idx] if genus_idx < len(self.genus_list) else f"Genus_{genus_idx}"
446
+ species_name = self.species_list[species_idx] if species_idx < len(self.species_list) else f"Species_{species_idx}"
447
+
448
+ detection_data = {
449
+ "family": family_name,
450
+ "genus": genus_name,
451
+ "species": species_name,
452
+ "family_confidence": float(family_probs[family_idx]),
453
+ "genus_confidence": float(genus_probs[genus_idx]),
454
+ "species_confidence": float(species_probs[species_idx]),
455
+ "family_probs": family_probs.tolist(),
456
+ "genus_probs": genus_probs.tolist(),
457
+ "species_probs": species_probs.tolist()
458
+ }
459
+
460
+ return detection_data
461
+
462
+ def _extract_yolo_detections(self, frame):
463
+ """Extract and validate YOLO detections from frame"""
464
+ with torch.no_grad():
465
+ results = self.yolo_model(frame, conf=self.confidence_threshold, iou=0.5, verbose=False)
466
+
467
+ detections = results[0].boxes
468
+ valid_detections = []
469
+ valid_detection_data = []
470
+
471
+ if detections is not None and len(detections) > 0:
472
+ height, width = frame.shape[:2]
473
+
474
+ for box in detections:
475
+ xyxy = box.xyxy.cpu().numpy().flatten()
476
+ confidence = box.conf.cpu().numpy().item()
477
+
478
+ if confidence < self.confidence_threshold:
479
+ continue
480
+
481
+ x1, y1, x2, y2 = xyxy[:4]
482
+
483
+ # Clamp coordinates
484
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(width, x2), min(height, y2)
485
+
486
+ if x2 <= x1 or y2 <= y1:
487
+ continue
488
+
489
+ # Store detection for tracking (x1, y1, x2, y2 format)
490
+ valid_detections.append([x1, y1, x2, y2])
491
+ valid_detection_data.append({
492
+ 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2,
493
+ 'confidence': confidence
494
+ })
495
+
496
+ return valid_detections, valid_detection_data
497
+
498
+ def _classify_detection(self, frame, x1, y1, x2, y2):
499
+ """Perform hierarchical classification on a detection crop"""
500
+ crop = frame[int(y1):int(y2), int(x1):int(x2)]
501
+ crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
502
+ pil_img = Image.fromarray(crop_rgb)
503
+ input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
504
+
505
+ with torch.no_grad():
506
+ classification_outputs = self.classification_model(input_tensor)
507
+
508
+ return self.process_classification_results(classification_outputs)
509
+
510
+ def _process_single_detection(self, frame, det_data, track_id, frame_time_seconds):
511
+ """Process a single detection: classify, store, and visualize"""
512
+ x1, y1, x2, y2 = det_data['x1'], det_data['y1'], det_data['x2'], det_data['y2']
513
+ height, width = frame.shape[:2]
514
+
515
+ # Perform classification
516
+ detection_data = self._classify_detection(frame, x1, y1, x2, y2)
517
+
518
+ # Store detection for aggregation
519
+ bbox = self.convert_bbox_to_normalized(x1, y1, x2, y2, width, height)
520
+ timestamp = datetime.now().isoformat()
521
+ self.store_detection(detection_data, timestamp, frame_time_seconds, track_id, bbox)
522
+
523
+ # Visualization
524
+ FrameVisualizer.draw_detection_on_frame(frame, x1, y1, x2, y2, track_id, detection_data)
525
+
526
+ def process_frame(self, frame, frame_time_seconds, tracker, global_frame_count):
527
+ """Process a single frame from the video."""
528
+ self.frame_count += 1
529
+
530
+ # Extract YOLO detections
531
+ valid_detections, valid_detection_data = self._extract_yolo_detections(frame)
532
+
533
+ # Update tracker with detections
534
+ track_ids = tracker.update(valid_detections, global_frame_count)
535
+
536
+ # Process each detection with its track ID
537
+ for i, det_data in enumerate(valid_detection_data):
538
+ track_id = track_ids[i] if i < len(track_ids) else None
539
+ self._process_single_detection(frame, det_data, track_id, frame_time_seconds)
540
+
541
+ return frame
542
+
543
+ # ------------------------------------------------------------------------
544
+ # AGGREGATION AND RESULTS METHODS
545
+ # ------------------------------------------------------------------------
546
+
547
+ def hierarchical_aggregation(self):
548
+ """
549
+ Perform hierarchical aggregation of results per track ID.
550
+
551
+ CONFIDENCE CALCULATION EXPLANATION:
552
+ 1. For each track, average the softmax probabilities across all detections
553
+ 2. Use hierarchical selection:
554
+ - Find best family (highest averaged family probability)
555
+ - Find best genus within that family (highest averaged genus probability among genera in that family)
556
+ - Find best species within that genus (highest averaged species probability among species in that genus)
557
+ 3. Final confidence = averaged probability of the selected class at each taxonomic level
558
+ """
559
+ print("\n=== Performing Hierarchical Aggregation ===")
560
+
561
+ # Group detections by track ID
562
+ track_detections = defaultdict(list)
563
+ for detection in self.all_detections:
564
+ if detection['track_id'] is not None:
565
+ track_detections[detection['track_id']].append(detection)
566
+
567
+ aggregated_results = []
568
+
569
+ for track_id, detections in track_detections.items():
570
+ print(f"\nProcessing Track ID: {track_id} ({len(detections)} detections)")
571
+
572
+ # Aggregate probabilities across all detections for this track
573
+ prob_sums = [
574
+ np.zeros(len(self.family_list)),
575
+ np.zeros(len(self.genus_list)),
576
+ np.zeros(len(self.species_list))
577
+ ]
578
+
579
+ for detection in detections:
580
+ prob_sums[0] += np.array(detection['family_probs'])
581
+ prob_sums[1] += np.array(detection['genus_probs'])
582
+ prob_sums[2] += np.array(detection['species_probs'])
583
+
584
+ # Average the probabilities
585
+ prob_avgs = [prob_sum / len(detections) for prob_sum in prob_sums]
586
+
587
+ # Hierarchical selection: Start with family
588
+ best_family_idx = np.argmax(prob_avgs[0])
589
+ best_family = self.family_list[best_family_idx]
590
+ best_family_prob = prob_avgs[0][best_family_idx]
591
+ print(f" Best family: {best_family} (prob: {best_family_prob:.3f})")
592
+
593
+ # Find genera belonging to this family
594
+ family_genera_indices = [i for i, genus in enumerate(self.genus_list)
595
+ if genus in self.genus_to_family and self.genus_to_family[genus] == best_family]
596
+
597
+ if family_genera_indices:
598
+ family_genus_probs = prob_avgs[1][family_genera_indices]
599
+ best_genus_idx = family_genera_indices[np.argmax(family_genus_probs)]
600
+ else:
601
+ best_genus_idx = np.argmax(prob_avgs[1])
602
+
603
+ best_genus = self.genus_list[best_genus_idx]
604
+ best_genus_prob = prob_avgs[1][best_genus_idx]
605
+ print(f" Best genus: {best_genus} (prob: {best_genus_prob:.3f})")
606
+
607
+ # Find species belonging to this genus
608
+ genus_species_indices = [i for i, species in enumerate(self.species_list)
609
+ if species in self.species_to_genus and self.species_to_genus[species] == best_genus]
610
+
611
+ if genus_species_indices:
612
+ genus_species_probs = prob_avgs[2][genus_species_indices]
613
+ best_species_idx = genus_species_indices[np.argmax(genus_species_probs)]
614
+ else:
615
+ best_species_idx = np.argmax(prob_avgs[2])
616
+
617
+ best_species = self.species_list[best_species_idx]
618
+ best_species_prob = prob_avgs[2][best_species_idx]
619
+ print(f" Best species: {best_species} (prob: {best_species_prob:.3f})")
620
+
621
+ # Calculate track statistics
622
+ frame_times = [d['frame_time_seconds'] for d in detections]
623
+ first_frame, last_frame = min(frame_times), max(frame_times)
624
+
625
+ aggregated_results.append({
626
+ 'track_id': track_id,
627
+ 'num_detections': len(detections),
628
+ 'first_frame_time': first_frame,
629
+ 'last_frame_time': last_frame,
630
+ 'duration': last_frame - first_frame,
631
+ 'final_family': best_family,
632
+ 'final_genus': best_genus,
633
+ 'final_species': best_species,
634
+ 'family_confidence': best_family_prob,
635
+ 'genus_confidence': best_genus_prob,
636
+ 'species_confidence': best_species_prob
637
+ })
638
+
639
+ return aggregated_results
640
+
641
+ def print_simplified_summary(self, aggregated_results):
642
+ """Print a simplified, readable summary of tracking results"""
643
+ print("\n" + "="*60)
644
+ print("🐛 INSECT TRACKING SUMMARY")
645
+ print("="*60)
646
+
647
+ if not aggregated_results:
648
+ print("No insects were tracked in this video.")
649
+ return
650
+
651
+ # Sort by number of detections (most active insects first)
652
+ sorted_results = sorted(aggregated_results, key=lambda x: x['num_detections'], reverse=True)
653
+
654
+ for i, result in enumerate(sorted_results, 1):
655
+ duration = result['duration']
656
+ detection_count = result['num_detections']
657
+
658
+ print(f"\n🐞 Insect {i}:")
659
+ print(f" Detections: {detection_count}")
660
+ print(f" Duration: {duration:.1f}s")
661
+ print(f" Family: {result['final_family']}")
662
+ print(f" Genus: {result['final_genus']}")
663
+ print(f" Species: {result['final_species']}")
664
+ print(f" Confidence: {result['species_confidence']:.1%}")
665
+
666
+ print(f"\n📈 Total: {len(aggregated_results)} unique insects tracked")
667
+ print("="*60)
668
+
669
+ def save_results_table(self, aggregated_results, output_path):
670
+ """Save aggregated results to CSV"""
671
+ df = pd.DataFrame(aggregated_results).sort_values('track_id')
672
+ csv_path = str(output_path).replace('.mp4', '_results.csv')
673
+ df.to_csv(csv_path, index=False)
674
+
675
+ print(f"\n📊 Results saved to: {csv_path}")
676
+
677
+ # Print simplified summary
678
+ self.print_simplified_summary(aggregated_results)
679
+
680
+ return df
681
+
682
+ # ============================================================================
683
+ # VIDEO PROCESSING FUNCTIONS
684
+ # ============================================================================
685
+
686
+ def process_video(video_path, processor, output_video_path=None, show_video=False, tracker_max_frames=30, fps=None):
687
+ """Process an MP4 video file frame by frame.
688
+
689
+ Args:
690
+ fps (float, optional): FPS for processing and output. If provided, frames will be skipped
691
+ to match this rate and output video will use this FPS. If None, all frames are processed.
692
+ """
693
+
694
+ if not os.path.exists(video_path):
695
+ raise FileNotFoundError(f"Video file not found: {video_path}")
696
+
697
+ cap = cv2.VideoCapture(video_path)
698
+ if not cap.isOpened():
699
+ raise ValueError(f"Could not open video file: {video_path}")
700
+
701
+ # Get video properties
702
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
703
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
704
+ width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
705
+ duration = total_frames / input_fps if input_fps > 0 else 0
706
+
707
+ print(f"Processing video: {video_path}")
708
+ print(f"Properties: {total_frames} frames, {input_fps:.2f} FPS, {duration:.2f}s duration")
709
+
710
+ # Initialize tracker and output writer
711
+ tracker = InsectTracker(height, width, max_frames=tracker_max_frames, debug=False)
712
+ out = None
713
+
714
+ if output_video_path:
715
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
716
+ write_fps = fps if fps is not None else input_fps
717
+ out = cv2.VideoWriter(output_video_path, fourcc, write_fps, (width, height))
718
+ print(f"Output video: {output_video_path} at {write_fps:.2f} FPS")
719
+
720
+ frame_number = 0
721
+ processed_frame_count = 0
722
+ start_time = time.time()
723
+
724
+ # Calculate frame skip interval if fps is specified
725
+ frame_skip_interval = None
726
+ if fps is not None and fps > 0 and input_fps > 0:
727
+ frame_skip_interval = max(1, int(input_fps / fps))
728
+ print(f"Processing every {frame_skip_interval} frame(s) to achieve ~{fps:.2f} FPS")
729
+ print(f"Tracker will only receive the processed frames, not all {total_frames} frames")
730
+
731
+ try:
732
+ while True:
733
+ ret, frame = cap.read()
734
+ if not ret:
735
+ break
736
+
737
+ frame_time_seconds = frame_number / input_fps if input_fps > 0 else 0
738
+
739
+ # Check if we should process this frame
740
+ should_process = True
741
+ if frame_skip_interval is not None:
742
+ should_process = (frame_number % frame_skip_interval == 0)
743
+
744
+ if should_process:
745
+ # Process the frame
746
+ processed_frame = processor.process_frame(frame, frame_time_seconds, tracker, frame_number)
747
+ processed_frame_count += 1
748
+ last_processed_frame = processed_frame.copy() # Keep for frame duplication
749
+
750
+ # Show video if requested
751
+ if show_video:
752
+ cv2.imshow('Video Inference', processed_frame)
753
+ if cv2.waitKey(1) & 0xFF == ord('q'):
754
+ print("User requested quit")
755
+ break
756
+
757
+ # If we're skipping frames, skip ahead to save time
758
+ if frame_skip_interval is not None and frame_skip_interval > 1:
759
+ # Skip the next (frame_skip_interval - 1) frames
760
+ for _ in range(frame_skip_interval - 1):
761
+ ret, _ = cap.read()
762
+ if not ret:
763
+ break
764
+ frame_number += 1
765
+ else:
766
+ # Use the last processed frame to maintain video length
767
+ if 'last_processed_frame' in locals():
768
+ processed_frame = last_processed_frame.copy()
769
+ else:
770
+ processed_frame = frame # Fallback for first frames
771
+
772
+ # Write to output video if requested (always write to maintain duration)
773
+ if out:
774
+ out.write(processed_frame)
775
+
776
+ frame_number += 1
777
+
778
+ # Progress update based on processed frames only
779
+ if should_process and processed_frame_count % 25 == 0:
780
+ if frame_skip_interval is not None:
781
+ estimated_total_processed = total_frames // frame_skip_interval
782
+ progress = (processed_frame_count / estimated_total_processed) * 100 if estimated_total_processed > 0 else 0
783
+ print(f"Processed: {processed_frame_count}/{estimated_total_processed} frames ({progress:.1f}%)")
784
+ else:
785
+ progress = (frame_number / total_frames) * 100 if total_frames > 0 else 0
786
+ print(f"Processed: {processed_frame_count}/{total_frames} frames ({progress:.1f}%)")
787
+
788
+ finally:
789
+ cap.release()
790
+ if out:
791
+ out.release()
792
+ if show_video:
793
+ cv2.destroyAllWindows()
794
+
795
+ processing_time = time.time() - start_time
796
+ print(f"\nProcessing complete! {processed_frame_count}/{frame_number} frames processed in {processing_time:.2f}s")
797
+ if processed_frame_count > 0:
798
+ print(f"Processing speed: {processed_frame_count/processing_time:.2f} FPS, Detections: {len(processor.all_detections)}")
799
+ else:
800
+ print(f"No frames processed, Detections: {len(processor.all_detections)}")
801
+
802
+ # Perform hierarchical aggregation and save results
803
+ aggregated_results = processor.hierarchical_aggregation()
804
+ if output_video_path:
805
+ processor.save_results_table(aggregated_results, output_video_path)
806
+
807
+ return aggregated_results
808
+
809
+ # ============================================================================
810
+ # MAIN ENTRY POINT FUNCTIONS
811
+ # ============================================================================
812
+
813
+ def inference(species_list, yolo_model_path, hierarchical_model_path, confidence_threshold,
814
+ video_path, output_path, tracker_max_frames, fps=None):
815
+ """
816
+ Run inference on a single video file.
817
+
818
+ Args:
819
+ species_list (list): List of species names for classification
820
+ yolo_model_path (str): Path to YOLO model weights
821
+ hierarchical_model_path (str): Path to hierarchical classification model weights
822
+ confidence_threshold (float): Confidence threshold for detections
823
+ video_path (str): Path to input video file
824
+ output_path (str): Path for output video file (including filename)
825
+ tracker_max_frames (int): Maximum frames for tracker context
826
+ fps (float, optional): Processing and output FPS. If provided, frames will be skipped
827
+ to match this rate and output video will use this FPS. If None, all frames are processed.
828
+
829
+ Returns:
830
+ dict: Summary of processing results
831
+ """
832
+ # Check if input video exists
833
+ if not os.path.exists(video_path):
834
+ error_msg = f"Video file not found: {video_path}"
835
+ print(f"Error: {error_msg}")
836
+ return {"error": error_msg}
837
+
838
+ # Create output directory if it doesn't exist
839
+ output_dir = os.path.dirname(output_path)
840
+ if output_dir:
841
+ os.makedirs(output_dir, exist_ok=True)
842
+
843
+ print(f"Processing single video: {video_path}")
844
+
845
+ # Create processor instance
846
+ print("Initializing models...")
847
+ processor = VideoInferenceProcessor(
848
+ species_list=species_list,
849
+ yolo_model_path=yolo_model_path,
850
+ hierarchical_model_path=hierarchical_model_path,
851
+ confidence_threshold=confidence_threshold
852
+ )
853
+
854
+ # Track processing results
855
+ processing_results = {
856
+ "video_file": os.path.basename(video_path),
857
+ "output_path": output_path,
858
+ "success": False,
859
+ "detections": 0,
860
+ "tracks": 0,
861
+ "error": None
862
+ }
863
+
864
+ print(f"\n{'='*20}\nProcessing: {video_path}\n{'='*20}")
865
+
866
+ try:
867
+ # Process the video
868
+ aggregated_results = process_video(
869
+ video_path=video_path,
870
+ processor=processor,
871
+ output_video_path=output_path,
872
+ show_video=False,
873
+ tracker_max_frames=tracker_max_frames,
874
+ fps=fps
875
+ )
876
+
877
+ processing_results.update({
878
+ "success": True,
879
+ "detections": len(processor.all_detections),
880
+ "tracks": len(aggregated_results)
881
+ })
882
+
883
+ print(f"Finished processing. Output saved to {output_path}")
884
+
885
+ except Exception as e:
886
+ error_msg = f"Failed to process {os.path.basename(video_path)}: {str(e)}"
887
+ print(f"Error: {error_msg}")
888
+ processing_results["error"] = error_msg
889
+
890
+ if processing_results["success"]:
891
+ print(f"\nProcessing complete!")
892
+ print(f"Total detections: {processing_results['detections']}")
893
+ print(f"Total tracks: {processing_results['tracks']}")
894
+
895
+ return processing_results
896
+
897
+
898
+ def main():
899
+ """Example usage for processing a single video."""
900
+ # Define your species list (replace with your actual species)
901
+ species_list = [
902
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
903
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
904
+ "Eristalis tenax", "unknown"
905
+ ]
906
+
907
+ # Paths (replace with your actual paths)
908
+ video_path = "input_videos/sample_video.mp4"
909
+ output_path = "output_videos/sample_video_predictions.mp4"
910
+ yolo_model_path = "weights/yolo_model.pt"
911
+ hierarchical_model_path = "weights/hierarchical_model.pth"
912
+
913
+ # Run inference
914
+ results = inference(
915
+ species_list=species_list,
916
+ yolo_model_path=yolo_model_path,
917
+ hierarchical_model_path=hierarchical_model_path,
918
+ confidence_threshold=0.35,
919
+ video_path=video_path,
920
+ output_path=output_path,
921
+ tracker_max_frames=60,
922
+ fps=None # Set to e.g. 5.0 to process only 5 frames per second
923
+ )
924
+
925
+ return results
926
+
927
+
928
+ if __name__ == "__main__":
929
+ main()