bplusplus 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +3 -5
- bplusplus/inference.py +891 -0
- bplusplus/prepare.py +419 -652
- bplusplus/{hierarchical/test.py → test.py} +22 -9
- bplusplus/tracker.py +261 -0
- bplusplus/{hierarchical/train.py → train.py} +1 -1
- bplusplus-1.2.3.dist-info/METADATA +101 -0
- bplusplus-1.2.3.dist-info/RECORD +11 -0
- {bplusplus-1.2.2.dist-info → bplusplus-1.2.3.dist-info}/WHEEL +1 -1
- bplusplus/resnet/test.py +0 -473
- bplusplus/resnet/train.py +0 -329
- bplusplus/train_validate.py +0 -11
- bplusplus-1.2.2.dist-info/METADATA +0 -260
- bplusplus-1.2.2.dist-info/RECORD +0 -12
- {bplusplus-1.2.2.dist-info → bplusplus-1.2.3.dist-info}/LICENSE +0 -0
bplusplus/inference.py
ADDED
|
@@ -0,0 +1,891 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import time
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import numpy as np
|
|
6
|
+
import json
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from .tracker import InsectTracker
|
|
11
|
+
import torch
|
|
12
|
+
from ultralytics import YOLO
|
|
13
|
+
from torchvision import transforms
|
|
14
|
+
from PIL import Image
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from torchvision.models import resnet50
|
|
17
|
+
import requests
|
|
18
|
+
import logging
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
import uuid
|
|
21
|
+
|
|
22
|
+
# Set up logging
|
|
23
|
+
logging.basicConfig(level=logging.INFO)
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# ============================================================================
|
|
27
|
+
# UTILITY FUNCTIONS
|
|
28
|
+
# ============================================================================
|
|
29
|
+
|
|
30
|
+
def get_taxonomy(species_list):
|
|
31
|
+
"""
|
|
32
|
+
Retrieves taxonomic information for a list of species from GBIF API.
|
|
33
|
+
Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
|
|
34
|
+
"""
|
|
35
|
+
taxonomy = {1: [], 2: {}, 3: {}}
|
|
36
|
+
species_to_genus = {}
|
|
37
|
+
genus_to_family = {}
|
|
38
|
+
|
|
39
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
|
|
40
|
+
|
|
41
|
+
print(f"\n{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
|
|
42
|
+
print("-" * 80)
|
|
43
|
+
|
|
44
|
+
for species_name in species_list:
|
|
45
|
+
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
46
|
+
try:
|
|
47
|
+
response = requests.get(url)
|
|
48
|
+
data = response.json()
|
|
49
|
+
|
|
50
|
+
if data.get('status') in ['ACCEPTED', 'SYNONYM']:
|
|
51
|
+
family = data.get('family')
|
|
52
|
+
genus = data.get('genus')
|
|
53
|
+
|
|
54
|
+
if family and genus:
|
|
55
|
+
print(f"{species_name:<30} {family:<20} {genus:<20} OK")
|
|
56
|
+
|
|
57
|
+
species_to_genus[species_name] = genus
|
|
58
|
+
genus_to_family[genus] = family
|
|
59
|
+
|
|
60
|
+
if family not in taxonomy[1]:
|
|
61
|
+
taxonomy[1].append(family)
|
|
62
|
+
|
|
63
|
+
taxonomy[2][genus] = family
|
|
64
|
+
taxonomy[3][species_name] = genus
|
|
65
|
+
else:
|
|
66
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
67
|
+
logger.error(f"Species '{species_name}' found but missing family/genus data")
|
|
68
|
+
else:
|
|
69
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
70
|
+
logger.error(f"Species '{species_name}' not found in GBIF")
|
|
71
|
+
|
|
72
|
+
except Exception as e:
|
|
73
|
+
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
74
|
+
logger.error(f"Error retrieving data for '{species_name}': {str(e)}")
|
|
75
|
+
|
|
76
|
+
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
77
|
+
print("-" * 80)
|
|
78
|
+
|
|
79
|
+
# Print indices
|
|
80
|
+
for level, name, items in [(1, "Family", taxonomy[1]), (2, "Genus", taxonomy[2].keys()), (3, "Species", species_list)]:
|
|
81
|
+
print(f"\n{name} indices:")
|
|
82
|
+
for i, item in enumerate(items):
|
|
83
|
+
print(f" {i}: {item}")
|
|
84
|
+
|
|
85
|
+
logger.info(f"Taxonomy built: {len(taxonomy[1])} families, {len(taxonomy[2])} genera, {len(taxonomy[3])} species")
|
|
86
|
+
return taxonomy, species_to_genus, genus_to_family
|
|
87
|
+
|
|
88
|
+
def create_mappings(taxonomy):
|
|
89
|
+
"""Create index mappings from taxonomy"""
|
|
90
|
+
level_to_idx = {}
|
|
91
|
+
idx_to_level = {}
|
|
92
|
+
|
|
93
|
+
for level, labels in taxonomy.items():
|
|
94
|
+
if isinstance(labels, list):
|
|
95
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
96
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
|
|
97
|
+
else: # Dictionary
|
|
98
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
|
|
99
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(labels.keys())}
|
|
100
|
+
|
|
101
|
+
return level_to_idx, idx_to_level
|
|
102
|
+
|
|
103
|
+
# ============================================================================
|
|
104
|
+
# MODEL CLASSES
|
|
105
|
+
# ============================================================================
|
|
106
|
+
|
|
107
|
+
class HierarchicalInsectClassifier(nn.Module):
|
|
108
|
+
def __init__(self, num_classes_per_level):
|
|
109
|
+
"""
|
|
110
|
+
Args:
|
|
111
|
+
num_classes_per_level (list): Number of classes for each taxonomic level [family, genus, species]
|
|
112
|
+
"""
|
|
113
|
+
super(HierarchicalInsectClassifier, self).__init__()
|
|
114
|
+
|
|
115
|
+
self.backbone = resnet50(pretrained=True)
|
|
116
|
+
backbone_output_features = self.backbone.fc.in_features
|
|
117
|
+
self.backbone.fc = nn.Identity() # Remove the final fully connected layer
|
|
118
|
+
|
|
119
|
+
self.branches = nn.ModuleList()
|
|
120
|
+
for num_classes in num_classes_per_level:
|
|
121
|
+
branch = nn.Sequential(
|
|
122
|
+
nn.Linear(backbone_output_features, 512),
|
|
123
|
+
nn.ReLU(),
|
|
124
|
+
nn.Dropout(0.5),
|
|
125
|
+
nn.Linear(512, num_classes)
|
|
126
|
+
)
|
|
127
|
+
self.branches.append(branch)
|
|
128
|
+
|
|
129
|
+
self.num_levels = len(num_classes_per_level)
|
|
130
|
+
|
|
131
|
+
def forward(self, x):
|
|
132
|
+
R0 = self.backbone(x)
|
|
133
|
+
|
|
134
|
+
outputs = []
|
|
135
|
+
for branch in self.branches:
|
|
136
|
+
outputs.append(branch(R0))
|
|
137
|
+
|
|
138
|
+
return outputs
|
|
139
|
+
|
|
140
|
+
# ============================================================================
|
|
141
|
+
# VISUALIZATION UTILITIES
|
|
142
|
+
# ============================================================================
|
|
143
|
+
|
|
144
|
+
class FrameVisualizer:
|
|
145
|
+
"""Modern, slick visualization system for insect detection and tracking"""
|
|
146
|
+
|
|
147
|
+
# Modern color palette - vibrant but professional
|
|
148
|
+
COLORS = [
|
|
149
|
+
(68, 189, 50), # Vibrant Green
|
|
150
|
+
(255, 59, 48), # Red
|
|
151
|
+
(0, 122, 255), # Blue
|
|
152
|
+
(255, 149, 0), # Orange
|
|
153
|
+
(175, 82, 222), # Purple
|
|
154
|
+
(255, 204, 0), # Yellow
|
|
155
|
+
(50, 173, 230), # Light Blue
|
|
156
|
+
(255, 45, 85), # Pink
|
|
157
|
+
(48, 209, 88), # Light Green
|
|
158
|
+
(90, 200, 250), # Sky Blue
|
|
159
|
+
(255, 159, 10), # Amber
|
|
160
|
+
(191, 90, 242), # Lavender
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def get_track_color(track_id):
|
|
165
|
+
"""Get a consistent, vibrant color for a track ID"""
|
|
166
|
+
if track_id is None:
|
|
167
|
+
return (68, 189, 50) # Default green for untracked
|
|
168
|
+
|
|
169
|
+
# Generate consistent index from track_id
|
|
170
|
+
try:
|
|
171
|
+
track_uuid = uuid.UUID(track_id)
|
|
172
|
+
except (ValueError, TypeError):
|
|
173
|
+
track_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(track_id))
|
|
174
|
+
|
|
175
|
+
color_index = track_uuid.int % len(FrameVisualizer.COLORS)
|
|
176
|
+
return FrameVisualizer.COLORS[color_index]
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def draw_rounded_rectangle(frame, pt1, pt2, color, thickness, radius=8):
|
|
180
|
+
"""Draw a rounded rectangle"""
|
|
181
|
+
x1, y1 = pt1
|
|
182
|
+
x2, y2 = pt2
|
|
183
|
+
|
|
184
|
+
# Draw main rectangle
|
|
185
|
+
cv2.rectangle(frame, (x1 + radius, y1), (x2 - radius, y2), color, thickness)
|
|
186
|
+
cv2.rectangle(frame, (x1, y1 + radius), (x2, y2 - radius), color, thickness)
|
|
187
|
+
|
|
188
|
+
# Draw corners
|
|
189
|
+
cv2.circle(frame, (x1 + radius, y1 + radius), radius, color, thickness)
|
|
190
|
+
cv2.circle(frame, (x2 - radius, y1 + radius), radius, color, thickness)
|
|
191
|
+
cv2.circle(frame, (x1 + radius, y2 - radius), radius, color, thickness)
|
|
192
|
+
cv2.circle(frame, (x2 - radius, y2 - radius), radius, color, thickness)
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def draw_gradient_background(frame, x1, y1, x2, y2, color, alpha=0.85):
|
|
196
|
+
"""Draw a modern gradient background with rounded corners"""
|
|
197
|
+
overlay = frame.copy()
|
|
198
|
+
|
|
199
|
+
# Create gradient effect
|
|
200
|
+
height = y2 - y1
|
|
201
|
+
for i in range(height):
|
|
202
|
+
intensity = 1.0 - (i / height) * 0.3 # Gradient from top to bottom
|
|
203
|
+
gradient_color = tuple(int(c * intensity) for c in color)
|
|
204
|
+
cv2.rectangle(overlay, (x1, y1 + i), (x2, y1 + i + 1), gradient_color, -1)
|
|
205
|
+
|
|
206
|
+
# Blend with original frame
|
|
207
|
+
cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def draw_detection_on_frame(frame, x1, y1, x2, y2, track_id, detection_data):
|
|
211
|
+
"""Draw modern, sleek detection visualization"""
|
|
212
|
+
|
|
213
|
+
# Get colors
|
|
214
|
+
primary_color = FrameVisualizer.get_track_color(track_id)
|
|
215
|
+
|
|
216
|
+
# Simple, clean bounding box
|
|
217
|
+
box_thickness = 2
|
|
218
|
+
|
|
219
|
+
# Draw single clean rectangle
|
|
220
|
+
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), primary_color, box_thickness)
|
|
221
|
+
|
|
222
|
+
# Prepare label content without emojis
|
|
223
|
+
if track_id is not None:
|
|
224
|
+
track_short = str(track_id)[:8]
|
|
225
|
+
track_display = f"ID: {track_short}"
|
|
226
|
+
else:
|
|
227
|
+
track_display = "NEW"
|
|
228
|
+
|
|
229
|
+
# Classification results without icons
|
|
230
|
+
classification_lines = []
|
|
231
|
+
|
|
232
|
+
for level, key in [("family", "family_confidence"),
|
|
233
|
+
("genus", "genus_confidence"),
|
|
234
|
+
("species", "species_confidence")]:
|
|
235
|
+
if detection_data.get(level):
|
|
236
|
+
conf = detection_data.get(key, 0)
|
|
237
|
+
name = detection_data[level]
|
|
238
|
+
# Truncate long names
|
|
239
|
+
if len(name) > 18:
|
|
240
|
+
name = name[:15] + "..."
|
|
241
|
+
level_short = level[0].upper() # F, G, S
|
|
242
|
+
classification_lines.append(f"{level_short}: {name}")
|
|
243
|
+
classification_lines.append(f" {conf:.1%}")
|
|
244
|
+
|
|
245
|
+
if not classification_lines and track_id is None:
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
# Calculate label box dimensions with smaller, lighter font
|
|
249
|
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
250
|
+
font_scale = 0.45
|
|
251
|
+
thickness = 1
|
|
252
|
+
padding = 8
|
|
253
|
+
line_spacing = 6
|
|
254
|
+
|
|
255
|
+
# Calculate text dimensions
|
|
256
|
+
all_lines = [track_display] + classification_lines
|
|
257
|
+
text_sizes = [cv2.getTextSize(line, font, font_scale, thickness)[0] for line in all_lines]
|
|
258
|
+
max_w = max(size[0] for size in text_sizes) if text_sizes else 100
|
|
259
|
+
text_h = text_sizes[0][1] if text_sizes else 20
|
|
260
|
+
|
|
261
|
+
total_h = len(all_lines) * (text_h + line_spacing) + padding * 2
|
|
262
|
+
label_w = max_w + padding * 2
|
|
263
|
+
|
|
264
|
+
# Position label box (above bbox, or below if no space)
|
|
265
|
+
label_x1 = max(0, int(x1))
|
|
266
|
+
label_y1 = max(0, int(y1) - total_h - 5)
|
|
267
|
+
if label_y1 < 0:
|
|
268
|
+
label_y1 = int(y2) + 5
|
|
269
|
+
|
|
270
|
+
label_x2 = min(frame.shape[1], label_x1 + label_w)
|
|
271
|
+
label_y2 = min(frame.shape[0], label_y1 + total_h)
|
|
272
|
+
|
|
273
|
+
# Draw modern gradient background with rounded corners
|
|
274
|
+
FrameVisualizer.draw_gradient_background(frame, label_x1, label_y1, label_x2, label_y2,
|
|
275
|
+
(20, 20, 20), alpha=0.88)
|
|
276
|
+
|
|
277
|
+
# Add subtle border
|
|
278
|
+
FrameVisualizer.draw_rounded_rectangle(frame,
|
|
279
|
+
(label_x1, label_y1),
|
|
280
|
+
(label_x2, label_y2),
|
|
281
|
+
primary_color, 1, radius=6)
|
|
282
|
+
|
|
283
|
+
# Draw text with modern styling
|
|
284
|
+
current_y = label_y1 + padding + text_h
|
|
285
|
+
|
|
286
|
+
for i, line in enumerate(all_lines):
|
|
287
|
+
if i == 0: # Track ID line - use primary color
|
|
288
|
+
text_color = primary_color
|
|
289
|
+
line_thickness = 1
|
|
290
|
+
elif "%" in line: # Confidence lines - use lighter color
|
|
291
|
+
text_color = (160, 160, 160)
|
|
292
|
+
line_thickness = 1
|
|
293
|
+
else: # Classification name lines - use white
|
|
294
|
+
text_color = (255, 255, 255)
|
|
295
|
+
line_thickness = 1
|
|
296
|
+
|
|
297
|
+
# Add subtle text shadow for readability
|
|
298
|
+
cv2.putText(frame, line, (label_x1 + padding + 1, current_y + 1),
|
|
299
|
+
font, font_scale, (0, 0, 0), 1, cv2.LINE_AA)
|
|
300
|
+
|
|
301
|
+
# Main text
|
|
302
|
+
cv2.putText(frame, line, (label_x1 + padding, current_y),
|
|
303
|
+
font, font_scale, text_color, line_thickness, cv2.LINE_AA)
|
|
304
|
+
|
|
305
|
+
current_y += text_h + line_spacing
|
|
306
|
+
|
|
307
|
+
# ============================================================================
|
|
308
|
+
# MAIN PROCESSING CLASS
|
|
309
|
+
# ============================================================================
|
|
310
|
+
|
|
311
|
+
class VideoInferenceProcessor:
|
|
312
|
+
def __init__(self, species_list, yolo_model_path, hierarchical_model_path,
|
|
313
|
+
confidence_threshold=0.35, device_id="video_processor"):
|
|
314
|
+
|
|
315
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
316
|
+
self.species_list = species_list
|
|
317
|
+
self.confidence_threshold = confidence_threshold
|
|
318
|
+
self.device_id = device_id
|
|
319
|
+
|
|
320
|
+
print(f"Using device: {self.device}")
|
|
321
|
+
|
|
322
|
+
# Build taxonomy from species list
|
|
323
|
+
self.taxonomy, self.species_to_genus, self.genus_to_family = get_taxonomy(species_list)
|
|
324
|
+
self.level_to_idx, self.idx_to_level = create_mappings(self.taxonomy)
|
|
325
|
+
self.family_list = self.taxonomy[1]
|
|
326
|
+
self.genus_list = list(self.taxonomy[2].keys())
|
|
327
|
+
|
|
328
|
+
# Load models
|
|
329
|
+
print(f"Loading YOLO model from {yolo_model_path}")
|
|
330
|
+
self.yolo_model = YOLO(yolo_model_path)
|
|
331
|
+
|
|
332
|
+
print(f"Loading hierarchical model from {hierarchical_model_path}")
|
|
333
|
+
checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
|
|
334
|
+
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
|
335
|
+
|
|
336
|
+
num_classes_per_level = [len(self.family_list), len(self.genus_list), len(self.species_list)]
|
|
337
|
+
print(f"Model architecture: {num_classes_per_level} classes per level")
|
|
338
|
+
|
|
339
|
+
self.classification_model = HierarchicalInsectClassifier(num_classes_per_level)
|
|
340
|
+
self.classification_model.load_state_dict(state_dict, strict=False)
|
|
341
|
+
self.classification_model.to(self.device)
|
|
342
|
+
self.classification_model.eval()
|
|
343
|
+
|
|
344
|
+
# Classification preprocessing
|
|
345
|
+
self.classification_transform = transforms.Compose([
|
|
346
|
+
transforms.Resize((768, 768)),
|
|
347
|
+
transforms.CenterCrop(640),
|
|
348
|
+
transforms.ToTensor(),
|
|
349
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
350
|
+
])
|
|
351
|
+
|
|
352
|
+
self.all_detections = []
|
|
353
|
+
self.frame_count = 0
|
|
354
|
+
print("Models loaded successfully!")
|
|
355
|
+
|
|
356
|
+
# ------------------------------------------------------------------------
|
|
357
|
+
# UTILITY METHODS
|
|
358
|
+
# ------------------------------------------------------------------------
|
|
359
|
+
|
|
360
|
+
def convert_bbox_to_normalized(self, x, y, x2, y2, width, height):
|
|
361
|
+
x_center = (x + x2) / 2.0 / width
|
|
362
|
+
y_center = (y + y2) / 2.0 / height
|
|
363
|
+
norm_width = (x2 - x) / width
|
|
364
|
+
norm_height = (y2 - y) / height
|
|
365
|
+
return [x_center, y_center, norm_width, norm_height]
|
|
366
|
+
|
|
367
|
+
def store_detection(self, detection_data, timestamp, frame_time_seconds, track_id, bbox):
|
|
368
|
+
"""Store detection for final aggregation"""
|
|
369
|
+
payload = {
|
|
370
|
+
"timestamp": timestamp,
|
|
371
|
+
"frame_time_seconds": frame_time_seconds,
|
|
372
|
+
"track_id": track_id,
|
|
373
|
+
"bbox": bbox,
|
|
374
|
+
**detection_data
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
self.all_detections.append(payload)
|
|
378
|
+
|
|
379
|
+
# Print frame-by-frame prediction (only for processed frames)
|
|
380
|
+
track_display = str(track_id)[:8] if track_id else "NEW"
|
|
381
|
+
species = payload.get('species', 'Unknown')
|
|
382
|
+
species_conf = payload.get('species_confidence', 0)
|
|
383
|
+
print(f"Processed {frame_time_seconds:6.2f}s | Track {track_display} | {species} ({species_conf:.1%})")
|
|
384
|
+
|
|
385
|
+
# ------------------------------------------------------------------------
|
|
386
|
+
# DETECTION AND CLASSIFICATION METHODS
|
|
387
|
+
# ------------------------------------------------------------------------
|
|
388
|
+
|
|
389
|
+
def process_classification_results(self, classification_outputs):
|
|
390
|
+
"""Process raw classification outputs to get predictions and probabilities"""
|
|
391
|
+
family_output = classification_outputs[0].cpu().numpy().flatten()
|
|
392
|
+
genus_output = classification_outputs[1].cpu().numpy().flatten()
|
|
393
|
+
species_output = classification_outputs[2].cpu().numpy().flatten()
|
|
394
|
+
|
|
395
|
+
# Apply softmax to get probabilities
|
|
396
|
+
family_probs = torch.softmax(classification_outputs[0], dim=1).cpu().numpy().flatten()
|
|
397
|
+
genus_probs = torch.softmax(classification_outputs[1], dim=1).cpu().numpy().flatten()
|
|
398
|
+
species_probs = torch.softmax(classification_outputs[2], dim=1).cpu().numpy().flatten()
|
|
399
|
+
|
|
400
|
+
# Get top predictions
|
|
401
|
+
family_idx = np.argmax(family_probs)
|
|
402
|
+
genus_idx = np.argmax(genus_probs)
|
|
403
|
+
species_idx = np.argmax(species_probs)
|
|
404
|
+
|
|
405
|
+
# Get names
|
|
406
|
+
family_name = self.family_list[family_idx] if family_idx < len(self.family_list) else f"Family_{family_idx}"
|
|
407
|
+
genus_name = self.genus_list[genus_idx] if genus_idx < len(self.genus_list) else f"Genus_{genus_idx}"
|
|
408
|
+
species_name = self.species_list[species_idx] if species_idx < len(self.species_list) else f"Species_{species_idx}"
|
|
409
|
+
|
|
410
|
+
detection_data = {
|
|
411
|
+
"family": family_name,
|
|
412
|
+
"genus": genus_name,
|
|
413
|
+
"species": species_name,
|
|
414
|
+
"family_confidence": float(family_probs[family_idx]),
|
|
415
|
+
"genus_confidence": float(genus_probs[genus_idx]),
|
|
416
|
+
"species_confidence": float(species_probs[species_idx]),
|
|
417
|
+
"family_probs": family_probs.tolist(),
|
|
418
|
+
"genus_probs": genus_probs.tolist(),
|
|
419
|
+
"species_probs": species_probs.tolist()
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
return detection_data
|
|
423
|
+
|
|
424
|
+
def _extract_yolo_detections(self, frame):
|
|
425
|
+
"""Extract and validate YOLO detections from frame"""
|
|
426
|
+
with torch.no_grad():
|
|
427
|
+
results = self.yolo_model(frame, conf=self.confidence_threshold, iou=0.5, verbose=False)
|
|
428
|
+
|
|
429
|
+
detections = results[0].boxes
|
|
430
|
+
valid_detections = []
|
|
431
|
+
valid_detection_data = []
|
|
432
|
+
|
|
433
|
+
if detections is not None and len(detections) > 0:
|
|
434
|
+
height, width = frame.shape[:2]
|
|
435
|
+
|
|
436
|
+
for box in detections:
|
|
437
|
+
xyxy = box.xyxy.cpu().numpy().flatten()
|
|
438
|
+
confidence = box.conf.cpu().numpy().item()
|
|
439
|
+
|
|
440
|
+
if confidence < self.confidence_threshold:
|
|
441
|
+
continue
|
|
442
|
+
|
|
443
|
+
x1, y1, x2, y2 = xyxy[:4]
|
|
444
|
+
|
|
445
|
+
# Clamp coordinates
|
|
446
|
+
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(width, x2), min(height, y2)
|
|
447
|
+
|
|
448
|
+
if x2 <= x1 or y2 <= y1:
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
# Store detection for tracking (x1, y1, x2, y2 format)
|
|
452
|
+
valid_detections.append([x1, y1, x2, y2])
|
|
453
|
+
valid_detection_data.append({
|
|
454
|
+
'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2,
|
|
455
|
+
'confidence': confidence
|
|
456
|
+
})
|
|
457
|
+
|
|
458
|
+
return valid_detections, valid_detection_data
|
|
459
|
+
|
|
460
|
+
def _classify_detection(self, frame, x1, y1, x2, y2):
|
|
461
|
+
"""Perform hierarchical classification on a detection crop"""
|
|
462
|
+
crop = frame[int(y1):int(y2), int(x1):int(x2)]
|
|
463
|
+
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
|
464
|
+
pil_img = Image.fromarray(crop_rgb)
|
|
465
|
+
input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
|
|
466
|
+
|
|
467
|
+
with torch.no_grad():
|
|
468
|
+
classification_outputs = self.classification_model(input_tensor)
|
|
469
|
+
|
|
470
|
+
return self.process_classification_results(classification_outputs)
|
|
471
|
+
|
|
472
|
+
def _process_single_detection(self, frame, det_data, track_id, frame_time_seconds):
|
|
473
|
+
"""Process a single detection: classify, store, and visualize"""
|
|
474
|
+
x1, y1, x2, y2 = det_data['x1'], det_data['y1'], det_data['x2'], det_data['y2']
|
|
475
|
+
height, width = frame.shape[:2]
|
|
476
|
+
|
|
477
|
+
# Perform classification
|
|
478
|
+
detection_data = self._classify_detection(frame, x1, y1, x2, y2)
|
|
479
|
+
|
|
480
|
+
# Store detection for aggregation
|
|
481
|
+
bbox = self.convert_bbox_to_normalized(x1, y1, x2, y2, width, height)
|
|
482
|
+
timestamp = datetime.now().isoformat()
|
|
483
|
+
self.store_detection(detection_data, timestamp, frame_time_seconds, track_id, bbox)
|
|
484
|
+
|
|
485
|
+
# Visualization
|
|
486
|
+
FrameVisualizer.draw_detection_on_frame(frame, x1, y1, x2, y2, track_id, detection_data)
|
|
487
|
+
|
|
488
|
+
def process_frame(self, frame, frame_time_seconds, tracker, global_frame_count):
|
|
489
|
+
"""Process a single frame from the video."""
|
|
490
|
+
self.frame_count += 1
|
|
491
|
+
|
|
492
|
+
# Extract YOLO detections
|
|
493
|
+
valid_detections, valid_detection_data = self._extract_yolo_detections(frame)
|
|
494
|
+
|
|
495
|
+
# Update tracker with detections
|
|
496
|
+
track_ids = tracker.update(valid_detections, global_frame_count)
|
|
497
|
+
|
|
498
|
+
# Process each detection with its track ID
|
|
499
|
+
for i, det_data in enumerate(valid_detection_data):
|
|
500
|
+
track_id = track_ids[i] if i < len(track_ids) else None
|
|
501
|
+
self._process_single_detection(frame, det_data, track_id, frame_time_seconds)
|
|
502
|
+
|
|
503
|
+
return frame
|
|
504
|
+
|
|
505
|
+
# ------------------------------------------------------------------------
|
|
506
|
+
# AGGREGATION AND RESULTS METHODS
|
|
507
|
+
# ------------------------------------------------------------------------
|
|
508
|
+
|
|
509
|
+
def hierarchical_aggregation(self):
|
|
510
|
+
"""
|
|
511
|
+
Perform hierarchical aggregation of results per track ID.
|
|
512
|
+
|
|
513
|
+
CONFIDENCE CALCULATION EXPLANATION:
|
|
514
|
+
1. For each track, average the softmax probabilities across all detections
|
|
515
|
+
2. Use hierarchical selection:
|
|
516
|
+
- Find best family (highest averaged family probability)
|
|
517
|
+
- Find best genus within that family (highest averaged genus probability among genera in that family)
|
|
518
|
+
- Find best species within that genus (highest averaged species probability among species in that genus)
|
|
519
|
+
3. Final confidence = averaged probability of the selected class at each taxonomic level
|
|
520
|
+
"""
|
|
521
|
+
print("\n=== Performing Hierarchical Aggregation ===")
|
|
522
|
+
|
|
523
|
+
# Group detections by track ID
|
|
524
|
+
track_detections = defaultdict(list)
|
|
525
|
+
for detection in self.all_detections:
|
|
526
|
+
if detection['track_id'] is not None:
|
|
527
|
+
track_detections[detection['track_id']].append(detection)
|
|
528
|
+
|
|
529
|
+
aggregated_results = []
|
|
530
|
+
|
|
531
|
+
for track_id, detections in track_detections.items():
|
|
532
|
+
print(f"\nProcessing Track ID: {track_id} ({len(detections)} detections)")
|
|
533
|
+
|
|
534
|
+
# Aggregate probabilities across all detections for this track
|
|
535
|
+
prob_sums = [
|
|
536
|
+
np.zeros(len(self.family_list)),
|
|
537
|
+
np.zeros(len(self.genus_list)),
|
|
538
|
+
np.zeros(len(self.species_list))
|
|
539
|
+
]
|
|
540
|
+
|
|
541
|
+
for detection in detections:
|
|
542
|
+
prob_sums[0] += np.array(detection['family_probs'])
|
|
543
|
+
prob_sums[1] += np.array(detection['genus_probs'])
|
|
544
|
+
prob_sums[2] += np.array(detection['species_probs'])
|
|
545
|
+
|
|
546
|
+
# Average the probabilities
|
|
547
|
+
prob_avgs = [prob_sum / len(detections) for prob_sum in prob_sums]
|
|
548
|
+
|
|
549
|
+
# Hierarchical selection: Start with family
|
|
550
|
+
best_family_idx = np.argmax(prob_avgs[0])
|
|
551
|
+
best_family = self.family_list[best_family_idx]
|
|
552
|
+
best_family_prob = prob_avgs[0][best_family_idx]
|
|
553
|
+
print(f" Best family: {best_family} (prob: {best_family_prob:.3f})")
|
|
554
|
+
|
|
555
|
+
# Find genera belonging to this family
|
|
556
|
+
family_genera_indices = [i for i, genus in enumerate(self.genus_list)
|
|
557
|
+
if genus in self.genus_to_family and self.genus_to_family[genus] == best_family]
|
|
558
|
+
|
|
559
|
+
if family_genera_indices:
|
|
560
|
+
family_genus_probs = prob_avgs[1][family_genera_indices]
|
|
561
|
+
best_genus_idx = family_genera_indices[np.argmax(family_genus_probs)]
|
|
562
|
+
else:
|
|
563
|
+
best_genus_idx = np.argmax(prob_avgs[1])
|
|
564
|
+
|
|
565
|
+
best_genus = self.genus_list[best_genus_idx]
|
|
566
|
+
best_genus_prob = prob_avgs[1][best_genus_idx]
|
|
567
|
+
print(f" Best genus: {best_genus} (prob: {best_genus_prob:.3f})")
|
|
568
|
+
|
|
569
|
+
# Find species belonging to this genus
|
|
570
|
+
genus_species_indices = [i for i, species in enumerate(self.species_list)
|
|
571
|
+
if species in self.species_to_genus and self.species_to_genus[species] == best_genus]
|
|
572
|
+
|
|
573
|
+
if genus_species_indices:
|
|
574
|
+
genus_species_probs = prob_avgs[2][genus_species_indices]
|
|
575
|
+
best_species_idx = genus_species_indices[np.argmax(genus_species_probs)]
|
|
576
|
+
else:
|
|
577
|
+
best_species_idx = np.argmax(prob_avgs[2])
|
|
578
|
+
|
|
579
|
+
best_species = self.species_list[best_species_idx]
|
|
580
|
+
best_species_prob = prob_avgs[2][best_species_idx]
|
|
581
|
+
print(f" Best species: {best_species} (prob: {best_species_prob:.3f})")
|
|
582
|
+
|
|
583
|
+
# Calculate track statistics
|
|
584
|
+
frame_times = [d['frame_time_seconds'] for d in detections]
|
|
585
|
+
first_frame, last_frame = min(frame_times), max(frame_times)
|
|
586
|
+
|
|
587
|
+
aggregated_results.append({
|
|
588
|
+
'track_id': track_id,
|
|
589
|
+
'num_detections': len(detections),
|
|
590
|
+
'first_frame_time': first_frame,
|
|
591
|
+
'last_frame_time': last_frame,
|
|
592
|
+
'duration': last_frame - first_frame,
|
|
593
|
+
'final_family': best_family,
|
|
594
|
+
'final_genus': best_genus,
|
|
595
|
+
'final_species': best_species,
|
|
596
|
+
'family_confidence': best_family_prob,
|
|
597
|
+
'genus_confidence': best_genus_prob,
|
|
598
|
+
'species_confidence': best_species_prob
|
|
599
|
+
})
|
|
600
|
+
|
|
601
|
+
return aggregated_results
|
|
602
|
+
|
|
603
|
+
def print_simplified_summary(self, aggregated_results):
|
|
604
|
+
"""Print a simplified, readable summary of tracking results"""
|
|
605
|
+
print("\n" + "="*60)
|
|
606
|
+
print("🐛 INSECT TRACKING SUMMARY")
|
|
607
|
+
print("="*60)
|
|
608
|
+
|
|
609
|
+
if not aggregated_results:
|
|
610
|
+
print("No insects were tracked in this video.")
|
|
611
|
+
return
|
|
612
|
+
|
|
613
|
+
# Sort by number of detections (most active insects first)
|
|
614
|
+
sorted_results = sorted(aggregated_results, key=lambda x: x['num_detections'], reverse=True)
|
|
615
|
+
|
|
616
|
+
for i, result in enumerate(sorted_results, 1):
|
|
617
|
+
duration = result['duration']
|
|
618
|
+
detection_count = result['num_detections']
|
|
619
|
+
|
|
620
|
+
print(f"\n🐞 Insect {i}:")
|
|
621
|
+
print(f" Detections: {detection_count}")
|
|
622
|
+
print(f" Duration: {duration:.1f}s")
|
|
623
|
+
print(f" Family: {result['final_family']}")
|
|
624
|
+
print(f" Genus: {result['final_genus']}")
|
|
625
|
+
print(f" Species: {result['final_species']}")
|
|
626
|
+
print(f" Confidence: {result['species_confidence']:.1%}")
|
|
627
|
+
|
|
628
|
+
print(f"\n📈 Total: {len(aggregated_results)} unique insects tracked")
|
|
629
|
+
print("="*60)
|
|
630
|
+
|
|
631
|
+
def save_results_table(self, aggregated_results, output_path):
|
|
632
|
+
"""Save aggregated results to CSV"""
|
|
633
|
+
df = pd.DataFrame(aggregated_results).sort_values('track_id')
|
|
634
|
+
csv_path = str(output_path).replace('.mp4', '_results.csv')
|
|
635
|
+
df.to_csv(csv_path, index=False)
|
|
636
|
+
|
|
637
|
+
print(f"\n📊 Results saved to: {csv_path}")
|
|
638
|
+
|
|
639
|
+
# Print simplified summary
|
|
640
|
+
self.print_simplified_summary(aggregated_results)
|
|
641
|
+
|
|
642
|
+
return df
|
|
643
|
+
|
|
644
|
+
# ============================================================================
|
|
645
|
+
# VIDEO PROCESSING FUNCTIONS
|
|
646
|
+
# ============================================================================
|
|
647
|
+
|
|
648
|
+
def process_video(video_path, processor, output_video_path=None, show_video=False, tracker_max_frames=30, fps=None):
|
|
649
|
+
"""Process an MP4 video file frame by frame.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
fps (float, optional): FPS for processing and output. If provided, frames will be skipped
|
|
653
|
+
to match this rate and output video will use this FPS. If None, all frames are processed.
|
|
654
|
+
"""
|
|
655
|
+
|
|
656
|
+
if not os.path.exists(video_path):
|
|
657
|
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
|
658
|
+
|
|
659
|
+
cap = cv2.VideoCapture(video_path)
|
|
660
|
+
if not cap.isOpened():
|
|
661
|
+
raise ValueError(f"Could not open video file: {video_path}")
|
|
662
|
+
|
|
663
|
+
# Get video properties
|
|
664
|
+
input_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
665
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
666
|
+
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
667
|
+
duration = total_frames / input_fps if input_fps > 0 else 0
|
|
668
|
+
|
|
669
|
+
print(f"Processing video: {video_path}")
|
|
670
|
+
print(f"Properties: {total_frames} frames, {input_fps:.2f} FPS, {duration:.2f}s duration")
|
|
671
|
+
|
|
672
|
+
# Initialize tracker and output writer
|
|
673
|
+
tracker = InsectTracker(height, width, max_frames=tracker_max_frames, debug=False)
|
|
674
|
+
out = None
|
|
675
|
+
|
|
676
|
+
if output_video_path:
|
|
677
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
678
|
+
write_fps = fps if fps is not None else input_fps
|
|
679
|
+
out = cv2.VideoWriter(output_video_path, fourcc, write_fps, (width, height))
|
|
680
|
+
print(f"Output video: {output_video_path} at {write_fps:.2f} FPS")
|
|
681
|
+
|
|
682
|
+
frame_number = 0
|
|
683
|
+
processed_frame_count = 0
|
|
684
|
+
start_time = time.time()
|
|
685
|
+
|
|
686
|
+
# Calculate frame skip interval if fps is specified
|
|
687
|
+
frame_skip_interval = None
|
|
688
|
+
if fps is not None and fps > 0 and input_fps > 0:
|
|
689
|
+
frame_skip_interval = max(1, int(input_fps / fps))
|
|
690
|
+
print(f"Processing every {frame_skip_interval} frame(s) to achieve ~{fps:.2f} FPS")
|
|
691
|
+
print(f"Tracker will only receive the processed frames, not all {total_frames} frames")
|
|
692
|
+
|
|
693
|
+
try:
|
|
694
|
+
while True:
|
|
695
|
+
ret, frame = cap.read()
|
|
696
|
+
if not ret:
|
|
697
|
+
break
|
|
698
|
+
|
|
699
|
+
frame_time_seconds = frame_number / input_fps if input_fps > 0 else 0
|
|
700
|
+
|
|
701
|
+
# Check if we should process this frame
|
|
702
|
+
should_process = True
|
|
703
|
+
if frame_skip_interval is not None:
|
|
704
|
+
should_process = (frame_number % frame_skip_interval == 0)
|
|
705
|
+
|
|
706
|
+
if should_process:
|
|
707
|
+
# Process the frame
|
|
708
|
+
processed_frame = processor.process_frame(frame, frame_time_seconds, tracker, frame_number)
|
|
709
|
+
processed_frame_count += 1
|
|
710
|
+
last_processed_frame = processed_frame.copy() # Keep for frame duplication
|
|
711
|
+
|
|
712
|
+
# Show video if requested
|
|
713
|
+
if show_video:
|
|
714
|
+
cv2.imshow('Video Inference', processed_frame)
|
|
715
|
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
716
|
+
print("User requested quit")
|
|
717
|
+
break
|
|
718
|
+
|
|
719
|
+
# If we're skipping frames, skip ahead to save time
|
|
720
|
+
if frame_skip_interval is not None and frame_skip_interval > 1:
|
|
721
|
+
# Skip the next (frame_skip_interval - 1) frames
|
|
722
|
+
for _ in range(frame_skip_interval - 1):
|
|
723
|
+
ret, _ = cap.read()
|
|
724
|
+
if not ret:
|
|
725
|
+
break
|
|
726
|
+
frame_number += 1
|
|
727
|
+
else:
|
|
728
|
+
# Use the last processed frame to maintain video length
|
|
729
|
+
if 'last_processed_frame' in locals():
|
|
730
|
+
processed_frame = last_processed_frame.copy()
|
|
731
|
+
else:
|
|
732
|
+
processed_frame = frame # Fallback for first frames
|
|
733
|
+
|
|
734
|
+
# Write to output video if requested (always write to maintain duration)
|
|
735
|
+
if out:
|
|
736
|
+
out.write(processed_frame)
|
|
737
|
+
|
|
738
|
+
frame_number += 1
|
|
739
|
+
|
|
740
|
+
# Progress update based on processed frames only
|
|
741
|
+
if should_process and processed_frame_count % 25 == 0:
|
|
742
|
+
if frame_skip_interval is not None:
|
|
743
|
+
estimated_total_processed = total_frames // frame_skip_interval
|
|
744
|
+
progress = (processed_frame_count / estimated_total_processed) * 100 if estimated_total_processed > 0 else 0
|
|
745
|
+
print(f"Processed: {processed_frame_count}/{estimated_total_processed} frames ({progress:.1f}%)")
|
|
746
|
+
else:
|
|
747
|
+
progress = (frame_number / total_frames) * 100 if total_frames > 0 else 0
|
|
748
|
+
print(f"Processed: {processed_frame_count}/{total_frames} frames ({progress:.1f}%)")
|
|
749
|
+
|
|
750
|
+
finally:
|
|
751
|
+
cap.release()
|
|
752
|
+
if out:
|
|
753
|
+
out.release()
|
|
754
|
+
if show_video:
|
|
755
|
+
cv2.destroyAllWindows()
|
|
756
|
+
|
|
757
|
+
processing_time = time.time() - start_time
|
|
758
|
+
print(f"\nProcessing complete! {processed_frame_count}/{frame_number} frames processed in {processing_time:.2f}s")
|
|
759
|
+
if processed_frame_count > 0:
|
|
760
|
+
print(f"Processing speed: {processed_frame_count/processing_time:.2f} FPS, Detections: {len(processor.all_detections)}")
|
|
761
|
+
else:
|
|
762
|
+
print(f"No frames processed, Detections: {len(processor.all_detections)}")
|
|
763
|
+
|
|
764
|
+
# Perform hierarchical aggregation and save results
|
|
765
|
+
aggregated_results = processor.hierarchical_aggregation()
|
|
766
|
+
if output_video_path:
|
|
767
|
+
processor.save_results_table(aggregated_results, output_video_path)
|
|
768
|
+
|
|
769
|
+
return aggregated_results
|
|
770
|
+
|
|
771
|
+
# ============================================================================
|
|
772
|
+
# MAIN ENTRY POINT FUNCTIONS
|
|
773
|
+
# ============================================================================
|
|
774
|
+
|
|
775
|
+
def inference(species_list, yolo_model_path, hierarchical_model_path, confidence_threshold,
|
|
776
|
+
video_path, output_path, tracker_max_frames, fps=None):
|
|
777
|
+
"""
|
|
778
|
+
Run inference on a single video file.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
species_list (list): List of species names for classification
|
|
782
|
+
yolo_model_path (str): Path to YOLO model weights
|
|
783
|
+
hierarchical_model_path (str): Path to hierarchical classification model weights
|
|
784
|
+
confidence_threshold (float): Confidence threshold for detections
|
|
785
|
+
video_path (str): Path to input video file
|
|
786
|
+
output_path (str): Path for output video file (including filename)
|
|
787
|
+
tracker_max_frames (int): Maximum frames for tracker context
|
|
788
|
+
fps (float, optional): Processing and output FPS. If provided, frames will be skipped
|
|
789
|
+
to match this rate and output video will use this FPS. If None, all frames are processed.
|
|
790
|
+
|
|
791
|
+
Returns:
|
|
792
|
+
dict: Summary of processing results
|
|
793
|
+
"""
|
|
794
|
+
# Check if input video exists
|
|
795
|
+
if not os.path.exists(video_path):
|
|
796
|
+
error_msg = f"Video file not found: {video_path}"
|
|
797
|
+
print(f"Error: {error_msg}")
|
|
798
|
+
return {"error": error_msg}
|
|
799
|
+
|
|
800
|
+
# Create output directory if it doesn't exist
|
|
801
|
+
output_dir = os.path.dirname(output_path)
|
|
802
|
+
if output_dir:
|
|
803
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
804
|
+
|
|
805
|
+
print(f"Processing single video: {video_path}")
|
|
806
|
+
|
|
807
|
+
# Create processor instance
|
|
808
|
+
print("Initializing models...")
|
|
809
|
+
processor = VideoInferenceProcessor(
|
|
810
|
+
species_list=species_list,
|
|
811
|
+
yolo_model_path=yolo_model_path,
|
|
812
|
+
hierarchical_model_path=hierarchical_model_path,
|
|
813
|
+
confidence_threshold=confidence_threshold
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# Track processing results
|
|
817
|
+
processing_results = {
|
|
818
|
+
"video_file": os.path.basename(video_path),
|
|
819
|
+
"output_path": output_path,
|
|
820
|
+
"success": False,
|
|
821
|
+
"detections": 0,
|
|
822
|
+
"tracks": 0,
|
|
823
|
+
"error": None
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
print(f"\n{'='*20}\nProcessing: {video_path}\n{'='*20}")
|
|
827
|
+
|
|
828
|
+
try:
|
|
829
|
+
# Process the video
|
|
830
|
+
aggregated_results = process_video(
|
|
831
|
+
video_path=video_path,
|
|
832
|
+
processor=processor,
|
|
833
|
+
output_video_path=output_path,
|
|
834
|
+
show_video=False,
|
|
835
|
+
tracker_max_frames=tracker_max_frames,
|
|
836
|
+
fps=fps
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
processing_results.update({
|
|
840
|
+
"success": True,
|
|
841
|
+
"detections": len(processor.all_detections),
|
|
842
|
+
"tracks": len(aggregated_results)
|
|
843
|
+
})
|
|
844
|
+
|
|
845
|
+
print(f"Finished processing. Output saved to {output_path}")
|
|
846
|
+
|
|
847
|
+
except Exception as e:
|
|
848
|
+
error_msg = f"Failed to process {os.path.basename(video_path)}: {str(e)}"
|
|
849
|
+
print(f"Error: {error_msg}")
|
|
850
|
+
processing_results["error"] = error_msg
|
|
851
|
+
|
|
852
|
+
if processing_results["success"]:
|
|
853
|
+
print(f"\nProcessing complete!")
|
|
854
|
+
print(f"Total detections: {processing_results['detections']}")
|
|
855
|
+
print(f"Total tracks: {processing_results['tracks']}")
|
|
856
|
+
|
|
857
|
+
return processing_results
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
def main():
|
|
861
|
+
"""Example usage for processing a single video."""
|
|
862
|
+
# Define your species list (replace with your actual species)
|
|
863
|
+
species_list = [
|
|
864
|
+
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
865
|
+
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
866
|
+
"Eristalis tenax"
|
|
867
|
+
]
|
|
868
|
+
|
|
869
|
+
# Paths (replace with your actual paths)
|
|
870
|
+
video_path = "input_videos/sample_video.mp4"
|
|
871
|
+
output_path = "output_videos/sample_video_predictions.mp4"
|
|
872
|
+
yolo_model_path = "weights/yolo_model.pt"
|
|
873
|
+
hierarchical_model_path = "weights/hierarchical_model.pth"
|
|
874
|
+
|
|
875
|
+
# Run inference
|
|
876
|
+
results = inference(
|
|
877
|
+
species_list=species_list,
|
|
878
|
+
yolo_model_path=yolo_model_path,
|
|
879
|
+
hierarchical_model_path=hierarchical_model_path,
|
|
880
|
+
confidence_threshold=0.35,
|
|
881
|
+
video_path=video_path,
|
|
882
|
+
output_path=output_path,
|
|
883
|
+
tracker_max_frames=60,
|
|
884
|
+
fps=None # Set to e.g. 5.0 to process only 5 frames per second
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
return results
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
if __name__ == "__main__":
|
|
891
|
+
main()
|