bplusplus 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bplusplus/__init__.py +15 -0
- bplusplus/collect.py +523 -0
- bplusplus/detector.py +376 -0
- bplusplus/inference.py +1337 -0
- bplusplus/prepare.py +706 -0
- bplusplus/tracker.py +261 -0
- bplusplus/train.py +913 -0
- bplusplus/validation.py +580 -0
- bplusplus-2.0.4.dist-info/LICENSE +21 -0
- bplusplus-2.0.4.dist-info/METADATA +259 -0
- bplusplus-2.0.4.dist-info/RECORD +12 -0
- bplusplus-2.0.4.dist-info/WHEEL +4 -0
bplusplus/inference.py
ADDED
|
@@ -0,0 +1,1337 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import time
|
|
3
|
+
import os
|
|
4
|
+
import yaml
|
|
5
|
+
import json
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
import uuid
|
|
11
|
+
import argparse
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from torchvision import transforms, models
|
|
16
|
+
from PIL import Image
|
|
17
|
+
import requests
|
|
18
|
+
import logging
|
|
19
|
+
|
|
20
|
+
from .tracker import InsectTracker
|
|
21
|
+
from .detector import (
|
|
22
|
+
DEFAULT_DETECTION_CONFIG,
|
|
23
|
+
get_default_config,
|
|
24
|
+
build_detection_params,
|
|
25
|
+
extract_motion_detections,
|
|
26
|
+
analyze_path_topology,
|
|
27
|
+
check_track_consistency,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Torch serialization compatibility
|
|
31
|
+
if hasattr(torch.serialization, 'add_safe_globals'):
|
|
32
|
+
torch.serialization.add_safe_globals([
|
|
33
|
+
'torch.LongTensor',
|
|
34
|
+
'torch.cuda.LongTensor',
|
|
35
|
+
'torch.FloatStorage',
|
|
36
|
+
'torch.cuda.FloatStorage',
|
|
37
|
+
])
|
|
38
|
+
|
|
39
|
+
logging.basicConfig(level=logging.INFO)
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ============================================================================
|
|
44
|
+
# CONFIGURATION
|
|
45
|
+
# ============================================================================
|
|
46
|
+
|
|
47
|
+
def load_config(config_path):
|
|
48
|
+
"""
|
|
49
|
+
Load detection configuration from YAML or JSON file.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
config_path: Path to config file (.yaml, .yml, or .json)
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
dict: Configuration parameters
|
|
56
|
+
"""
|
|
57
|
+
if not os.path.exists(config_path):
|
|
58
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
59
|
+
|
|
60
|
+
ext = os.path.splitext(config_path)[1].lower()
|
|
61
|
+
|
|
62
|
+
with open(config_path, 'r') as f:
|
|
63
|
+
if ext in ['.yaml', '.yml']:
|
|
64
|
+
config = yaml.safe_load(f)
|
|
65
|
+
elif ext == '.json':
|
|
66
|
+
config = json.load(f)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f"Unsupported config format: {ext}")
|
|
69
|
+
|
|
70
|
+
# Merge with defaults
|
|
71
|
+
params = get_default_config()
|
|
72
|
+
for key, value in config.items():
|
|
73
|
+
if key in params:
|
|
74
|
+
params[key] = value
|
|
75
|
+
else:
|
|
76
|
+
logger.warning(f"Unknown config parameter ignored: {key}")
|
|
77
|
+
|
|
78
|
+
return params
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ============================================================================
|
|
82
|
+
# TAXONOMY UTILITIES
|
|
83
|
+
# ============================================================================
|
|
84
|
+
|
|
85
|
+
def get_taxonomy(species_list):
|
|
86
|
+
"""
|
|
87
|
+
Retrieve taxonomic information from GBIF API.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
species_list: List of species names
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
tuple: (taxonomy_dict, species_to_genus, genus_to_family)
|
|
94
|
+
"""
|
|
95
|
+
taxonomy = {1: [], 2: {}, 3: {}}
|
|
96
|
+
species_to_genus = {}
|
|
97
|
+
genus_to_family = {}
|
|
98
|
+
|
|
99
|
+
species_for_gbif = [s for s in species_list if s.lower() != 'unknown']
|
|
100
|
+
has_unknown = len(species_for_gbif) != len(species_list)
|
|
101
|
+
|
|
102
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_for_gbif)} species")
|
|
103
|
+
print(f"\n{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
|
|
104
|
+
print("-" * 80)
|
|
105
|
+
|
|
106
|
+
for species_name in species_for_gbif:
|
|
107
|
+
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
108
|
+
try:
|
|
109
|
+
response = requests.get(url)
|
|
110
|
+
data = response.json()
|
|
111
|
+
|
|
112
|
+
if data.get('status') in ['ACCEPTED', 'SYNONYM']:
|
|
113
|
+
family = data.get('family')
|
|
114
|
+
genus = data.get('genus')
|
|
115
|
+
|
|
116
|
+
if family and genus:
|
|
117
|
+
print(f"{species_name:<30} {family:<20} {genus:<20} OK")
|
|
118
|
+
species_to_genus[species_name] = genus
|
|
119
|
+
genus_to_family[genus] = family
|
|
120
|
+
if family not in taxonomy[1]:
|
|
121
|
+
taxonomy[1].append(family)
|
|
122
|
+
taxonomy[2][genus] = family
|
|
123
|
+
taxonomy[3][species_name] = genus
|
|
124
|
+
else:
|
|
125
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
126
|
+
logger.error(f"Species '{species_name}' missing family/genus")
|
|
127
|
+
else:
|
|
128
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
129
|
+
logger.error(f"Species '{species_name}' not found in GBIF")
|
|
130
|
+
except Exception as e:
|
|
131
|
+
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
132
|
+
logger.error(f"Error for '{species_name}': {e}")
|
|
133
|
+
|
|
134
|
+
if has_unknown:
|
|
135
|
+
if "Unknown" not in taxonomy[1]:
|
|
136
|
+
taxonomy[1].append("Unknown")
|
|
137
|
+
taxonomy[2]["Unknown"] = "Unknown"
|
|
138
|
+
taxonomy[3]["unknown"] = "Unknown"
|
|
139
|
+
species_to_genus["unknown"] = "Unknown"
|
|
140
|
+
genus_to_family["Unknown"] = "Unknown"
|
|
141
|
+
print(f"{'unknown':<30} {'Unknown':<20} {'Unknown':<20} OK")
|
|
142
|
+
|
|
143
|
+
taxonomy[1] = sorted(set(taxonomy[1]))
|
|
144
|
+
print("-" * 80)
|
|
145
|
+
|
|
146
|
+
for level, name, items in [(1, "Family", taxonomy[1]),
|
|
147
|
+
(2, "Genus", taxonomy[2].keys()),
|
|
148
|
+
(3, "Species", species_list)]:
|
|
149
|
+
print(f"\n{name} indices:")
|
|
150
|
+
for i, item in enumerate(items):
|
|
151
|
+
print(f" {i}: {item}")
|
|
152
|
+
|
|
153
|
+
logger.info(f"Taxonomy: {len(taxonomy[1])} families, {len(taxonomy[2])} genera, {len(taxonomy[3])} species")
|
|
154
|
+
return taxonomy, species_to_genus, genus_to_family
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def create_mappings(taxonomy, species_list=None):
|
|
158
|
+
"""Create index mappings from taxonomy."""
|
|
159
|
+
level_to_idx = {}
|
|
160
|
+
idx_to_level = {}
|
|
161
|
+
|
|
162
|
+
for level, labels in taxonomy.items():
|
|
163
|
+
if isinstance(labels, list):
|
|
164
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
165
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
|
|
166
|
+
else:
|
|
167
|
+
sorted_keys = species_list if level == 3 and species_list else sorted(labels.keys())
|
|
168
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
|
|
169
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(sorted_keys)}
|
|
170
|
+
|
|
171
|
+
return level_to_idx, idx_to_level
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# ============================================================================
|
|
175
|
+
# MODEL
|
|
176
|
+
# ============================================================================
|
|
177
|
+
|
|
178
|
+
class HierarchicalInsectClassifier(nn.Module):
|
|
179
|
+
"""Hierarchical classifier with ResNet backbone and multi-branch heads."""
|
|
180
|
+
|
|
181
|
+
def __init__(self, num_classes_per_level, backbone: str = "resnet50"):
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.backbone = self._build_backbone(backbone)
|
|
184
|
+
self.backbone_name = backbone
|
|
185
|
+
backbone_features = self.backbone.fc.in_features
|
|
186
|
+
self.backbone.fc = nn.Identity()
|
|
187
|
+
|
|
188
|
+
self.branches = nn.ModuleList([
|
|
189
|
+
nn.Sequential(
|
|
190
|
+
nn.Linear(backbone_features, 512),
|
|
191
|
+
nn.ReLU(),
|
|
192
|
+
nn.Dropout(0.5),
|
|
193
|
+
nn.Linear(512, num_classes)
|
|
194
|
+
) for num_classes in num_classes_per_level
|
|
195
|
+
])
|
|
196
|
+
self.num_levels = len(num_classes_per_level)
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def _build_backbone(backbone: str):
|
|
200
|
+
"""Build ResNet backbone by name."""
|
|
201
|
+
name = backbone.lower()
|
|
202
|
+
if name == "resnet18":
|
|
203
|
+
return models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
|
204
|
+
if name == "resnet50":
|
|
205
|
+
return models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
206
|
+
if name == "resnet101":
|
|
207
|
+
return models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
|
|
208
|
+
raise ValueError(f"Unsupported backbone '{backbone}'. Choose from 'resnet18', 'resnet50', 'resnet101'.")
|
|
209
|
+
|
|
210
|
+
def forward(self, x):
|
|
211
|
+
features = self.backbone(x)
|
|
212
|
+
return [branch(features) for branch in self.branches]
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ============================================================================
|
|
216
|
+
# VISUALIZATION
|
|
217
|
+
# ============================================================================
|
|
218
|
+
|
|
219
|
+
class FrameVisualizer:
|
|
220
|
+
"""Visualization utilities for detection overlay."""
|
|
221
|
+
|
|
222
|
+
COLORS = [
|
|
223
|
+
(68, 189, 50), (255, 59, 48), (0, 122, 255), (255, 149, 0),
|
|
224
|
+
(175, 82, 222), (255, 204, 0), (50, 173, 230), (255, 45, 85),
|
|
225
|
+
(48, 209, 88), (90, 200, 250), (255, 159, 10), (191, 90, 242),
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def get_track_color(track_id):
|
|
230
|
+
if track_id is None:
|
|
231
|
+
return (68, 189, 50)
|
|
232
|
+
try:
|
|
233
|
+
track_uuid = uuid.UUID(track_id)
|
|
234
|
+
except (ValueError, TypeError):
|
|
235
|
+
track_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(track_id))
|
|
236
|
+
return FrameVisualizer.COLORS[track_uuid.int % len(FrameVisualizer.COLORS)]
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def draw_path(frame, path, track_id):
|
|
240
|
+
"""Draw track path on frame."""
|
|
241
|
+
if len(path) < 2:
|
|
242
|
+
return
|
|
243
|
+
color = FrameVisualizer.get_track_color(track_id)
|
|
244
|
+
path_points = np.array(path, dtype=np.int32)
|
|
245
|
+
cv2.polylines(frame, [path_points], False, color, 2)
|
|
246
|
+
# Draw center point
|
|
247
|
+
cx, cy = path[-1]
|
|
248
|
+
cv2.circle(frame, (int(cx), int(cy)), 4, color, -1)
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def draw_detection(frame, x1, y1, x2, y2, track_id, detection_data):
|
|
252
|
+
"""Draw bounding box and classification label on frame."""
|
|
253
|
+
color = FrameVisualizer.get_track_color(track_id)
|
|
254
|
+
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
|
|
255
|
+
|
|
256
|
+
track_display = f"ID: {str(track_id)[:8]}" if track_id else "NEW"
|
|
257
|
+
lines = [track_display]
|
|
258
|
+
|
|
259
|
+
for level, conf_key in [("family", "family_confidence"),
|
|
260
|
+
("genus", "genus_confidence"),
|
|
261
|
+
("species", "species_confidence")]:
|
|
262
|
+
if detection_data.get(level):
|
|
263
|
+
name = detection_data[level]
|
|
264
|
+
conf = detection_data.get(conf_key, 0)
|
|
265
|
+
name = name[:15] + "..." if len(name) > 18 else name
|
|
266
|
+
lines.append(f"{level[0].upper()}: {name}")
|
|
267
|
+
lines.append(f" {conf:.1%}")
|
|
268
|
+
|
|
269
|
+
if not lines[1:] and track_id is None:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
font, scale, thickness = cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1
|
|
273
|
+
padding, spacing = 8, 6
|
|
274
|
+
text_sizes = [cv2.getTextSize(line, font, scale, thickness)[0] for line in lines]
|
|
275
|
+
max_w = max(s[0] for s in text_sizes)
|
|
276
|
+
text_h = text_sizes[0][1]
|
|
277
|
+
|
|
278
|
+
total_h = len(lines) * (text_h + spacing) + padding * 2
|
|
279
|
+
label_x1 = max(0, int(x1))
|
|
280
|
+
label_y1 = max(0, int(y1) - total_h - 5)
|
|
281
|
+
if label_y1 < 0:
|
|
282
|
+
label_y1 = int(y2) + 5
|
|
283
|
+
label_x2 = min(frame.shape[1], label_x1 + max_w + padding * 2)
|
|
284
|
+
label_y2 = min(frame.shape[0], label_y1 + total_h)
|
|
285
|
+
|
|
286
|
+
overlay = frame.copy()
|
|
287
|
+
cv2.rectangle(overlay, (label_x1, label_y1), (label_x2, label_y2), (20, 20, 20), -1)
|
|
288
|
+
cv2.addWeighted(overlay, 0.85, frame, 0.15, 0, frame)
|
|
289
|
+
cv2.rectangle(frame, (label_x1, label_y1), (label_x2, label_y2), color, 1)
|
|
290
|
+
|
|
291
|
+
y = label_y1 + padding + text_h
|
|
292
|
+
for i, line in enumerate(lines):
|
|
293
|
+
text_color = color if i == 0 else ((160, 160, 160) if "%" in line else (255, 255, 255))
|
|
294
|
+
cv2.putText(frame, line, (label_x1 + padding, y), font, scale, text_color, thickness, cv2.LINE_AA)
|
|
295
|
+
y += text_h + spacing
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# ============================================================================
|
|
299
|
+
# VIDEO PROCESSOR
|
|
300
|
+
# ============================================================================
|
|
301
|
+
|
|
302
|
+
class VideoInferenceProcessor:
|
|
303
|
+
"""
|
|
304
|
+
Processes video frames for insect detection and classification.
|
|
305
|
+
|
|
306
|
+
Combines motion-based detection with hierarchical classification
|
|
307
|
+
and track-based prediction aggregation.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def __init__(self, species_list, hierarchical_model_path, params, backbone="resnet50", img_size=60):
|
|
311
|
+
"""
|
|
312
|
+
Initialize the processor.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
species_list: List of species names for classification
|
|
316
|
+
hierarchical_model_path: Path to trained model weights
|
|
317
|
+
params: Detection parameters dict
|
|
318
|
+
backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101')
|
|
319
|
+
img_size: Image size for classification (should match training)
|
|
320
|
+
"""
|
|
321
|
+
self.img_size = img_size
|
|
322
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
323
|
+
self.species_list = species_list
|
|
324
|
+
self.params = params
|
|
325
|
+
|
|
326
|
+
print(f"Using device: {self.device}")
|
|
327
|
+
|
|
328
|
+
# Build taxonomy
|
|
329
|
+
self.taxonomy, self.species_to_genus, self.genus_to_family = get_taxonomy(species_list)
|
|
330
|
+
self.level_to_idx, self.idx_to_level = create_mappings(self.taxonomy, species_list)
|
|
331
|
+
self.family_list = sorted(self.taxonomy[1])
|
|
332
|
+
self.genus_list = sorted(self.taxonomy[2].keys())
|
|
333
|
+
|
|
334
|
+
# Motion detection setup
|
|
335
|
+
self.back_sub = cv2.createBackgroundSubtractorMOG2(
|
|
336
|
+
history=500, varThreshold=16, detectShadows=False
|
|
337
|
+
)
|
|
338
|
+
self.morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
|
339
|
+
|
|
340
|
+
# Load classification model
|
|
341
|
+
print(f"Loading hierarchical model from {hierarchical_model_path}")
|
|
342
|
+
checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
|
|
343
|
+
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
|
344
|
+
|
|
345
|
+
# Use backbone from checkpoint if available, otherwise use provided
|
|
346
|
+
model_backbone = checkpoint.get("backbone", backbone)
|
|
347
|
+
if model_backbone != backbone:
|
|
348
|
+
print(f"Note: Using backbone '{model_backbone}' from checkpoint (overrides '{backbone}')")
|
|
349
|
+
|
|
350
|
+
num_classes = [len(self.family_list), len(self.genus_list), len(self.species_list)]
|
|
351
|
+
print(f"Model architecture: {num_classes} classes per level, backbone: {model_backbone}")
|
|
352
|
+
|
|
353
|
+
self.model = HierarchicalInsectClassifier(num_classes, backbone=model_backbone)
|
|
354
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
355
|
+
self.model.to(self.device)
|
|
356
|
+
self.model.eval()
|
|
357
|
+
|
|
358
|
+
self.transform = transforms.Compose([
|
|
359
|
+
transforms.Resize((self.img_size, self.img_size)),
|
|
360
|
+
transforms.ToTensor(),
|
|
361
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
362
|
+
])
|
|
363
|
+
|
|
364
|
+
# Track state
|
|
365
|
+
self.all_detections = []
|
|
366
|
+
self.track_paths = defaultdict(list)
|
|
367
|
+
self.track_areas = defaultdict(list)
|
|
368
|
+
|
|
369
|
+
print("Processor initialized successfully!")
|
|
370
|
+
|
|
371
|
+
def _extract_detections(self, frame):
|
|
372
|
+
"""Extract motion-based detections."""
|
|
373
|
+
detections, fg_mask = extract_motion_detections(
|
|
374
|
+
frame, self.back_sub, self.morph_kernel, self.params
|
|
375
|
+
)
|
|
376
|
+
return detections, fg_mask
|
|
377
|
+
|
|
378
|
+
def _classify(self, frame, x1, y1, x2, y2):
|
|
379
|
+
"""Classify a detection crop."""
|
|
380
|
+
crop = frame[int(y1):int(y2), int(x1):int(x2)]
|
|
381
|
+
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
|
382
|
+
tensor = self.transform(Image.fromarray(crop_rgb)).unsqueeze(0).to(self.device)
|
|
383
|
+
|
|
384
|
+
with torch.no_grad():
|
|
385
|
+
outputs = self.model(tensor)
|
|
386
|
+
|
|
387
|
+
probs = [torch.softmax(o, dim=1).cpu().numpy().flatten() for o in outputs]
|
|
388
|
+
idxs = [np.argmax(p) for p in probs]
|
|
389
|
+
|
|
390
|
+
return {
|
|
391
|
+
"family": self.family_list[idxs[0]] if idxs[0] < len(self.family_list) else f"Family_{idxs[0]}",
|
|
392
|
+
"genus": self.genus_list[idxs[1]] if idxs[1] < len(self.genus_list) else f"Genus_{idxs[1]}",
|
|
393
|
+
"species": self.species_list[idxs[2]] if idxs[2] < len(self.species_list) else f"Species_{idxs[2]}",
|
|
394
|
+
"family_confidence": float(probs[0][idxs[0]]),
|
|
395
|
+
"genus_confidence": float(probs[1][idxs[1]]),
|
|
396
|
+
"species_confidence": float(probs[2][idxs[2]]),
|
|
397
|
+
"family_probs": probs[0].tolist(),
|
|
398
|
+
"genus_probs": probs[1].tolist(),
|
|
399
|
+
"species_probs": probs[2].tolist(),
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
def process_frame(self, frame, frame_time, tracker, frame_number):
|
|
403
|
+
"""
|
|
404
|
+
Process a single frame: detect and track only (no classification).
|
|
405
|
+
Classification happens later for confirmed tracks only.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
frame: BGR image frame
|
|
409
|
+
frame_time: Time in seconds
|
|
410
|
+
tracker: InsectTracker instance
|
|
411
|
+
frame_number: Frame index
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
tuple: (foreground_mask, list of detections with track_ids)
|
|
415
|
+
"""
|
|
416
|
+
detections, fg_mask = self._extract_detections(frame)
|
|
417
|
+
track_ids = tracker.update(detections, frame_number)
|
|
418
|
+
|
|
419
|
+
height, width = frame.shape[:2]
|
|
420
|
+
frame_detections = []
|
|
421
|
+
|
|
422
|
+
for i, det in enumerate(detections):
|
|
423
|
+
x1, y1, x2, y2 = det
|
|
424
|
+
track_id = track_ids[i] if i < len(track_ids) else None
|
|
425
|
+
|
|
426
|
+
# Track consistency check
|
|
427
|
+
if track_id:
|
|
428
|
+
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
|
429
|
+
area = (x2 - x1) * (y2 - y1)
|
|
430
|
+
|
|
431
|
+
if self.track_paths[track_id]:
|
|
432
|
+
prev_pos = self.track_paths[track_id][-1]
|
|
433
|
+
prev_area = self.track_areas[track_id][-1] if self.track_areas[track_id] else area
|
|
434
|
+
|
|
435
|
+
if not check_track_consistency(
|
|
436
|
+
prev_pos, (cx, cy), prev_area, area,
|
|
437
|
+
self.params["max_frame_jump"]
|
|
438
|
+
):
|
|
439
|
+
# Reset track
|
|
440
|
+
self.track_paths[track_id] = []
|
|
441
|
+
self.track_areas[track_id] = []
|
|
442
|
+
|
|
443
|
+
self.track_paths[track_id].append((cx, cy))
|
|
444
|
+
self.track_areas[track_id].append(area)
|
|
445
|
+
|
|
446
|
+
# Store detection WITHOUT classification
|
|
447
|
+
detection_data = {
|
|
448
|
+
"timestamp": datetime.now().isoformat(),
|
|
449
|
+
"frame_number": frame_number,
|
|
450
|
+
"frame_time_seconds": frame_time,
|
|
451
|
+
"track_id": track_id,
|
|
452
|
+
"bbox": [x1, y1, x2, y2],
|
|
453
|
+
"bbox_normalized": [
|
|
454
|
+
(x1 + x2) / (2 * width), (y1 + y2) / (2 * height),
|
|
455
|
+
(x2 - x1) / width, (y2 - y1) / height
|
|
456
|
+
],
|
|
457
|
+
}
|
|
458
|
+
self.all_detections.append(detection_data)
|
|
459
|
+
frame_detections.append(detection_data)
|
|
460
|
+
|
|
461
|
+
# Log (detection only, no classification yet)
|
|
462
|
+
track_display = str(track_id)[:8] if track_id else "NEW"
|
|
463
|
+
print(f"Frame {frame_time:6.2f}s | Track {track_display} | Detected")
|
|
464
|
+
|
|
465
|
+
return fg_mask, frame_detections
|
|
466
|
+
|
|
467
|
+
def classify_confirmed_tracks(self, video_path, confirmed_track_ids, crops_dir=None):
|
|
468
|
+
"""
|
|
469
|
+
Classify only the confirmed tracks by re-reading relevant frames.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
video_path: Path to original video
|
|
473
|
+
confirmed_track_ids: Set of track IDs that passed topology analysis
|
|
474
|
+
crops_dir: Optional directory to save cropped frames
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
dict: track_id -> list of classifications
|
|
478
|
+
"""
|
|
479
|
+
if not confirmed_track_ids:
|
|
480
|
+
print("No confirmed tracks to classify.")
|
|
481
|
+
return {}
|
|
482
|
+
|
|
483
|
+
print(f"\nClassifying {len(confirmed_track_ids)} confirmed tracks...")
|
|
484
|
+
|
|
485
|
+
# Setup crops directory if requested
|
|
486
|
+
if crops_dir:
|
|
487
|
+
os.makedirs(crops_dir, exist_ok=True)
|
|
488
|
+
# Create subdirectory for each track
|
|
489
|
+
for track_id in confirmed_track_ids:
|
|
490
|
+
track_dir = os.path.join(crops_dir, str(track_id)[:8])
|
|
491
|
+
os.makedirs(track_dir, exist_ok=True)
|
|
492
|
+
print(f" Saving crops to: {crops_dir}")
|
|
493
|
+
|
|
494
|
+
# Group detections by frame for confirmed tracks
|
|
495
|
+
frames_to_classify = defaultdict(list)
|
|
496
|
+
for det in self.all_detections:
|
|
497
|
+
if det['track_id'] in confirmed_track_ids:
|
|
498
|
+
frames_to_classify[det['frame_number']].append(det)
|
|
499
|
+
|
|
500
|
+
if not frames_to_classify:
|
|
501
|
+
return {}
|
|
502
|
+
|
|
503
|
+
cap = cv2.VideoCapture(video_path)
|
|
504
|
+
track_classifications = defaultdict(list)
|
|
505
|
+
|
|
506
|
+
frame_numbers = sorted(frames_to_classify.keys())
|
|
507
|
+
current_frame = 0
|
|
508
|
+
classified_count = 0
|
|
509
|
+
|
|
510
|
+
for target_frame in frame_numbers:
|
|
511
|
+
# Seek to frame
|
|
512
|
+
while current_frame < target_frame:
|
|
513
|
+
cap.read()
|
|
514
|
+
current_frame += 1
|
|
515
|
+
|
|
516
|
+
ret, frame = cap.read()
|
|
517
|
+
if not ret:
|
|
518
|
+
break
|
|
519
|
+
current_frame += 1
|
|
520
|
+
|
|
521
|
+
# Classify each detection in this frame
|
|
522
|
+
for det in frames_to_classify[target_frame]:
|
|
523
|
+
x1, y1, x2, y2 = det['bbox']
|
|
524
|
+
classification = self._classify(frame, x1, y1, x2, y2)
|
|
525
|
+
|
|
526
|
+
# Update detection with classification
|
|
527
|
+
det.update(classification)
|
|
528
|
+
|
|
529
|
+
track_classifications[det['track_id']].append(classification)
|
|
530
|
+
classified_count += 1
|
|
531
|
+
|
|
532
|
+
# Save crop if requested
|
|
533
|
+
if crops_dir:
|
|
534
|
+
track_id = det['track_id']
|
|
535
|
+
track_dir = os.path.join(crops_dir, str(track_id)[:8])
|
|
536
|
+
crop = frame[int(y1):int(y2), int(x1):int(x2)]
|
|
537
|
+
if crop.size > 0:
|
|
538
|
+
crop_path = os.path.join(track_dir, f"frame_{target_frame:06d}.jpg")
|
|
539
|
+
cv2.imwrite(crop_path, crop)
|
|
540
|
+
|
|
541
|
+
if classified_count % 20 == 0:
|
|
542
|
+
print(f" Classified {classified_count} detections...", end='\r')
|
|
543
|
+
|
|
544
|
+
cap.release()
|
|
545
|
+
print(f"\n✓ Classified {classified_count} detections from {len(confirmed_track_ids)} tracks")
|
|
546
|
+
if crops_dir:
|
|
547
|
+
print(f"✓ Saved {classified_count} crops to {crops_dir}")
|
|
548
|
+
|
|
549
|
+
return track_classifications
|
|
550
|
+
|
|
551
|
+
def analyze_tracks(self):
|
|
552
|
+
"""
|
|
553
|
+
Analyze all tracks to determine which pass topology (before classification).
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
tuple: (confirmed_track_ids set, all_track_info dict)
|
|
557
|
+
"""
|
|
558
|
+
print("\n" + "="*60)
|
|
559
|
+
print("TRACK TOPOLOGY ANALYSIS")
|
|
560
|
+
print("="*60)
|
|
561
|
+
|
|
562
|
+
track_detections = defaultdict(list)
|
|
563
|
+
for det in self.all_detections:
|
|
564
|
+
if det['track_id']:
|
|
565
|
+
track_detections[det['track_id']].append(det)
|
|
566
|
+
|
|
567
|
+
confirmed_track_ids = set()
|
|
568
|
+
all_track_info = {}
|
|
569
|
+
|
|
570
|
+
for track_id, detections in track_detections.items():
|
|
571
|
+
# Path topology analysis
|
|
572
|
+
path = self.track_paths.get(track_id, [])
|
|
573
|
+
passes_topology, topology_metrics = analyze_path_topology(path, self.params)
|
|
574
|
+
|
|
575
|
+
frame_times = [d['frame_time_seconds'] for d in detections]
|
|
576
|
+
|
|
577
|
+
track_info = {
|
|
578
|
+
'track_id': track_id,
|
|
579
|
+
'num_detections': len(detections),
|
|
580
|
+
'first_frame_time': min(frame_times),
|
|
581
|
+
'last_frame_time': max(frame_times),
|
|
582
|
+
'duration': max(frame_times) - min(frame_times),
|
|
583
|
+
'passes_topology': passes_topology,
|
|
584
|
+
**topology_metrics
|
|
585
|
+
}
|
|
586
|
+
all_track_info[track_id] = track_info
|
|
587
|
+
|
|
588
|
+
status = "✓ CONFIRMED" if passes_topology else "? unconfirmed"
|
|
589
|
+
print(f"Track {str(track_id)[:8]}: {len(detections)} detections, "
|
|
590
|
+
f"{track_info['duration']:.1f}s - {status}")
|
|
591
|
+
|
|
592
|
+
if passes_topology:
|
|
593
|
+
confirmed_track_ids.add(track_id)
|
|
594
|
+
|
|
595
|
+
print(f"\n✓ {len(confirmed_track_ids)} confirmed / {len(track_detections)} total tracks")
|
|
596
|
+
return confirmed_track_ids, all_track_info
|
|
597
|
+
|
|
598
|
+
def hierarchical_aggregation(self, confirmed_track_ids):
|
|
599
|
+
"""
|
|
600
|
+
Aggregate predictions for confirmed tracks using hierarchical selection.
|
|
601
|
+
Must be called AFTER classify_confirmed_tracks().
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
confirmed_track_ids: Set of confirmed track IDs
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
list: Aggregated results for confirmed tracks only
|
|
608
|
+
"""
|
|
609
|
+
print("\n" + "="*60)
|
|
610
|
+
print("HIERARCHICAL AGGREGATION (Confirmed Tracks)")
|
|
611
|
+
print("="*60)
|
|
612
|
+
|
|
613
|
+
track_detections = defaultdict(list)
|
|
614
|
+
for det in self.all_detections:
|
|
615
|
+
if det['track_id'] in confirmed_track_ids:
|
|
616
|
+
track_detections[det['track_id']].append(det)
|
|
617
|
+
|
|
618
|
+
results = []
|
|
619
|
+
for track_id, detections in track_detections.items():
|
|
620
|
+
# Check if classifications exist
|
|
621
|
+
if 'family_probs' not in detections[0]:
|
|
622
|
+
print(f"Warning: Track {str(track_id)[:8]} has no classifications, skipping")
|
|
623
|
+
continue
|
|
624
|
+
|
|
625
|
+
print(f"\nTrack {str(track_id)[:8]}: {len(detections)} classified detections")
|
|
626
|
+
|
|
627
|
+
# Path topology analysis
|
|
628
|
+
path = self.track_paths.get(track_id, [])
|
|
629
|
+
passes_topology, topology_metrics = analyze_path_topology(path, self.params)
|
|
630
|
+
|
|
631
|
+
# Average probabilities
|
|
632
|
+
prob_avgs = [
|
|
633
|
+
np.mean([d['family_probs'] for d in detections], axis=0),
|
|
634
|
+
np.mean([d['genus_probs'] for d in detections], axis=0),
|
|
635
|
+
np.mean([d['species_probs'] for d in detections], axis=0),
|
|
636
|
+
]
|
|
637
|
+
|
|
638
|
+
# Hierarchical selection
|
|
639
|
+
best_family_idx = np.argmax(prob_avgs[0])
|
|
640
|
+
best_family = self.family_list[best_family_idx]
|
|
641
|
+
|
|
642
|
+
family_genera = [i for i, g in enumerate(self.genus_list)
|
|
643
|
+
if self.genus_to_family.get(g) == best_family]
|
|
644
|
+
if family_genera:
|
|
645
|
+
best_genus_idx = family_genera[np.argmax(prob_avgs[1][family_genera])]
|
|
646
|
+
else:
|
|
647
|
+
best_genus_idx = np.argmax(prob_avgs[1])
|
|
648
|
+
best_genus = self.genus_list[best_genus_idx]
|
|
649
|
+
|
|
650
|
+
genus_species = [i for i, s in enumerate(self.species_list)
|
|
651
|
+
if self.species_to_genus.get(s) == best_genus]
|
|
652
|
+
if genus_species:
|
|
653
|
+
best_species_idx = genus_species[np.argmax(prob_avgs[2][genus_species])]
|
|
654
|
+
else:
|
|
655
|
+
best_species_idx = np.argmax(prob_avgs[2])
|
|
656
|
+
best_species = self.species_list[best_species_idx]
|
|
657
|
+
|
|
658
|
+
frame_times = [d['frame_time_seconds'] for d in detections]
|
|
659
|
+
|
|
660
|
+
result = {
|
|
661
|
+
'track_id': track_id,
|
|
662
|
+
'num_detections': len(detections),
|
|
663
|
+
'first_frame_time': min(frame_times),
|
|
664
|
+
'last_frame_time': max(frame_times),
|
|
665
|
+
'duration': max(frame_times) - min(frame_times),
|
|
666
|
+
'final_family': best_family,
|
|
667
|
+
'final_genus': best_genus,
|
|
668
|
+
'final_species': best_species,
|
|
669
|
+
'family_confidence': float(prob_avgs[0][best_family_idx]),
|
|
670
|
+
'genus_confidence': float(prob_avgs[1][best_genus_idx]),
|
|
671
|
+
'species_confidence': float(prob_avgs[2][best_species_idx]),
|
|
672
|
+
'passes_topology': passes_topology,
|
|
673
|
+
**topology_metrics
|
|
674
|
+
}
|
|
675
|
+
results.append(result)
|
|
676
|
+
|
|
677
|
+
print(f" → {best_family} / {best_genus} / {best_species} "
|
|
678
|
+
f"({result['species_confidence']:.1%})")
|
|
679
|
+
|
|
680
|
+
return results
|
|
681
|
+
|
|
682
|
+
def save_results(self, results, output_paths):
|
|
683
|
+
"""
|
|
684
|
+
Save results to CSV and print summary.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
results: Aggregated results list (confirmed tracks only)
|
|
688
|
+
output_paths: Dict with output file paths
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
pd.DataFrame: Results dataframe (confirmed tracks only)
|
|
692
|
+
"""
|
|
693
|
+
# Count total tracks vs confirmed
|
|
694
|
+
total_tracks = len(self.track_paths)
|
|
695
|
+
num_confirmed = len(results)
|
|
696
|
+
num_unconfirmed = total_tracks - num_confirmed
|
|
697
|
+
|
|
698
|
+
# Save confirmed tracks to results CSV
|
|
699
|
+
if results:
|
|
700
|
+
df = pd.DataFrame(results).sort_values('num_detections', ascending=False)
|
|
701
|
+
df.to_csv(output_paths["results_csv"], index=False)
|
|
702
|
+
print(f"\n📊 Confirmed results saved: {output_paths['results_csv']} ({num_confirmed} tracks)")
|
|
703
|
+
else:
|
|
704
|
+
# Create empty CSV with headers
|
|
705
|
+
df = pd.DataFrame(columns=[
|
|
706
|
+
'track_id', 'num_detections', 'first_frame_time', 'last_frame_time',
|
|
707
|
+
'duration', 'final_family', 'final_genus', 'final_species',
|
|
708
|
+
'family_confidence', 'genus_confidence', 'species_confidence',
|
|
709
|
+
'passes_topology', 'total_displacement', 'revisit_ratio',
|
|
710
|
+
'progression_ratio', 'directional_variance'
|
|
711
|
+
])
|
|
712
|
+
df.to_csv(output_paths["results_csv"], index=False)
|
|
713
|
+
print(f"\n📊 Results file created (empty - no confirmed tracks): {output_paths['results_csv']}")
|
|
714
|
+
|
|
715
|
+
# Save all detections (frame-by-frame, regardless of confirmation)
|
|
716
|
+
det_df = pd.DataFrame(self.all_detections)
|
|
717
|
+
det_df.to_csv(output_paths["detections_csv"], index=False)
|
|
718
|
+
print(f"📋 Frame-by-frame detections saved: {output_paths['detections_csv']}")
|
|
719
|
+
|
|
720
|
+
# Print summary
|
|
721
|
+
print("\n" + "="*60)
|
|
722
|
+
print("🐛 FINAL SUMMARY")
|
|
723
|
+
print("="*60)
|
|
724
|
+
|
|
725
|
+
if results:
|
|
726
|
+
print(f"\n✓ CONFIRMED INSECTS ({num_confirmed}):")
|
|
727
|
+
for r in results:
|
|
728
|
+
print(f" • {r['final_species']} - {r['num_detections']} detections, "
|
|
729
|
+
f"{r['duration']:.1f}s, {r['species_confidence']:.1%}")
|
|
730
|
+
|
|
731
|
+
if num_unconfirmed > 0:
|
|
732
|
+
print(f"\n? Unconfirmed tracks: {num_unconfirmed} (failed topology analysis)")
|
|
733
|
+
|
|
734
|
+
print(f"\n📈 Total: {total_tracks} tracks ({num_confirmed} confirmed, {num_unconfirmed} unconfirmed)")
|
|
735
|
+
|
|
736
|
+
# Bold warning if no confirmed tracks
|
|
737
|
+
if not results:
|
|
738
|
+
print("\n" + "!"*60)
|
|
739
|
+
print("⚠️ WARNING: NO CONFIRMED INSECT TRACKS DETECTED!")
|
|
740
|
+
print("!"*60)
|
|
741
|
+
print("Possible reasons:")
|
|
742
|
+
print(" • No insects present in the video")
|
|
743
|
+
print(" • Detection parameters too strict (try lowering min_area)")
|
|
744
|
+
print(" • Tracking parameters too strict (try increasing lost_track_seconds)")
|
|
745
|
+
print(" • Path topology too strict (try lowering min_displacement)")
|
|
746
|
+
print(" • Video quality/resolution issues")
|
|
747
|
+
if num_unconfirmed > 0:
|
|
748
|
+
print(f"\nNote: {num_unconfirmed} tracks were detected but failed topology check.")
|
|
749
|
+
print("Consider relaxing path topology parameters if these might be valid insects.")
|
|
750
|
+
print("!"*60)
|
|
751
|
+
|
|
752
|
+
print("="*60)
|
|
753
|
+
|
|
754
|
+
return df
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
# ============================================================================
|
|
758
|
+
# VIDEO PROCESSING
|
|
759
|
+
# ============================================================================
|
|
760
|
+
|
|
761
|
+
def process_video(video_path, processor, output_paths, show_video=False, fps=None, crops_dir=None):
|
|
762
|
+
"""
|
|
763
|
+
Process video file with efficient classification (confirmed tracks only).
|
|
764
|
+
|
|
765
|
+
Pipeline:
|
|
766
|
+
1. Detection & Tracking: Process all frames, detect motion, build tracks
|
|
767
|
+
2. Topology Analysis: Determine which tracks are confirmed insects
|
|
768
|
+
3. Classification: Classify ONLY confirmed tracks (saves compute)
|
|
769
|
+
4. Render Videos: Debug (all detections) + Annotated (confirmed with classifications)
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
video_path: Input video path
|
|
773
|
+
processor: VideoInferenceProcessor instance
|
|
774
|
+
output_paths: Dict with output file paths
|
|
775
|
+
show_video: Display video while processing
|
|
776
|
+
fps: Target FPS (skip frames if lower than input)
|
|
777
|
+
crops_dir: Optional directory to save cropped frames for each track
|
|
778
|
+
|
|
779
|
+
Returns:
|
|
780
|
+
list: Aggregated results
|
|
781
|
+
"""
|
|
782
|
+
if not os.path.exists(video_path):
|
|
783
|
+
raise FileNotFoundError(f"Video not found: {video_path}")
|
|
784
|
+
|
|
785
|
+
cap = cv2.VideoCapture(video_path)
|
|
786
|
+
if not cap.isOpened():
|
|
787
|
+
raise ValueError(f"Could not open: {video_path}")
|
|
788
|
+
|
|
789
|
+
input_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
790
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
791
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
792
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
793
|
+
|
|
794
|
+
print(f"\nVideo: {video_path}")
|
|
795
|
+
print(f"Properties: {total_frames} frames, {input_fps:.1f} FPS, {total_frames/input_fps:.1f}s")
|
|
796
|
+
|
|
797
|
+
# Setup tracker
|
|
798
|
+
effective_fps = fps if fps and fps > 0 else input_fps if input_fps > 0 else 30
|
|
799
|
+
track_memory = max(30, int(processor.params["lost_track_seconds"] * effective_fps))
|
|
800
|
+
|
|
801
|
+
tracker = InsectTracker(
|
|
802
|
+
image_height=height, image_width=width,
|
|
803
|
+
max_frames=track_memory, track_memory_frames=track_memory,
|
|
804
|
+
w_dist=0.6, w_area=0.4, cost_threshold=0.3, debug=False
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# Frame skip
|
|
808
|
+
skip_interval = max(1, int(input_fps / fps)) if fps and fps > 0 else 1
|
|
809
|
+
if skip_interval > 1:
|
|
810
|
+
print(f"Processing every {skip_interval} frame(s)")
|
|
811
|
+
|
|
812
|
+
# ==========================================================================
|
|
813
|
+
# PHASE 1: Detection & Tracking (no classification)
|
|
814
|
+
# ==========================================================================
|
|
815
|
+
print("\n" + "="*60)
|
|
816
|
+
print("PHASE 1: DETECTION & TRACKING")
|
|
817
|
+
print("="*60)
|
|
818
|
+
|
|
819
|
+
frame_num = 0
|
|
820
|
+
processed = 0
|
|
821
|
+
start = time.time()
|
|
822
|
+
|
|
823
|
+
while True:
|
|
824
|
+
ret, frame = cap.read()
|
|
825
|
+
if not ret:
|
|
826
|
+
break
|
|
827
|
+
|
|
828
|
+
frame_time = frame_num / input_fps if input_fps > 0 else 0
|
|
829
|
+
|
|
830
|
+
if frame_num % skip_interval == 0:
|
|
831
|
+
fg_mask, frame_dets = processor.process_frame(frame, frame_time, tracker, frame_num)
|
|
832
|
+
processed += 1
|
|
833
|
+
|
|
834
|
+
if processed % 50 == 0:
|
|
835
|
+
print(f" Progress: {processed} frames, {len(processor.all_detections)} detections", end='\r')
|
|
836
|
+
|
|
837
|
+
frame_num += 1
|
|
838
|
+
|
|
839
|
+
cap.release()
|
|
840
|
+
elapsed = time.time() - start
|
|
841
|
+
print(f"\n✓ Phase 1 complete: {processed} frames in {elapsed:.1f}s ({processed/elapsed:.1f} FPS)")
|
|
842
|
+
print(f" Total detections: {len(processor.all_detections)}")
|
|
843
|
+
print(f" Unique tracks: {len(processor.track_paths)}")
|
|
844
|
+
|
|
845
|
+
# ==========================================================================
|
|
846
|
+
# PHASE 2: Topology Analysis (determine confirmed tracks)
|
|
847
|
+
# ==========================================================================
|
|
848
|
+
confirmed_track_ids, all_track_info = processor.analyze_tracks()
|
|
849
|
+
|
|
850
|
+
# ==========================================================================
|
|
851
|
+
# PHASE 3: Classification (confirmed tracks only)
|
|
852
|
+
# ==========================================================================
|
|
853
|
+
print("\n" + "="*60)
|
|
854
|
+
print("PHASE 3: CLASSIFICATION (Confirmed Tracks Only)")
|
|
855
|
+
print("="*60)
|
|
856
|
+
|
|
857
|
+
if confirmed_track_ids:
|
|
858
|
+
processor.classify_confirmed_tracks(video_path, confirmed_track_ids, crops_dir=crops_dir)
|
|
859
|
+
results = processor.hierarchical_aggregation(confirmed_track_ids)
|
|
860
|
+
else:
|
|
861
|
+
results = []
|
|
862
|
+
|
|
863
|
+
# ==========================================================================
|
|
864
|
+
# PHASE 4: Render Videos
|
|
865
|
+
# ==========================================================================
|
|
866
|
+
# Render videos if requested
|
|
867
|
+
if "annotated_video" in output_paths or "debug_video" in output_paths:
|
|
868
|
+
print("\n" + "="*60)
|
|
869
|
+
print("PHASE 4: RENDERING VIDEOS")
|
|
870
|
+
print("="*60)
|
|
871
|
+
|
|
872
|
+
# Render debug video (all detections, showing confirmed vs unconfirmed)
|
|
873
|
+
if "debug_video" in output_paths:
|
|
874
|
+
print(f"\nRendering debug video (all detections)...")
|
|
875
|
+
_render_debug_video(
|
|
876
|
+
video_path, output_paths["debug_video"],
|
|
877
|
+
processor, confirmed_track_ids, all_track_info, input_fps
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
# Render annotated video (confirmed tracks with classifications)
|
|
881
|
+
if "annotated_video" in output_paths:
|
|
882
|
+
print(f"\nRendering annotated video ({len(confirmed_track_ids)} confirmed tracks)...")
|
|
883
|
+
_render_annotated_video(
|
|
884
|
+
video_path, output_paths["annotated_video"],
|
|
885
|
+
processor, confirmed_track_ids, input_fps
|
|
886
|
+
)
|
|
887
|
+
else:
|
|
888
|
+
print("\n(Video rendering skipped)")
|
|
889
|
+
|
|
890
|
+
# Save results
|
|
891
|
+
processor.save_results(results, output_paths)
|
|
892
|
+
|
|
893
|
+
return results
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
def _render_debug_video(video_path, output_path, processor, confirmed_track_ids, all_track_info, fps):
|
|
897
|
+
"""
|
|
898
|
+
Render debug video showing all detections with confirmed/unconfirmed status.
|
|
899
|
+
Shows detection boxes, track IDs, and GMM motion mask side-by-side.
|
|
900
|
+
"""
|
|
901
|
+
cap = cv2.VideoCapture(video_path)
|
|
902
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
903
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
904
|
+
|
|
905
|
+
# Recreate background subtractor for GMM visualization
|
|
906
|
+
back_sub = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=False)
|
|
907
|
+
|
|
908
|
+
out = cv2.VideoWriter(
|
|
909
|
+
output_path,
|
|
910
|
+
cv2.VideoWriter_fourcc(*'mp4v'),
|
|
911
|
+
fps, (width * 2, height)
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
# Build frame-to-detections lookup
|
|
915
|
+
frame_detections = defaultdict(list)
|
|
916
|
+
for det in processor.all_detections:
|
|
917
|
+
frame_detections[det['frame_number']].append(det)
|
|
918
|
+
|
|
919
|
+
frame_num = 0
|
|
920
|
+
while True:
|
|
921
|
+
ret, frame = cap.read()
|
|
922
|
+
if not ret:
|
|
923
|
+
break
|
|
924
|
+
|
|
925
|
+
# Get GMM mask
|
|
926
|
+
fg_mask = back_sub.apply(frame)
|
|
927
|
+
fg_display = cv2.cvtColor(fg_mask, cv2.COLOR_GRAY2BGR)
|
|
928
|
+
|
|
929
|
+
# Draw all detections with status
|
|
930
|
+
for det in frame_detections[frame_num]:
|
|
931
|
+
x1, y1, x2, y2 = [int(v) for v in det['bbox']]
|
|
932
|
+
track_id = det['track_id']
|
|
933
|
+
|
|
934
|
+
if track_id in confirmed_track_ids:
|
|
935
|
+
# Confirmed: Green box, show classification if available
|
|
936
|
+
color = (0, 255, 0)
|
|
937
|
+
status = "CONFIRMED"
|
|
938
|
+
classification = {
|
|
939
|
+
'species': det.get('species', ''),
|
|
940
|
+
'species_confidence': det.get('species_confidence', 0),
|
|
941
|
+
}
|
|
942
|
+
label = f"{str(track_id)[:6]} ✓"
|
|
943
|
+
if classification['species']:
|
|
944
|
+
label += f" {classification['species'][:12]}"
|
|
945
|
+
else:
|
|
946
|
+
# Unconfirmed: Yellow box
|
|
947
|
+
color = (0, 255, 255)
|
|
948
|
+
status = "tracking..."
|
|
949
|
+
label = f"{str(track_id)[:6] if track_id else 'NEW'}"
|
|
950
|
+
|
|
951
|
+
# Draw box
|
|
952
|
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
953
|
+
|
|
954
|
+
# Draw label background
|
|
955
|
+
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
|
956
|
+
cv2.rectangle(frame, (x1, y1 - th - 4), (x1 + tw + 4, y1), color, -1)
|
|
957
|
+
cv2.putText(frame, label, (x1 + 2, y1 - 2),
|
|
958
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1)
|
|
959
|
+
|
|
960
|
+
# Draw on GMM mask too
|
|
961
|
+
cv2.rectangle(fg_display, (x1, y1), (x2, y2), color, 2)
|
|
962
|
+
|
|
963
|
+
# Add headers
|
|
964
|
+
cv2.putText(frame, f"Frame {frame_num} | Detections (Green=Confirmed, Yellow=Tracking)",
|
|
965
|
+
(10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
966
|
+
cv2.putText(fg_display, "GMM Motion Mask",
|
|
967
|
+
(10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
968
|
+
|
|
969
|
+
# Combine side-by-side
|
|
970
|
+
combined = np.hstack((frame, fg_display))
|
|
971
|
+
out.write(combined)
|
|
972
|
+
|
|
973
|
+
frame_num += 1
|
|
974
|
+
if frame_num % 100 == 0:
|
|
975
|
+
print(f" Debug: {frame_num} frames", end='\r')
|
|
976
|
+
|
|
977
|
+
cap.release()
|
|
978
|
+
out.release()
|
|
979
|
+
print(f"\n✓ Debug video saved: {output_path}")
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def _render_annotated_video(video_path, output_path, processor, confirmed_track_ids, fps):
|
|
983
|
+
"""
|
|
984
|
+
Render annotated video showing only confirmed tracks with classifications.
|
|
985
|
+
"""
|
|
986
|
+
if not confirmed_track_ids:
|
|
987
|
+
# Create video with "No confirmed tracks" message
|
|
988
|
+
cap = cv2.VideoCapture(video_path)
|
|
989
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
990
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
991
|
+
|
|
992
|
+
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
|
993
|
+
|
|
994
|
+
frame_num = 0
|
|
995
|
+
while True:
|
|
996
|
+
ret, frame = cap.read()
|
|
997
|
+
if not ret:
|
|
998
|
+
break
|
|
999
|
+
cv2.putText(frame, "No confirmed insect tracks",
|
|
1000
|
+
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
|
|
1001
|
+
out.write(frame)
|
|
1002
|
+
frame_num += 1
|
|
1003
|
+
|
|
1004
|
+
cap.release()
|
|
1005
|
+
out.release()
|
|
1006
|
+
print(f"✓ Annotated video saved (no confirmed tracks): {output_path}")
|
|
1007
|
+
return
|
|
1008
|
+
|
|
1009
|
+
cap = cv2.VideoCapture(video_path)
|
|
1010
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
1011
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
1012
|
+
|
|
1013
|
+
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
|
1014
|
+
|
|
1015
|
+
# Build frame-to-detections lookup for confirmed tracks only
|
|
1016
|
+
frame_detections = defaultdict(list)
|
|
1017
|
+
for det in processor.all_detections:
|
|
1018
|
+
if det['track_id'] in confirmed_track_ids:
|
|
1019
|
+
frame_detections[det['frame_number']].append(det)
|
|
1020
|
+
|
|
1021
|
+
frame_num = 0
|
|
1022
|
+
while True:
|
|
1023
|
+
ret, frame = cap.read()
|
|
1024
|
+
if not ret:
|
|
1025
|
+
break
|
|
1026
|
+
|
|
1027
|
+
# Draw paths for confirmed tracks (up to current frame)
|
|
1028
|
+
for track_id in confirmed_track_ids:
|
|
1029
|
+
path_to_draw = []
|
|
1030
|
+
for det in processor.all_detections:
|
|
1031
|
+
if det['track_id'] == track_id and det['frame_number'] <= frame_num:
|
|
1032
|
+
bbox = det['bbox']
|
|
1033
|
+
cx = (bbox[0] + bbox[2]) / 2
|
|
1034
|
+
cy = (bbox[1] + bbox[3]) / 2
|
|
1035
|
+
path_to_draw.append((cx, cy))
|
|
1036
|
+
|
|
1037
|
+
if len(path_to_draw) > 1:
|
|
1038
|
+
FrameVisualizer.draw_path(frame, path_to_draw, track_id)
|
|
1039
|
+
|
|
1040
|
+
# Draw detections for this frame (confirmed tracks only, with classification)
|
|
1041
|
+
for det in frame_detections[frame_num]:
|
|
1042
|
+
x1, y1, x2, y2 = det['bbox']
|
|
1043
|
+
track_id = det['track_id']
|
|
1044
|
+
|
|
1045
|
+
classification = {
|
|
1046
|
+
'family': det.get('family', ''),
|
|
1047
|
+
'genus': det.get('genus', ''),
|
|
1048
|
+
'species': det.get('species', ''),
|
|
1049
|
+
'family_confidence': det.get('family_confidence', 0),
|
|
1050
|
+
'genus_confidence': det.get('genus_confidence', 0),
|
|
1051
|
+
'species_confidence': det.get('species_confidence', 0),
|
|
1052
|
+
}
|
|
1053
|
+
FrameVisualizer.draw_detection(frame, x1, y1, x2, y2, track_id, classification)
|
|
1054
|
+
|
|
1055
|
+
# Add header
|
|
1056
|
+
cv2.putText(frame, f"Confirmed Insects ({len(confirmed_track_ids)} tracks)",
|
|
1057
|
+
(10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
|
1058
|
+
|
|
1059
|
+
out.write(frame)
|
|
1060
|
+
frame_num += 1
|
|
1061
|
+
|
|
1062
|
+
if frame_num % 100 == 0:
|
|
1063
|
+
print(f" Annotated: {frame_num} frames", end='\r')
|
|
1064
|
+
|
|
1065
|
+
cap.release()
|
|
1066
|
+
out.release()
|
|
1067
|
+
print(f"\n✓ Annotated video saved: {output_path}")
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
# ============================================================================
|
|
1071
|
+
# MAIN ENTRY POINT
|
|
1072
|
+
# ============================================================================
|
|
1073
|
+
|
|
1074
|
+
def inference(
|
|
1075
|
+
species_list,
|
|
1076
|
+
hierarchical_model_path,
|
|
1077
|
+
video_path,
|
|
1078
|
+
output_dir,
|
|
1079
|
+
fps=None,
|
|
1080
|
+
config=None,
|
|
1081
|
+
backbone="resnet50",
|
|
1082
|
+
crops=False,
|
|
1083
|
+
save_video=True,
|
|
1084
|
+
img_size=60,
|
|
1085
|
+
):
|
|
1086
|
+
"""
|
|
1087
|
+
Run inference on a video file.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
species_list: List of species names for classification
|
|
1091
|
+
hierarchical_model_path: Path to trained model weights
|
|
1092
|
+
video_path: Input video path
|
|
1093
|
+
output_dir: Output directory for all generated files
|
|
1094
|
+
fps: Target processing FPS (None = use input FPS)
|
|
1095
|
+
config: Detection config - can be:
|
|
1096
|
+
- None: use defaults
|
|
1097
|
+
- str: path to YAML/JSON config file
|
|
1098
|
+
- dict: config parameters directly
|
|
1099
|
+
backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101').
|
|
1100
|
+
If model checkpoint contains backbone info, it will be used instead.
|
|
1101
|
+
crops: If True, save cropped frames for each classified track
|
|
1102
|
+
save_video: If True, save annotated and debug videos. Defaults to True.
|
|
1103
|
+
img_size: Image size for classification (should match training). Default: 60.
|
|
1104
|
+
|
|
1105
|
+
Returns:
|
|
1106
|
+
dict: Processing results with output file paths
|
|
1107
|
+
|
|
1108
|
+
Generated files in output_dir:
|
|
1109
|
+
- {video_name}_annotated.mp4: Video with detection boxes and paths (if save_video=True)
|
|
1110
|
+
- {video_name}_debug.mp4: Side-by-side with GMM motion mask (if save_video=True)
|
|
1111
|
+
- {video_name}_results.csv: Aggregated track results
|
|
1112
|
+
- {video_name}_detections.csv: Frame-by-frame detections
|
|
1113
|
+
- {video_name}_crops/ (if crops=True): Directory with cropped frames per track
|
|
1114
|
+
"""
|
|
1115
|
+
if not os.path.exists(video_path):
|
|
1116
|
+
print(f"Error: Video not found: {video_path}")
|
|
1117
|
+
return {"error": f"Video not found: {video_path}", "success": False}
|
|
1118
|
+
|
|
1119
|
+
# Build parameters from config
|
|
1120
|
+
if config is None:
|
|
1121
|
+
params = get_default_config()
|
|
1122
|
+
elif isinstance(config, str):
|
|
1123
|
+
params = load_config(config)
|
|
1124
|
+
elif isinstance(config, dict):
|
|
1125
|
+
params = get_default_config()
|
|
1126
|
+
for key, value in config.items():
|
|
1127
|
+
if key in params:
|
|
1128
|
+
params[key] = value
|
|
1129
|
+
else:
|
|
1130
|
+
logger.warning(f"Unknown config parameter: {key}")
|
|
1131
|
+
else:
|
|
1132
|
+
raise ValueError("config must be None, a file path (str), or a dict")
|
|
1133
|
+
|
|
1134
|
+
# Setup output directory and file paths
|
|
1135
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1136
|
+
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
|
1137
|
+
|
|
1138
|
+
output_paths = {
|
|
1139
|
+
"results_csv": os.path.join(output_dir, f"{video_name}_results.csv"),
|
|
1140
|
+
"detections_csv": os.path.join(output_dir, f"{video_name}_detections.csv"),
|
|
1141
|
+
}
|
|
1142
|
+
|
|
1143
|
+
if save_video:
|
|
1144
|
+
output_paths["annotated_video"] = os.path.join(output_dir, f"{video_name}_annotated.mp4")
|
|
1145
|
+
output_paths["debug_video"] = os.path.join(output_dir, f"{video_name}_debug.mp4")
|
|
1146
|
+
|
|
1147
|
+
# Setup crops directory if requested
|
|
1148
|
+
crops_dir = os.path.join(output_dir, f"{video_name}_crops") if crops else None
|
|
1149
|
+
if crops_dir:
|
|
1150
|
+
output_paths["crops_dir"] = crops_dir
|
|
1151
|
+
|
|
1152
|
+
print("\n" + "="*60)
|
|
1153
|
+
print("BPLUSPLUS INFERENCE")
|
|
1154
|
+
print("="*60)
|
|
1155
|
+
print(f"Video: {video_path}")
|
|
1156
|
+
print(f"Model: {hierarchical_model_path}")
|
|
1157
|
+
print(f"Output directory: {output_dir}")
|
|
1158
|
+
print("\nOutput files:")
|
|
1159
|
+
for name, path in output_paths.items():
|
|
1160
|
+
print(f" {name}: {os.path.basename(path)}")
|
|
1161
|
+
print("\nDetection Parameters:")
|
|
1162
|
+
for key, value in params.items():
|
|
1163
|
+
print(f" {key}: {value}")
|
|
1164
|
+
print("="*60)
|
|
1165
|
+
|
|
1166
|
+
# Process
|
|
1167
|
+
processor = VideoInferenceProcessor(
|
|
1168
|
+
species_list=species_list,
|
|
1169
|
+
hierarchical_model_path=hierarchical_model_path,
|
|
1170
|
+
params=params,
|
|
1171
|
+
backbone=backbone,
|
|
1172
|
+
img_size=img_size,
|
|
1173
|
+
)
|
|
1174
|
+
|
|
1175
|
+
try:
|
|
1176
|
+
results = process_video(
|
|
1177
|
+
video_path=video_path,
|
|
1178
|
+
processor=processor,
|
|
1179
|
+
output_paths=output_paths,
|
|
1180
|
+
fps=fps,
|
|
1181
|
+
crops_dir=crops_dir
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
return {
|
|
1185
|
+
"video_file": os.path.basename(video_path),
|
|
1186
|
+
"output_dir": output_dir,
|
|
1187
|
+
"output_files": output_paths,
|
|
1188
|
+
"success": True,
|
|
1189
|
+
"detections": len(processor.all_detections),
|
|
1190
|
+
"tracks": len(results),
|
|
1191
|
+
"confirmed_tracks": len([r for r in results if r.get('passes_topology', False)]),
|
|
1192
|
+
}
|
|
1193
|
+
except Exception as e:
|
|
1194
|
+
logger.exception("Inference failed")
|
|
1195
|
+
return {"error": str(e), "success": False}
|
|
1196
|
+
|
|
1197
|
+
|
|
1198
|
+
# ============================================================================
|
|
1199
|
+
# COMMAND LINE INTERFACE
|
|
1200
|
+
# ============================================================================
|
|
1201
|
+
|
|
1202
|
+
def main():
|
|
1203
|
+
"""Command line interface for inference."""
|
|
1204
|
+
parser = argparse.ArgumentParser(
|
|
1205
|
+
description='Bplusplus Video Inference - Detect and classify insects in videos',
|
|
1206
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
1207
|
+
epilog="""
|
|
1208
|
+
Examples:
|
|
1209
|
+
# Basic usage
|
|
1210
|
+
python -m bplusplus.inference --video input.mp4 --model model.pt \\
|
|
1211
|
+
--output-dir results/ --species "Apis mellifera" "Bombus terrestris"
|
|
1212
|
+
|
|
1213
|
+
# With config file
|
|
1214
|
+
python -m bplusplus.inference --video input.mp4 --model model.pt \\
|
|
1215
|
+
--output-dir results/ --species "Apis mellifera" --config detection_config.yaml
|
|
1216
|
+
|
|
1217
|
+
Output files generated in output directory:
|
|
1218
|
+
- {video_name}_annotated.mp4: Video with detection boxes and paths
|
|
1219
|
+
- {video_name}_debug.mp4: Side-by-side view with GMM motion mask
|
|
1220
|
+
- {video_name}_results.csv: Aggregated track results
|
|
1221
|
+
- {video_name}_detections.csv: Frame-by-frame detections
|
|
1222
|
+
- {video_name}_crops/ (with --crops): Cropped frames for each track
|
|
1223
|
+
"""
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
# Required arguments
|
|
1227
|
+
parser.add_argument('--video', '-v', help='Input video path')
|
|
1228
|
+
parser.add_argument('--model', '-m', help='Path to hierarchical model weights')
|
|
1229
|
+
parser.add_argument('--output-dir', '-o', help='Output directory for all generated files')
|
|
1230
|
+
parser.add_argument('--species', '-s', nargs='+', help='List of species names')
|
|
1231
|
+
|
|
1232
|
+
# Config
|
|
1233
|
+
parser.add_argument('--config', '-c', help='Path to config file (YAML or JSON)')
|
|
1234
|
+
|
|
1235
|
+
# Processing
|
|
1236
|
+
parser.add_argument('--fps', type=float, help='Target processing FPS')
|
|
1237
|
+
parser.add_argument('--show', action='store_true', help='Display video while processing')
|
|
1238
|
+
parser.add_argument('--backbone', '-b', default='resnet50',
|
|
1239
|
+
choices=['resnet18', 'resnet50', 'resnet101'],
|
|
1240
|
+
help='ResNet backbone (default: resnet50, overridden by checkpoint if saved)')
|
|
1241
|
+
parser.add_argument('--crops', action='store_true',
|
|
1242
|
+
help='Save cropped frames for each classified track')
|
|
1243
|
+
parser.add_argument('--no-video', action='store_true',
|
|
1244
|
+
help='Skip saving annotated and debug videos')
|
|
1245
|
+
parser.add_argument('--img-size', type=int, default=60,
|
|
1246
|
+
help='Image size for classification (should match training, default: 60)')
|
|
1247
|
+
|
|
1248
|
+
# Detection parameters (override config)
|
|
1249
|
+
defaults = DEFAULT_DETECTION_CONFIG
|
|
1250
|
+
|
|
1251
|
+
cohesive = parser.add_argument_group('Cohesiveness parameters')
|
|
1252
|
+
cohesive.add_argument('--min-blob-ratio', type=float,
|
|
1253
|
+
help=f'Min largest blob ratio (default: {defaults["min_largest_blob_ratio"]})')
|
|
1254
|
+
cohesive.add_argument('--max-num-blobs', type=int,
|
|
1255
|
+
help=f'Max number of blobs (default: {defaults["max_num_blobs"]})')
|
|
1256
|
+
|
|
1257
|
+
shape = parser.add_argument_group('Shape parameters')
|
|
1258
|
+
shape.add_argument('--min-area', type=int,
|
|
1259
|
+
help=f'Min area in px² (default: {defaults["min_area"]})')
|
|
1260
|
+
shape.add_argument('--max-area', type=int,
|
|
1261
|
+
help=f'Max area in px² (default: {defaults["max_area"]})')
|
|
1262
|
+
shape.add_argument('--min-density', type=float,
|
|
1263
|
+
help=f'Min density (default: {defaults["min_density"]})')
|
|
1264
|
+
shape.add_argument('--min-solidity', type=float,
|
|
1265
|
+
help=f'Min solidity (default: {defaults["min_solidity"]})')
|
|
1266
|
+
|
|
1267
|
+
tracking = parser.add_argument_group('Tracking parameters')
|
|
1268
|
+
tracking.add_argument('--min-displacement', type=int,
|
|
1269
|
+
help=f'Min NET displacement in px (default: {defaults["min_displacement"]})')
|
|
1270
|
+
tracking.add_argument('--min-path-points', type=int,
|
|
1271
|
+
help=f'Min path points (default: {defaults["min_path_points"]})')
|
|
1272
|
+
tracking.add_argument('--max-frame-jump', type=int,
|
|
1273
|
+
help=f'Max pixels between frames (default: {defaults["max_frame_jump"]})')
|
|
1274
|
+
tracking.add_argument('--lost-track-seconds', type=float,
|
|
1275
|
+
help=f'Lost track memory in seconds (default: {defaults["lost_track_seconds"]})')
|
|
1276
|
+
|
|
1277
|
+
topology = parser.add_argument_group('Path topology parameters')
|
|
1278
|
+
topology.add_argument('--max-revisit-ratio', type=float,
|
|
1279
|
+
help=f'Max revisit ratio (default: {defaults["max_revisit_ratio"]})')
|
|
1280
|
+
topology.add_argument('--min-progression-ratio', type=float,
|
|
1281
|
+
help=f'Min progression ratio (default: {defaults["min_progression_ratio"]})')
|
|
1282
|
+
topology.add_argument('--max-directional-variance', type=float,
|
|
1283
|
+
help=f'Max directional variance (default: {defaults["max_directional_variance"]})')
|
|
1284
|
+
|
|
1285
|
+
args = parser.parse_args()
|
|
1286
|
+
|
|
1287
|
+
# Validate required args
|
|
1288
|
+
if not all([args.video, args.model, args.output_dir, args.species]):
|
|
1289
|
+
parser.error("--video, --model, --output-dir, and --species are required")
|
|
1290
|
+
|
|
1291
|
+
# Build config: start with file if provided, then override with CLI args
|
|
1292
|
+
if args.config:
|
|
1293
|
+
config = args.config # Pass path, inference() will load it
|
|
1294
|
+
else:
|
|
1295
|
+
# Build dict from CLI args (only non-None values)
|
|
1296
|
+
cli_overrides = {
|
|
1297
|
+
"min_largest_blob_ratio": args.min_blob_ratio,
|
|
1298
|
+
"max_num_blobs": args.max_num_blobs,
|
|
1299
|
+
"min_area": args.min_area,
|
|
1300
|
+
"max_area": args.max_area,
|
|
1301
|
+
"min_density": args.min_density,
|
|
1302
|
+
"min_solidity": args.min_solidity,
|
|
1303
|
+
"min_displacement": args.min_displacement,
|
|
1304
|
+
"min_path_points": args.min_path_points,
|
|
1305
|
+
"max_frame_jump": args.max_frame_jump,
|
|
1306
|
+
"lost_track_seconds": args.lost_track_seconds,
|
|
1307
|
+
"max_revisit_ratio": args.max_revisit_ratio,
|
|
1308
|
+
"min_progression_ratio": args.min_progression_ratio,
|
|
1309
|
+
"max_directional_variance": args.max_directional_variance,
|
|
1310
|
+
}
|
|
1311
|
+
config = {k: v for k, v in cli_overrides.items() if v is not None} or None
|
|
1312
|
+
|
|
1313
|
+
# Run inference
|
|
1314
|
+
result = inference(
|
|
1315
|
+
species_list=args.species,
|
|
1316
|
+
hierarchical_model_path=args.model,
|
|
1317
|
+
video_path=args.video,
|
|
1318
|
+
output_dir=args.output_dir,
|
|
1319
|
+
fps=args.fps,
|
|
1320
|
+
config=config,
|
|
1321
|
+
backbone=args.backbone,
|
|
1322
|
+
crops=args.crops,
|
|
1323
|
+
save_video=not args.no_video,
|
|
1324
|
+
img_size=args.img_size,
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
if result.get("success"):
|
|
1328
|
+
print(f"\n✓ Inference complete!")
|
|
1329
|
+
print(f" Output directory: {result['output_dir']}")
|
|
1330
|
+
print(f" Detections: {result['detections']}")
|
|
1331
|
+
print(f" Tracks: {result['tracks']} ({result['confirmed_tracks']} confirmed)")
|
|
1332
|
+
else:
|
|
1333
|
+
print(f"\n✗ Inference failed: {result.get('error')}")
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
if __name__ == "__main__":
|
|
1337
|
+
main()
|