bplusplus 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

bplusplus/inference.py ADDED
@@ -0,0 +1,891 @@
1
+ import cv2
2
+ import time
3
+ import os
4
+ import sys
5
+ import numpy as np
6
+ import json
7
+ import pandas as pd
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from .tracker import InsectTracker
11
+ import torch
12
+ from ultralytics import YOLO
13
+ from torchvision import transforms
14
+ from PIL import Image
15
+ import torch.nn as nn
16
+ from torchvision.models import resnet50
17
+ import requests
18
+ import logging
19
+ from collections import defaultdict
20
+ import uuid
21
+
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # ============================================================================
27
+ # UTILITY FUNCTIONS
28
+ # ============================================================================
29
+
30
+ def get_taxonomy(species_list):
31
+ """
32
+ Retrieves taxonomic information for a list of species from GBIF API.
33
+ Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
34
+ """
35
+ taxonomy = {1: [], 2: {}, 3: {}}
36
+ species_to_genus = {}
37
+ genus_to_family = {}
38
+
39
+ logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
40
+
41
+ print(f"\n{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
42
+ print("-" * 80)
43
+
44
+ for species_name in species_list:
45
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
46
+ try:
47
+ response = requests.get(url)
48
+ data = response.json()
49
+
50
+ if data.get('status') in ['ACCEPTED', 'SYNONYM']:
51
+ family = data.get('family')
52
+ genus = data.get('genus')
53
+
54
+ if family and genus:
55
+ print(f"{species_name:<30} {family:<20} {genus:<20} OK")
56
+
57
+ species_to_genus[species_name] = genus
58
+ genus_to_family[genus] = family
59
+
60
+ if family not in taxonomy[1]:
61
+ taxonomy[1].append(family)
62
+
63
+ taxonomy[2][genus] = family
64
+ taxonomy[3][species_name] = genus
65
+ else:
66
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
67
+ logger.error(f"Species '{species_name}' found but missing family/genus data")
68
+ else:
69
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
70
+ logger.error(f"Species '{species_name}' not found in GBIF")
71
+
72
+ except Exception as e:
73
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
74
+ logger.error(f"Error retrieving data for '{species_name}': {str(e)}")
75
+
76
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
77
+ print("-" * 80)
78
+
79
+ # Print indices
80
+ for level, name, items in [(1, "Family", taxonomy[1]), (2, "Genus", taxonomy[2].keys()), (3, "Species", species_list)]:
81
+ print(f"\n{name} indices:")
82
+ for i, item in enumerate(items):
83
+ print(f" {i}: {item}")
84
+
85
+ logger.info(f"Taxonomy built: {len(taxonomy[1])} families, {len(taxonomy[2])} genera, {len(taxonomy[3])} species")
86
+ return taxonomy, species_to_genus, genus_to_family
87
+
88
+ def create_mappings(taxonomy):
89
+ """Create index mappings from taxonomy"""
90
+ level_to_idx = {}
91
+ idx_to_level = {}
92
+
93
+ for level, labels in taxonomy.items():
94
+ if isinstance(labels, list):
95
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
96
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
97
+ else: # Dictionary
98
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
99
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels.keys())}
100
+
101
+ return level_to_idx, idx_to_level
102
+
103
+ # ============================================================================
104
+ # MODEL CLASSES
105
+ # ============================================================================
106
+
107
+ class HierarchicalInsectClassifier(nn.Module):
108
+ def __init__(self, num_classes_per_level):
109
+ """
110
+ Args:
111
+ num_classes_per_level (list): Number of classes for each taxonomic level [family, genus, species]
112
+ """
113
+ super(HierarchicalInsectClassifier, self).__init__()
114
+
115
+ self.backbone = resnet50(pretrained=True)
116
+ backbone_output_features = self.backbone.fc.in_features
117
+ self.backbone.fc = nn.Identity() # Remove the final fully connected layer
118
+
119
+ self.branches = nn.ModuleList()
120
+ for num_classes in num_classes_per_level:
121
+ branch = nn.Sequential(
122
+ nn.Linear(backbone_output_features, 512),
123
+ nn.ReLU(),
124
+ nn.Dropout(0.5),
125
+ nn.Linear(512, num_classes)
126
+ )
127
+ self.branches.append(branch)
128
+
129
+ self.num_levels = len(num_classes_per_level)
130
+
131
+ def forward(self, x):
132
+ R0 = self.backbone(x)
133
+
134
+ outputs = []
135
+ for branch in self.branches:
136
+ outputs.append(branch(R0))
137
+
138
+ return outputs
139
+
140
+ # ============================================================================
141
+ # VISUALIZATION UTILITIES
142
+ # ============================================================================
143
+
144
+ class FrameVisualizer:
145
+ """Modern, slick visualization system for insect detection and tracking"""
146
+
147
+ # Modern color palette - vibrant but professional
148
+ COLORS = [
149
+ (68, 189, 50), # Vibrant Green
150
+ (255, 59, 48), # Red
151
+ (0, 122, 255), # Blue
152
+ (255, 149, 0), # Orange
153
+ (175, 82, 222), # Purple
154
+ (255, 204, 0), # Yellow
155
+ (50, 173, 230), # Light Blue
156
+ (255, 45, 85), # Pink
157
+ (48, 209, 88), # Light Green
158
+ (90, 200, 250), # Sky Blue
159
+ (255, 159, 10), # Amber
160
+ (191, 90, 242), # Lavender
161
+ ]
162
+
163
+ @staticmethod
164
+ def get_track_color(track_id):
165
+ """Get a consistent, vibrant color for a track ID"""
166
+ if track_id is None:
167
+ return (68, 189, 50) # Default green for untracked
168
+
169
+ # Generate consistent index from track_id
170
+ try:
171
+ track_uuid = uuid.UUID(track_id)
172
+ except (ValueError, TypeError):
173
+ track_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(track_id))
174
+
175
+ color_index = track_uuid.int % len(FrameVisualizer.COLORS)
176
+ return FrameVisualizer.COLORS[color_index]
177
+
178
+ @staticmethod
179
+ def draw_rounded_rectangle(frame, pt1, pt2, color, thickness, radius=8):
180
+ """Draw a rounded rectangle"""
181
+ x1, y1 = pt1
182
+ x2, y2 = pt2
183
+
184
+ # Draw main rectangle
185
+ cv2.rectangle(frame, (x1 + radius, y1), (x2 - radius, y2), color, thickness)
186
+ cv2.rectangle(frame, (x1, y1 + radius), (x2, y2 - radius), color, thickness)
187
+
188
+ # Draw corners
189
+ cv2.circle(frame, (x1 + radius, y1 + radius), radius, color, thickness)
190
+ cv2.circle(frame, (x2 - radius, y1 + radius), radius, color, thickness)
191
+ cv2.circle(frame, (x1 + radius, y2 - radius), radius, color, thickness)
192
+ cv2.circle(frame, (x2 - radius, y2 - radius), radius, color, thickness)
193
+
194
+ @staticmethod
195
+ def draw_gradient_background(frame, x1, y1, x2, y2, color, alpha=0.85):
196
+ """Draw a modern gradient background with rounded corners"""
197
+ overlay = frame.copy()
198
+
199
+ # Create gradient effect
200
+ height = y2 - y1
201
+ for i in range(height):
202
+ intensity = 1.0 - (i / height) * 0.3 # Gradient from top to bottom
203
+ gradient_color = tuple(int(c * intensity) for c in color)
204
+ cv2.rectangle(overlay, (x1, y1 + i), (x2, y1 + i + 1), gradient_color, -1)
205
+
206
+ # Blend with original frame
207
+ cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
208
+
209
+ @staticmethod
210
+ def draw_detection_on_frame(frame, x1, y1, x2, y2, track_id, detection_data):
211
+ """Draw modern, sleek detection visualization"""
212
+
213
+ # Get colors
214
+ primary_color = FrameVisualizer.get_track_color(track_id)
215
+
216
+ # Simple, clean bounding box
217
+ box_thickness = 2
218
+
219
+ # Draw single clean rectangle
220
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), primary_color, box_thickness)
221
+
222
+ # Prepare label content without emojis
223
+ if track_id is not None:
224
+ track_short = str(track_id)[:8]
225
+ track_display = f"ID: {track_short}"
226
+ else:
227
+ track_display = "NEW"
228
+
229
+ # Classification results without icons
230
+ classification_lines = []
231
+
232
+ for level, key in [("family", "family_confidence"),
233
+ ("genus", "genus_confidence"),
234
+ ("species", "species_confidence")]:
235
+ if detection_data.get(level):
236
+ conf = detection_data.get(key, 0)
237
+ name = detection_data[level]
238
+ # Truncate long names
239
+ if len(name) > 18:
240
+ name = name[:15] + "..."
241
+ level_short = level[0].upper() # F, G, S
242
+ classification_lines.append(f"{level_short}: {name}")
243
+ classification_lines.append(f" {conf:.1%}")
244
+
245
+ if not classification_lines and track_id is None:
246
+ return
247
+
248
+ # Calculate label box dimensions with smaller, lighter font
249
+ font = cv2.FONT_HERSHEY_SIMPLEX
250
+ font_scale = 0.45
251
+ thickness = 1
252
+ padding = 8
253
+ line_spacing = 6
254
+
255
+ # Calculate text dimensions
256
+ all_lines = [track_display] + classification_lines
257
+ text_sizes = [cv2.getTextSize(line, font, font_scale, thickness)[0] for line in all_lines]
258
+ max_w = max(size[0] for size in text_sizes) if text_sizes else 100
259
+ text_h = text_sizes[0][1] if text_sizes else 20
260
+
261
+ total_h = len(all_lines) * (text_h + line_spacing) + padding * 2
262
+ label_w = max_w + padding * 2
263
+
264
+ # Position label box (above bbox, or below if no space)
265
+ label_x1 = max(0, int(x1))
266
+ label_y1 = max(0, int(y1) - total_h - 5)
267
+ if label_y1 < 0:
268
+ label_y1 = int(y2) + 5
269
+
270
+ label_x2 = min(frame.shape[1], label_x1 + label_w)
271
+ label_y2 = min(frame.shape[0], label_y1 + total_h)
272
+
273
+ # Draw modern gradient background with rounded corners
274
+ FrameVisualizer.draw_gradient_background(frame, label_x1, label_y1, label_x2, label_y2,
275
+ (20, 20, 20), alpha=0.88)
276
+
277
+ # Add subtle border
278
+ FrameVisualizer.draw_rounded_rectangle(frame,
279
+ (label_x1, label_y1),
280
+ (label_x2, label_y2),
281
+ primary_color, 1, radius=6)
282
+
283
+ # Draw text with modern styling
284
+ current_y = label_y1 + padding + text_h
285
+
286
+ for i, line in enumerate(all_lines):
287
+ if i == 0: # Track ID line - use primary color
288
+ text_color = primary_color
289
+ line_thickness = 1
290
+ elif "%" in line: # Confidence lines - use lighter color
291
+ text_color = (160, 160, 160)
292
+ line_thickness = 1
293
+ else: # Classification name lines - use white
294
+ text_color = (255, 255, 255)
295
+ line_thickness = 1
296
+
297
+ # Add subtle text shadow for readability
298
+ cv2.putText(frame, line, (label_x1 + padding + 1, current_y + 1),
299
+ font, font_scale, (0, 0, 0), 1, cv2.LINE_AA)
300
+
301
+ # Main text
302
+ cv2.putText(frame, line, (label_x1 + padding, current_y),
303
+ font, font_scale, text_color, line_thickness, cv2.LINE_AA)
304
+
305
+ current_y += text_h + line_spacing
306
+
307
+ # ============================================================================
308
+ # MAIN PROCESSING CLASS
309
+ # ============================================================================
310
+
311
+ class VideoInferenceProcessor:
312
+ def __init__(self, species_list, yolo_model_path, hierarchical_model_path,
313
+ confidence_threshold=0.35, device_id="video_processor"):
314
+
315
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
316
+ self.species_list = species_list
317
+ self.confidence_threshold = confidence_threshold
318
+ self.device_id = device_id
319
+
320
+ print(f"Using device: {self.device}")
321
+
322
+ # Build taxonomy from species list
323
+ self.taxonomy, self.species_to_genus, self.genus_to_family = get_taxonomy(species_list)
324
+ self.level_to_idx, self.idx_to_level = create_mappings(self.taxonomy)
325
+ self.family_list = self.taxonomy[1]
326
+ self.genus_list = list(self.taxonomy[2].keys())
327
+
328
+ # Load models
329
+ print(f"Loading YOLO model from {yolo_model_path}")
330
+ self.yolo_model = YOLO(yolo_model_path)
331
+
332
+ print(f"Loading hierarchical model from {hierarchical_model_path}")
333
+ checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
334
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
335
+
336
+ num_classes_per_level = [len(self.family_list), len(self.genus_list), len(self.species_list)]
337
+ print(f"Model architecture: {num_classes_per_level} classes per level")
338
+
339
+ self.classification_model = HierarchicalInsectClassifier(num_classes_per_level)
340
+ self.classification_model.load_state_dict(state_dict, strict=False)
341
+ self.classification_model.to(self.device)
342
+ self.classification_model.eval()
343
+
344
+ # Classification preprocessing
345
+ self.classification_transform = transforms.Compose([
346
+ transforms.Resize((768, 768)),
347
+ transforms.CenterCrop(640),
348
+ transforms.ToTensor(),
349
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
350
+ ])
351
+
352
+ self.all_detections = []
353
+ self.frame_count = 0
354
+ print("Models loaded successfully!")
355
+
356
+ # ------------------------------------------------------------------------
357
+ # UTILITY METHODS
358
+ # ------------------------------------------------------------------------
359
+
360
+ def convert_bbox_to_normalized(self, x, y, x2, y2, width, height):
361
+ x_center = (x + x2) / 2.0 / width
362
+ y_center = (y + y2) / 2.0 / height
363
+ norm_width = (x2 - x) / width
364
+ norm_height = (y2 - y) / height
365
+ return [x_center, y_center, norm_width, norm_height]
366
+
367
+ def store_detection(self, detection_data, timestamp, frame_time_seconds, track_id, bbox):
368
+ """Store detection for final aggregation"""
369
+ payload = {
370
+ "timestamp": timestamp,
371
+ "frame_time_seconds": frame_time_seconds,
372
+ "track_id": track_id,
373
+ "bbox": bbox,
374
+ **detection_data
375
+ }
376
+
377
+ self.all_detections.append(payload)
378
+
379
+ # Print frame-by-frame prediction (only for processed frames)
380
+ track_display = str(track_id)[:8] if track_id else "NEW"
381
+ species = payload.get('species', 'Unknown')
382
+ species_conf = payload.get('species_confidence', 0)
383
+ print(f"Processed {frame_time_seconds:6.2f}s | Track {track_display} | {species} ({species_conf:.1%})")
384
+
385
+ # ------------------------------------------------------------------------
386
+ # DETECTION AND CLASSIFICATION METHODS
387
+ # ------------------------------------------------------------------------
388
+
389
+ def process_classification_results(self, classification_outputs):
390
+ """Process raw classification outputs to get predictions and probabilities"""
391
+ family_output = classification_outputs[0].cpu().numpy().flatten()
392
+ genus_output = classification_outputs[1].cpu().numpy().flatten()
393
+ species_output = classification_outputs[2].cpu().numpy().flatten()
394
+
395
+ # Apply softmax to get probabilities
396
+ family_probs = torch.softmax(classification_outputs[0], dim=1).cpu().numpy().flatten()
397
+ genus_probs = torch.softmax(classification_outputs[1], dim=1).cpu().numpy().flatten()
398
+ species_probs = torch.softmax(classification_outputs[2], dim=1).cpu().numpy().flatten()
399
+
400
+ # Get top predictions
401
+ family_idx = np.argmax(family_probs)
402
+ genus_idx = np.argmax(genus_probs)
403
+ species_idx = np.argmax(species_probs)
404
+
405
+ # Get names
406
+ family_name = self.family_list[family_idx] if family_idx < len(self.family_list) else f"Family_{family_idx}"
407
+ genus_name = self.genus_list[genus_idx] if genus_idx < len(self.genus_list) else f"Genus_{genus_idx}"
408
+ species_name = self.species_list[species_idx] if species_idx < len(self.species_list) else f"Species_{species_idx}"
409
+
410
+ detection_data = {
411
+ "family": family_name,
412
+ "genus": genus_name,
413
+ "species": species_name,
414
+ "family_confidence": float(family_probs[family_idx]),
415
+ "genus_confidence": float(genus_probs[genus_idx]),
416
+ "species_confidence": float(species_probs[species_idx]),
417
+ "family_probs": family_probs.tolist(),
418
+ "genus_probs": genus_probs.tolist(),
419
+ "species_probs": species_probs.tolist()
420
+ }
421
+
422
+ return detection_data
423
+
424
+ def _extract_yolo_detections(self, frame):
425
+ """Extract and validate YOLO detections from frame"""
426
+ with torch.no_grad():
427
+ results = self.yolo_model(frame, conf=self.confidence_threshold, iou=0.5, verbose=False)
428
+
429
+ detections = results[0].boxes
430
+ valid_detections = []
431
+ valid_detection_data = []
432
+
433
+ if detections is not None and len(detections) > 0:
434
+ height, width = frame.shape[:2]
435
+
436
+ for box in detections:
437
+ xyxy = box.xyxy.cpu().numpy().flatten()
438
+ confidence = box.conf.cpu().numpy().item()
439
+
440
+ if confidence < self.confidence_threshold:
441
+ continue
442
+
443
+ x1, y1, x2, y2 = xyxy[:4]
444
+
445
+ # Clamp coordinates
446
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(width, x2), min(height, y2)
447
+
448
+ if x2 <= x1 or y2 <= y1:
449
+ continue
450
+
451
+ # Store detection for tracking (x1, y1, x2, y2 format)
452
+ valid_detections.append([x1, y1, x2, y2])
453
+ valid_detection_data.append({
454
+ 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2,
455
+ 'confidence': confidence
456
+ })
457
+
458
+ return valid_detections, valid_detection_data
459
+
460
+ def _classify_detection(self, frame, x1, y1, x2, y2):
461
+ """Perform hierarchical classification on a detection crop"""
462
+ crop = frame[int(y1):int(y2), int(x1):int(x2)]
463
+ crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
464
+ pil_img = Image.fromarray(crop_rgb)
465
+ input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
466
+
467
+ with torch.no_grad():
468
+ classification_outputs = self.classification_model(input_tensor)
469
+
470
+ return self.process_classification_results(classification_outputs)
471
+
472
+ def _process_single_detection(self, frame, det_data, track_id, frame_time_seconds):
473
+ """Process a single detection: classify, store, and visualize"""
474
+ x1, y1, x2, y2 = det_data['x1'], det_data['y1'], det_data['x2'], det_data['y2']
475
+ height, width = frame.shape[:2]
476
+
477
+ # Perform classification
478
+ detection_data = self._classify_detection(frame, x1, y1, x2, y2)
479
+
480
+ # Store detection for aggregation
481
+ bbox = self.convert_bbox_to_normalized(x1, y1, x2, y2, width, height)
482
+ timestamp = datetime.now().isoformat()
483
+ self.store_detection(detection_data, timestamp, frame_time_seconds, track_id, bbox)
484
+
485
+ # Visualization
486
+ FrameVisualizer.draw_detection_on_frame(frame, x1, y1, x2, y2, track_id, detection_data)
487
+
488
+ def process_frame(self, frame, frame_time_seconds, tracker, global_frame_count):
489
+ """Process a single frame from the video."""
490
+ self.frame_count += 1
491
+
492
+ # Extract YOLO detections
493
+ valid_detections, valid_detection_data = self._extract_yolo_detections(frame)
494
+
495
+ # Update tracker with detections
496
+ track_ids = tracker.update(valid_detections, global_frame_count)
497
+
498
+ # Process each detection with its track ID
499
+ for i, det_data in enumerate(valid_detection_data):
500
+ track_id = track_ids[i] if i < len(track_ids) else None
501
+ self._process_single_detection(frame, det_data, track_id, frame_time_seconds)
502
+
503
+ return frame
504
+
505
+ # ------------------------------------------------------------------------
506
+ # AGGREGATION AND RESULTS METHODS
507
+ # ------------------------------------------------------------------------
508
+
509
+ def hierarchical_aggregation(self):
510
+ """
511
+ Perform hierarchical aggregation of results per track ID.
512
+
513
+ CONFIDENCE CALCULATION EXPLANATION:
514
+ 1. For each track, average the softmax probabilities across all detections
515
+ 2. Use hierarchical selection:
516
+ - Find best family (highest averaged family probability)
517
+ - Find best genus within that family (highest averaged genus probability among genera in that family)
518
+ - Find best species within that genus (highest averaged species probability among species in that genus)
519
+ 3. Final confidence = averaged probability of the selected class at each taxonomic level
520
+ """
521
+ print("\n=== Performing Hierarchical Aggregation ===")
522
+
523
+ # Group detections by track ID
524
+ track_detections = defaultdict(list)
525
+ for detection in self.all_detections:
526
+ if detection['track_id'] is not None:
527
+ track_detections[detection['track_id']].append(detection)
528
+
529
+ aggregated_results = []
530
+
531
+ for track_id, detections in track_detections.items():
532
+ print(f"\nProcessing Track ID: {track_id} ({len(detections)} detections)")
533
+
534
+ # Aggregate probabilities across all detections for this track
535
+ prob_sums = [
536
+ np.zeros(len(self.family_list)),
537
+ np.zeros(len(self.genus_list)),
538
+ np.zeros(len(self.species_list))
539
+ ]
540
+
541
+ for detection in detections:
542
+ prob_sums[0] += np.array(detection['family_probs'])
543
+ prob_sums[1] += np.array(detection['genus_probs'])
544
+ prob_sums[2] += np.array(detection['species_probs'])
545
+
546
+ # Average the probabilities
547
+ prob_avgs = [prob_sum / len(detections) for prob_sum in prob_sums]
548
+
549
+ # Hierarchical selection: Start with family
550
+ best_family_idx = np.argmax(prob_avgs[0])
551
+ best_family = self.family_list[best_family_idx]
552
+ best_family_prob = prob_avgs[0][best_family_idx]
553
+ print(f" Best family: {best_family} (prob: {best_family_prob:.3f})")
554
+
555
+ # Find genera belonging to this family
556
+ family_genera_indices = [i for i, genus in enumerate(self.genus_list)
557
+ if genus in self.genus_to_family and self.genus_to_family[genus] == best_family]
558
+
559
+ if family_genera_indices:
560
+ family_genus_probs = prob_avgs[1][family_genera_indices]
561
+ best_genus_idx = family_genera_indices[np.argmax(family_genus_probs)]
562
+ else:
563
+ best_genus_idx = np.argmax(prob_avgs[1])
564
+
565
+ best_genus = self.genus_list[best_genus_idx]
566
+ best_genus_prob = prob_avgs[1][best_genus_idx]
567
+ print(f" Best genus: {best_genus} (prob: {best_genus_prob:.3f})")
568
+
569
+ # Find species belonging to this genus
570
+ genus_species_indices = [i for i, species in enumerate(self.species_list)
571
+ if species in self.species_to_genus and self.species_to_genus[species] == best_genus]
572
+
573
+ if genus_species_indices:
574
+ genus_species_probs = prob_avgs[2][genus_species_indices]
575
+ best_species_idx = genus_species_indices[np.argmax(genus_species_probs)]
576
+ else:
577
+ best_species_idx = np.argmax(prob_avgs[2])
578
+
579
+ best_species = self.species_list[best_species_idx]
580
+ best_species_prob = prob_avgs[2][best_species_idx]
581
+ print(f" Best species: {best_species} (prob: {best_species_prob:.3f})")
582
+
583
+ # Calculate track statistics
584
+ frame_times = [d['frame_time_seconds'] for d in detections]
585
+ first_frame, last_frame = min(frame_times), max(frame_times)
586
+
587
+ aggregated_results.append({
588
+ 'track_id': track_id,
589
+ 'num_detections': len(detections),
590
+ 'first_frame_time': first_frame,
591
+ 'last_frame_time': last_frame,
592
+ 'duration': last_frame - first_frame,
593
+ 'final_family': best_family,
594
+ 'final_genus': best_genus,
595
+ 'final_species': best_species,
596
+ 'family_confidence': best_family_prob,
597
+ 'genus_confidence': best_genus_prob,
598
+ 'species_confidence': best_species_prob
599
+ })
600
+
601
+ return aggregated_results
602
+
603
+ def print_simplified_summary(self, aggregated_results):
604
+ """Print a simplified, readable summary of tracking results"""
605
+ print("\n" + "="*60)
606
+ print("🐛 INSECT TRACKING SUMMARY")
607
+ print("="*60)
608
+
609
+ if not aggregated_results:
610
+ print("No insects were tracked in this video.")
611
+ return
612
+
613
+ # Sort by number of detections (most active insects first)
614
+ sorted_results = sorted(aggregated_results, key=lambda x: x['num_detections'], reverse=True)
615
+
616
+ for i, result in enumerate(sorted_results, 1):
617
+ duration = result['duration']
618
+ detection_count = result['num_detections']
619
+
620
+ print(f"\n🐞 Insect {i}:")
621
+ print(f" Detections: {detection_count}")
622
+ print(f" Duration: {duration:.1f}s")
623
+ print(f" Family: {result['final_family']}")
624
+ print(f" Genus: {result['final_genus']}")
625
+ print(f" Species: {result['final_species']}")
626
+ print(f" Confidence: {result['species_confidence']:.1%}")
627
+
628
+ print(f"\n📈 Total: {len(aggregated_results)} unique insects tracked")
629
+ print("="*60)
630
+
631
+ def save_results_table(self, aggregated_results, output_path):
632
+ """Save aggregated results to CSV"""
633
+ df = pd.DataFrame(aggregated_results).sort_values('track_id')
634
+ csv_path = str(output_path).replace('.mp4', '_results.csv')
635
+ df.to_csv(csv_path, index=False)
636
+
637
+ print(f"\n📊 Results saved to: {csv_path}")
638
+
639
+ # Print simplified summary
640
+ self.print_simplified_summary(aggregated_results)
641
+
642
+ return df
643
+
644
+ # ============================================================================
645
+ # VIDEO PROCESSING FUNCTIONS
646
+ # ============================================================================
647
+
648
+ def process_video(video_path, processor, output_video_path=None, show_video=False, tracker_max_frames=30, fps=None):
649
+ """Process an MP4 video file frame by frame.
650
+
651
+ Args:
652
+ fps (float, optional): FPS for processing and output. If provided, frames will be skipped
653
+ to match this rate and output video will use this FPS. If None, all frames are processed.
654
+ """
655
+
656
+ if not os.path.exists(video_path):
657
+ raise FileNotFoundError(f"Video file not found: {video_path}")
658
+
659
+ cap = cv2.VideoCapture(video_path)
660
+ if not cap.isOpened():
661
+ raise ValueError(f"Could not open video file: {video_path}")
662
+
663
+ # Get video properties
664
+ input_fps = cap.get(cv2.CAP_PROP_FPS)
665
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
666
+ width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
667
+ duration = total_frames / input_fps if input_fps > 0 else 0
668
+
669
+ print(f"Processing video: {video_path}")
670
+ print(f"Properties: {total_frames} frames, {input_fps:.2f} FPS, {duration:.2f}s duration")
671
+
672
+ # Initialize tracker and output writer
673
+ tracker = InsectTracker(height, width, max_frames=tracker_max_frames, debug=False)
674
+ out = None
675
+
676
+ if output_video_path:
677
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
678
+ write_fps = fps if fps is not None else input_fps
679
+ out = cv2.VideoWriter(output_video_path, fourcc, write_fps, (width, height))
680
+ print(f"Output video: {output_video_path} at {write_fps:.2f} FPS")
681
+
682
+ frame_number = 0
683
+ processed_frame_count = 0
684
+ start_time = time.time()
685
+
686
+ # Calculate frame skip interval if fps is specified
687
+ frame_skip_interval = None
688
+ if fps is not None and fps > 0 and input_fps > 0:
689
+ frame_skip_interval = max(1, int(input_fps / fps))
690
+ print(f"Processing every {frame_skip_interval} frame(s) to achieve ~{fps:.2f} FPS")
691
+ print(f"Tracker will only receive the processed frames, not all {total_frames} frames")
692
+
693
+ try:
694
+ while True:
695
+ ret, frame = cap.read()
696
+ if not ret:
697
+ break
698
+
699
+ frame_time_seconds = frame_number / input_fps if input_fps > 0 else 0
700
+
701
+ # Check if we should process this frame
702
+ should_process = True
703
+ if frame_skip_interval is not None:
704
+ should_process = (frame_number % frame_skip_interval == 0)
705
+
706
+ if should_process:
707
+ # Process the frame
708
+ processed_frame = processor.process_frame(frame, frame_time_seconds, tracker, frame_number)
709
+ processed_frame_count += 1
710
+ last_processed_frame = processed_frame.copy() # Keep for frame duplication
711
+
712
+ # Show video if requested
713
+ if show_video:
714
+ cv2.imshow('Video Inference', processed_frame)
715
+ if cv2.waitKey(1) & 0xFF == ord('q'):
716
+ print("User requested quit")
717
+ break
718
+
719
+ # If we're skipping frames, skip ahead to save time
720
+ if frame_skip_interval is not None and frame_skip_interval > 1:
721
+ # Skip the next (frame_skip_interval - 1) frames
722
+ for _ in range(frame_skip_interval - 1):
723
+ ret, _ = cap.read()
724
+ if not ret:
725
+ break
726
+ frame_number += 1
727
+ else:
728
+ # Use the last processed frame to maintain video length
729
+ if 'last_processed_frame' in locals():
730
+ processed_frame = last_processed_frame.copy()
731
+ else:
732
+ processed_frame = frame # Fallback for first frames
733
+
734
+ # Write to output video if requested (always write to maintain duration)
735
+ if out:
736
+ out.write(processed_frame)
737
+
738
+ frame_number += 1
739
+
740
+ # Progress update based on processed frames only
741
+ if should_process and processed_frame_count % 25 == 0:
742
+ if frame_skip_interval is not None:
743
+ estimated_total_processed = total_frames // frame_skip_interval
744
+ progress = (processed_frame_count / estimated_total_processed) * 100 if estimated_total_processed > 0 else 0
745
+ print(f"Processed: {processed_frame_count}/{estimated_total_processed} frames ({progress:.1f}%)")
746
+ else:
747
+ progress = (frame_number / total_frames) * 100 if total_frames > 0 else 0
748
+ print(f"Processed: {processed_frame_count}/{total_frames} frames ({progress:.1f}%)")
749
+
750
+ finally:
751
+ cap.release()
752
+ if out:
753
+ out.release()
754
+ if show_video:
755
+ cv2.destroyAllWindows()
756
+
757
+ processing_time = time.time() - start_time
758
+ print(f"\nProcessing complete! {processed_frame_count}/{frame_number} frames processed in {processing_time:.2f}s")
759
+ if processed_frame_count > 0:
760
+ print(f"Processing speed: {processed_frame_count/processing_time:.2f} FPS, Detections: {len(processor.all_detections)}")
761
+ else:
762
+ print(f"No frames processed, Detections: {len(processor.all_detections)}")
763
+
764
+ # Perform hierarchical aggregation and save results
765
+ aggregated_results = processor.hierarchical_aggregation()
766
+ if output_video_path:
767
+ processor.save_results_table(aggregated_results, output_video_path)
768
+
769
+ return aggregated_results
770
+
771
+ # ============================================================================
772
+ # MAIN ENTRY POINT FUNCTIONS
773
+ # ============================================================================
774
+
775
+ def inference(species_list, yolo_model_path, hierarchical_model_path, confidence_threshold,
776
+ video_path, output_path, tracker_max_frames, fps=None):
777
+ """
778
+ Run inference on a single video file.
779
+
780
+ Args:
781
+ species_list (list): List of species names for classification
782
+ yolo_model_path (str): Path to YOLO model weights
783
+ hierarchical_model_path (str): Path to hierarchical classification model weights
784
+ confidence_threshold (float): Confidence threshold for detections
785
+ video_path (str): Path to input video file
786
+ output_path (str): Path for output video file (including filename)
787
+ tracker_max_frames (int): Maximum frames for tracker context
788
+ fps (float, optional): Processing and output FPS. If provided, frames will be skipped
789
+ to match this rate and output video will use this FPS. If None, all frames are processed.
790
+
791
+ Returns:
792
+ dict: Summary of processing results
793
+ """
794
+ # Check if input video exists
795
+ if not os.path.exists(video_path):
796
+ error_msg = f"Video file not found: {video_path}"
797
+ print(f"Error: {error_msg}")
798
+ return {"error": error_msg}
799
+
800
+ # Create output directory if it doesn't exist
801
+ output_dir = os.path.dirname(output_path)
802
+ if output_dir:
803
+ os.makedirs(output_dir, exist_ok=True)
804
+
805
+ print(f"Processing single video: {video_path}")
806
+
807
+ # Create processor instance
808
+ print("Initializing models...")
809
+ processor = VideoInferenceProcessor(
810
+ species_list=species_list,
811
+ yolo_model_path=yolo_model_path,
812
+ hierarchical_model_path=hierarchical_model_path,
813
+ confidence_threshold=confidence_threshold
814
+ )
815
+
816
+ # Track processing results
817
+ processing_results = {
818
+ "video_file": os.path.basename(video_path),
819
+ "output_path": output_path,
820
+ "success": False,
821
+ "detections": 0,
822
+ "tracks": 0,
823
+ "error": None
824
+ }
825
+
826
+ print(f"\n{'='*20}\nProcessing: {video_path}\n{'='*20}")
827
+
828
+ try:
829
+ # Process the video
830
+ aggregated_results = process_video(
831
+ video_path=video_path,
832
+ processor=processor,
833
+ output_video_path=output_path,
834
+ show_video=False,
835
+ tracker_max_frames=tracker_max_frames,
836
+ fps=fps
837
+ )
838
+
839
+ processing_results.update({
840
+ "success": True,
841
+ "detections": len(processor.all_detections),
842
+ "tracks": len(aggregated_results)
843
+ })
844
+
845
+ print(f"Finished processing. Output saved to {output_path}")
846
+
847
+ except Exception as e:
848
+ error_msg = f"Failed to process {os.path.basename(video_path)}: {str(e)}"
849
+ print(f"Error: {error_msg}")
850
+ processing_results["error"] = error_msg
851
+
852
+ if processing_results["success"]:
853
+ print(f"\nProcessing complete!")
854
+ print(f"Total detections: {processing_results['detections']}")
855
+ print(f"Total tracks: {processing_results['tracks']}")
856
+
857
+ return processing_results
858
+
859
+
860
+ def main():
861
+ """Example usage for processing a single video."""
862
+ # Define your species list (replace with your actual species)
863
+ species_list = [
864
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
865
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
866
+ "Eristalis tenax"
867
+ ]
868
+
869
+ # Paths (replace with your actual paths)
870
+ video_path = "input_videos/sample_video.mp4"
871
+ output_path = "output_videos/sample_video_predictions.mp4"
872
+ yolo_model_path = "weights/yolo_model.pt"
873
+ hierarchical_model_path = "weights/hierarchical_model.pth"
874
+
875
+ # Run inference
876
+ results = inference(
877
+ species_list=species_list,
878
+ yolo_model_path=yolo_model_path,
879
+ hierarchical_model_path=hierarchical_model_path,
880
+ confidence_threshold=0.35,
881
+ video_path=video_path,
882
+ output_path=output_path,
883
+ tracker_max_frames=60,
884
+ fps=None # Set to e.g. 5.0 to process only 5 frames per second
885
+ )
886
+
887
+ return results
888
+
889
+
890
+ if __name__ == "__main__":
891
+ main()