bplusplus 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +3 -5
- bplusplus/collect.py +2 -0
- bplusplus/inference.py +891 -0
- bplusplus/prepare.py +429 -540
- bplusplus/{hierarchical/test.py → test.py} +99 -88
- bplusplus/tracker.py +261 -0
- bplusplus/{hierarchical/train.py → train.py} +29 -29
- bplusplus-1.2.3.dist-info/METADATA +101 -0
- bplusplus-1.2.3.dist-info/RECORD +11 -0
- {bplusplus-1.2.1.dist-info → bplusplus-1.2.3.dist-info}/WHEEL +1 -1
- bplusplus/resnet/test.py +0 -473
- bplusplus/resnet/train.py +0 -329
- bplusplus/train_validate.py +0 -11
- bplusplus-1.2.1.dist-info/METADATA +0 -252
- bplusplus-1.2.1.dist-info/RECORD +0 -12
- {bplusplus-1.2.1.dist-info → bplusplus-1.2.3.dist-info}/LICENSE +0 -0
|
@@ -23,7 +23,7 @@ import sys
|
|
|
23
23
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
|
-
def
|
|
26
|
+
def test(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
|
|
27
27
|
"""
|
|
28
28
|
Run the two-stage classifier on a test set.
|
|
29
29
|
|
|
@@ -115,17 +115,17 @@ class HierarchicalInsectClassifier(nn.Module):
|
|
|
115
115
|
def get_taxonomy(species_list):
|
|
116
116
|
"""
|
|
117
117
|
Retrieves taxonomic information for a list of species from GBIF API.
|
|
118
|
-
Creates a hierarchical taxonomy dictionary with
|
|
118
|
+
Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
|
|
119
119
|
"""
|
|
120
120
|
taxonomy = {1: [], 2: {}, 3: {}}
|
|
121
|
-
|
|
122
|
-
|
|
121
|
+
species_to_genus = {}
|
|
122
|
+
genus_to_family = {}
|
|
123
123
|
|
|
124
124
|
logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
|
|
125
125
|
|
|
126
126
|
print("\nTaxonomy Results:")
|
|
127
127
|
print("-" * 80)
|
|
128
|
-
print(f"{'Species':<30} {'
|
|
128
|
+
print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
|
|
129
129
|
print("-" * 80)
|
|
130
130
|
|
|
131
131
|
for species_name in species_list:
|
|
@@ -136,23 +136,23 @@ def get_taxonomy(species_list):
|
|
|
136
136
|
|
|
137
137
|
if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
|
|
138
138
|
family = data.get('family')
|
|
139
|
-
|
|
139
|
+
genus = data.get('genus')
|
|
140
140
|
|
|
141
|
-
if family and
|
|
141
|
+
if family and genus:
|
|
142
142
|
status = "OK"
|
|
143
143
|
|
|
144
|
-
print(f"{species_name:<30} {
|
|
144
|
+
print(f"{species_name:<30} {family:<20} {genus:<20} {status}")
|
|
145
145
|
|
|
146
|
-
|
|
147
|
-
|
|
146
|
+
species_to_genus[species_name] = genus
|
|
147
|
+
genus_to_family[genus] = family
|
|
148
148
|
|
|
149
|
-
if
|
|
150
|
-
taxonomy[1].append(
|
|
149
|
+
if family not in taxonomy[1]:
|
|
150
|
+
taxonomy[1].append(family)
|
|
151
151
|
|
|
152
|
-
taxonomy[2][
|
|
153
|
-
taxonomy[3][species_name] =
|
|
152
|
+
taxonomy[2][genus] = family
|
|
153
|
+
taxonomy[3][species_name] = genus
|
|
154
154
|
else:
|
|
155
|
-
error_msg = f"Species '{species_name}' found in GBIF but family and
|
|
155
|
+
error_msg = f"Species '{species_name}' found in GBIF but family and genus not found, could be spelling error in species, check GBIF"
|
|
156
156
|
logger.error(error_msg)
|
|
157
157
|
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
158
158
|
print(f"Error: {error_msg}")
|
|
@@ -174,24 +174,24 @@ def get_taxonomy(species_list):
|
|
|
174
174
|
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
175
175
|
print("-" * 80)
|
|
176
176
|
|
|
177
|
-
|
|
178
|
-
|
|
177
|
+
num_families = len(taxonomy[1])
|
|
178
|
+
num_genera = len(taxonomy[2])
|
|
179
179
|
num_species = len(taxonomy[3])
|
|
180
180
|
|
|
181
|
-
print("\nOrder indices:")
|
|
182
|
-
for i, order in enumerate(taxonomy[1]):
|
|
183
|
-
print(f" {i}: {order}")
|
|
184
|
-
|
|
185
181
|
print("\nFamily indices:")
|
|
186
|
-
for i, family in enumerate(taxonomy[
|
|
182
|
+
for i, family in enumerate(taxonomy[1]):
|
|
187
183
|
print(f" {i}: {family}")
|
|
188
184
|
|
|
185
|
+
print("\nGenus indices:")
|
|
186
|
+
for i, genus in enumerate(taxonomy[2].keys()):
|
|
187
|
+
print(f" {i}: {genus}")
|
|
188
|
+
|
|
189
189
|
print("\nSpecies indices:")
|
|
190
190
|
for i, species in enumerate(species_list):
|
|
191
191
|
print(f" {i}: {species}")
|
|
192
192
|
|
|
193
|
-
logger.info(f"Taxonomy built: {
|
|
194
|
-
return taxonomy,
|
|
193
|
+
logger.info(f"Taxonomy built: {num_families} families, {num_genera} genera, {num_species} species")
|
|
194
|
+
return taxonomy, species_to_genus, genus_to_family
|
|
195
195
|
|
|
196
196
|
def create_mappings(taxonomy):
|
|
197
197
|
"""Create index mappings from taxonomy"""
|
|
@@ -243,13 +243,20 @@ class TestTwoStage:
|
|
|
243
243
|
if "species_list" in checkpoint:
|
|
244
244
|
saved_species = checkpoint["species_list"]
|
|
245
245
|
print(f"Saved model was trained on: {', '.join(saved_species)}")
|
|
246
|
-
|
|
247
|
-
taxonomy
|
|
246
|
+
|
|
247
|
+
# Use saved taxonomy mappings if available
|
|
248
|
+
if "species_to_genus" in checkpoint and "genus_to_family" in checkpoint:
|
|
249
|
+
species_to_genus = checkpoint["species_to_genus"]
|
|
250
|
+
genus_to_family = checkpoint["genus_to_family"]
|
|
251
|
+
else:
|
|
252
|
+
# Fallback: fetch from GBIF but this may cause index mismatches
|
|
253
|
+
print("Warning: No taxonomy mappings in checkpoint, fetching from GBIF")
|
|
254
|
+
_, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
248
255
|
else:
|
|
249
|
-
taxonomy,
|
|
256
|
+
taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
250
257
|
else:
|
|
251
258
|
state_dict = checkpoint
|
|
252
|
-
taxonomy,
|
|
259
|
+
taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
253
260
|
|
|
254
261
|
level_to_idx, idx_to_level = create_mappings(taxonomy)
|
|
255
262
|
|
|
@@ -259,8 +266,6 @@ class TestTwoStage:
|
|
|
259
266
|
if hasattr(taxonomy, "items"):
|
|
260
267
|
num_classes_per_level = [len(classes) if isinstance(classes, list) else len(classes.keys())
|
|
261
268
|
for level, classes in taxonomy.items()]
|
|
262
|
-
else:
|
|
263
|
-
num_classes_per_level = [4, 5, 9] # Example values, adjust as needed
|
|
264
269
|
|
|
265
270
|
print(f"Using model with class counts: {num_classes_per_level}")
|
|
266
271
|
|
|
@@ -287,8 +292,6 @@ class TestTwoStage:
|
|
|
287
292
|
self.classification_model.eval()
|
|
288
293
|
|
|
289
294
|
self.classification_transform = transforms.Compose([
|
|
290
|
-
transforms.Resize((768, 768)), # Fixed size for all validation images
|
|
291
|
-
transforms.CenterCrop(640),
|
|
292
295
|
transforms.ToTensor(),
|
|
293
296
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
294
297
|
])
|
|
@@ -296,8 +299,8 @@ class TestTwoStage:
|
|
|
296
299
|
print("Model successfully loaded")
|
|
297
300
|
print(f"Using species: {', '.join(species_names)}")
|
|
298
301
|
|
|
299
|
-
self.
|
|
300
|
-
self.
|
|
302
|
+
self.species_to_genus = species_to_genus
|
|
303
|
+
self.genus_to_family = genus_to_family
|
|
301
304
|
|
|
302
305
|
def get_frames(self, test_dir):
|
|
303
306
|
image_dir = os.path.join(test_dir, "images")
|
|
@@ -305,10 +308,10 @@ class TestTwoStage:
|
|
|
305
308
|
|
|
306
309
|
predicted_frames = []
|
|
307
310
|
predicted_family_frames = []
|
|
308
|
-
|
|
311
|
+
predicted_genus_frames = []
|
|
309
312
|
true_species_frames = []
|
|
310
313
|
true_family_frames = []
|
|
311
|
-
|
|
314
|
+
true_genus_frames = []
|
|
312
315
|
image_names = []
|
|
313
316
|
|
|
314
317
|
start_time = time.time() # Start timing
|
|
@@ -326,7 +329,7 @@ class TestTwoStage:
|
|
|
326
329
|
detections = results[0].boxes
|
|
327
330
|
predicted_frame = []
|
|
328
331
|
predicted_family_frame = []
|
|
329
|
-
|
|
332
|
+
predicted_genus_frame = []
|
|
330
333
|
|
|
331
334
|
if detections:
|
|
332
335
|
for box in detections:
|
|
@@ -346,13 +349,13 @@ class TestTwoStage:
|
|
|
346
349
|
outputs = self.classification_model(input_tensor)
|
|
347
350
|
|
|
348
351
|
# Get all taxonomic level predictions
|
|
349
|
-
|
|
350
|
-
|
|
352
|
+
family_output = outputs[0] # First output is family (level 1)
|
|
353
|
+
genus_output = outputs[1] # Second output is genus (level 2)
|
|
351
354
|
species_output = outputs[2] # Third output is species (level 3)
|
|
352
355
|
|
|
353
356
|
# Get prediction indices
|
|
354
|
-
order_idx = order_output.argmax(dim=1).item()
|
|
355
357
|
family_idx = family_output.argmax(dim=1).item()
|
|
358
|
+
genus_idx = genus_output.argmax(dim=1).item()
|
|
356
359
|
species_idx = species_output.argmax(dim=1).item()
|
|
357
360
|
|
|
358
361
|
img_height, img_width, _ = frame.shape
|
|
@@ -367,15 +370,15 @@ class TestTwoStage:
|
|
|
367
370
|
# Add predictions for each taxonomic level
|
|
368
371
|
predicted_frame.append([species_idx] + box_coords)
|
|
369
372
|
predicted_family_frame.append([family_idx] + box_coords)
|
|
370
|
-
|
|
373
|
+
predicted_genus_frame.append([genus_idx] + box_coords)
|
|
371
374
|
|
|
372
375
|
predicted_frames.append(predicted_frame if predicted_frame else [])
|
|
373
376
|
predicted_family_frames.append(predicted_family_frame if predicted_family_frame else [])
|
|
374
|
-
|
|
377
|
+
predicted_genus_frames.append(predicted_genus_frame if predicted_genus_frame else [])
|
|
375
378
|
|
|
376
379
|
true_species_frame = []
|
|
377
380
|
true_family_frame = []
|
|
378
|
-
|
|
381
|
+
true_genus_frame = []
|
|
379
382
|
|
|
380
383
|
if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
|
|
381
384
|
with open(label_path, 'r') as f:
|
|
@@ -389,22 +392,22 @@ class TestTwoStage:
|
|
|
389
392
|
if species_idx < len(self.species_names):
|
|
390
393
|
species_name = self.species_names[species_idx]
|
|
391
394
|
|
|
392
|
-
if species_name in self.
|
|
393
|
-
|
|
394
|
-
# Get the index of the
|
|
395
|
-
if 2 in self.level_to_idx and
|
|
396
|
-
|
|
397
|
-
|
|
395
|
+
if species_name in self.species_to_genus:
|
|
396
|
+
genus_name = self.species_to_genus[species_name]
|
|
397
|
+
# Get the index of the genus in the level_to_idx mapping
|
|
398
|
+
if 2 in self.level_to_idx and genus_name in self.level_to_idx[2]:
|
|
399
|
+
genus_idx = self.level_to_idx[2][genus_name]
|
|
400
|
+
true_genus_frame.append([genus_idx] + box_coords)
|
|
398
401
|
|
|
399
|
-
if
|
|
400
|
-
|
|
401
|
-
if 1 in self.level_to_idx and
|
|
402
|
-
|
|
403
|
-
|
|
402
|
+
if genus_name in self.genus_to_family:
|
|
403
|
+
family_name = self.genus_to_family[genus_name]
|
|
404
|
+
if 1 in self.level_to_idx and family_name in self.level_to_idx[1]:
|
|
405
|
+
family_idx = self.level_to_idx[1][family_name]
|
|
406
|
+
true_family_frame.append([family_idx] + box_coords)
|
|
404
407
|
|
|
405
408
|
true_species_frames.append(true_species_frame if true_species_frame else [])
|
|
406
409
|
true_family_frames.append(true_family_frame if true_family_frame else [])
|
|
407
|
-
|
|
410
|
+
true_genus_frames.append(true_genus_frame if true_genus_frame else [])
|
|
408
411
|
|
|
409
412
|
end_time = time.time() # End timing
|
|
410
413
|
|
|
@@ -416,42 +419,42 @@ class TestTwoStage:
|
|
|
416
419
|
writer.writerow([
|
|
417
420
|
"Image Name",
|
|
418
421
|
"True Species Detections",
|
|
422
|
+
"True Genus Detections",
|
|
419
423
|
"True Family Detections",
|
|
420
|
-
"True Order Detections",
|
|
421
424
|
"Species Detections",
|
|
422
|
-
"
|
|
423
|
-
"
|
|
425
|
+
"Genus Detections",
|
|
426
|
+
"Family Detections"
|
|
424
427
|
])
|
|
425
428
|
|
|
426
|
-
for image_name, true_species,
|
|
429
|
+
for image_name, true_species, true_genus, true_family, species_pred, genus_pred, family_pred in zip(
|
|
427
430
|
image_names,
|
|
428
431
|
true_species_frames,
|
|
432
|
+
true_genus_frames,
|
|
429
433
|
true_family_frames,
|
|
430
|
-
true_order_frames,
|
|
431
434
|
predicted_frames,
|
|
432
|
-
|
|
433
|
-
|
|
435
|
+
predicted_genus_frames,
|
|
436
|
+
predicted_family_frames
|
|
434
437
|
):
|
|
435
438
|
writer.writerow([
|
|
436
439
|
image_name,
|
|
437
440
|
true_species,
|
|
441
|
+
true_genus,
|
|
438
442
|
true_family,
|
|
439
|
-
true_order,
|
|
440
443
|
species_pred,
|
|
441
|
-
|
|
442
|
-
|
|
444
|
+
genus_pred,
|
|
445
|
+
family_pred
|
|
443
446
|
])
|
|
444
447
|
|
|
445
448
|
print(f"Results saved to {output_file}")
|
|
446
|
-
return predicted_frames, true_species_frames, end_time - start_time, predicted_family_frames,
|
|
449
|
+
return predicted_frames, true_species_frames, end_time - start_time, predicted_genus_frames, predicted_family_frames, true_genus_frames, true_family_frames
|
|
447
450
|
|
|
448
451
|
def run(self, test_dir):
|
|
449
452
|
results = self.get_frames(test_dir)
|
|
450
453
|
predicted_frames, true_species_frames, total_time = results[0], results[1], results[2]
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
454
|
+
predicted_genus_frames = results[3]
|
|
455
|
+
predicted_family_frames = results[4]
|
|
456
|
+
true_genus_frames = results[5]
|
|
457
|
+
true_family_frames = results[6]
|
|
455
458
|
|
|
456
459
|
num_frames = len(os.listdir(os.path.join(test_dir, 'images')))
|
|
457
460
|
avg_time_per_frame = total_time / num_frames
|
|
@@ -461,29 +464,37 @@ class TestTwoStage:
|
|
|
461
464
|
|
|
462
465
|
self.calculate_metrics(
|
|
463
466
|
predicted_frames, true_species_frames,
|
|
464
|
-
|
|
465
|
-
|
|
467
|
+
predicted_genus_frames, true_genus_frames,
|
|
468
|
+
predicted_family_frames, true_family_frames
|
|
466
469
|
)
|
|
467
470
|
|
|
468
471
|
def calculate_metrics(self, predicted_species_frames, true_species_frames,
|
|
469
|
-
|
|
470
|
-
|
|
472
|
+
predicted_genus_frames, true_genus_frames,
|
|
473
|
+
predicted_family_frames, true_family_frames):
|
|
471
474
|
"""Calculate metrics at all taxonomic levels"""
|
|
472
|
-
# Get list of species, families and
|
|
475
|
+
# Get list of species, families and genera using the same order as model training
|
|
473
476
|
species_list = self.species_names
|
|
474
|
-
|
|
475
|
-
|
|
477
|
+
|
|
478
|
+
# Use the index mappings from the model to ensure consistency
|
|
479
|
+
if 1 in self.idx_to_level and 2 in self.idx_to_level:
|
|
480
|
+
family_list = [self.idx_to_level[1][i] for i in sorted(self.idx_to_level[1].keys())]
|
|
481
|
+
genus_list = [self.idx_to_level[2][i] for i in sorted(self.idx_to_level[2].keys())]
|
|
482
|
+
else:
|
|
483
|
+
# Fallback to sorted lists (may cause issues)
|
|
484
|
+
print("Warning: Using fallback sorted lists for taxonomy - this may cause index mismatches")
|
|
485
|
+
genus_list = sorted(list(set(self.species_to_genus.values())))
|
|
486
|
+
family_list = sorted(list(set(self.genus_to_family.values())))
|
|
476
487
|
|
|
477
488
|
# Print the index mappings we're using for evaluation
|
|
478
489
|
print("\nUsing the following index mappings for evaluation:")
|
|
479
|
-
print("\nOrder indices:")
|
|
480
|
-
for i, order in enumerate(order_list):
|
|
481
|
-
print(f" {i}: {order}")
|
|
482
|
-
|
|
483
490
|
print("\nFamily indices:")
|
|
484
491
|
for i, family in enumerate(family_list):
|
|
485
492
|
print(f" {i}: {family}")
|
|
486
493
|
|
|
494
|
+
print("\nGenus indices:")
|
|
495
|
+
for i, genus in enumerate(genus_list):
|
|
496
|
+
print(f" {i}: {genus}")
|
|
497
|
+
|
|
487
498
|
print("\nSpecies indices:")
|
|
488
499
|
for i, species in enumerate(species_list):
|
|
489
500
|
print(f" {i}: {species}")
|
|
@@ -491,11 +502,11 @@ class TestTwoStage:
|
|
|
491
502
|
# Dictionary to track prediction category counts for debugging
|
|
492
503
|
prediction_counts = {
|
|
493
504
|
"true_species_boxes": sum(len(frame) for frame in true_species_frames),
|
|
505
|
+
"true_genus_boxes": sum(len(frame) for frame in true_genus_frames),
|
|
494
506
|
"true_family_boxes": sum(len(frame) for frame in true_family_frames),
|
|
495
|
-
"true_order_boxes": sum(len(frame) for frame in true_order_frames),
|
|
496
507
|
"predicted_species": sum(len(frame) for frame in predicted_species_frames),
|
|
497
|
-
"
|
|
498
|
-
"
|
|
508
|
+
"predicted_genus": sum(len(frame) for frame in predicted_genus_frames),
|
|
509
|
+
"predicted_family": sum(len(frame) for frame in predicted_family_frames)
|
|
499
510
|
}
|
|
500
511
|
|
|
501
512
|
print(f"Prediction counts: {prediction_counts}")
|
|
@@ -504,11 +515,11 @@ class TestTwoStage:
|
|
|
504
515
|
print("\n=== Species-level Metrics ===")
|
|
505
516
|
self.get_metrics(predicted_species_frames, true_species_frames, species_list)
|
|
506
517
|
|
|
518
|
+
print("\n=== Genus-level Metrics ===")
|
|
519
|
+
self.get_metrics(predicted_genus_frames, true_genus_frames, genus_list)
|
|
520
|
+
|
|
507
521
|
print("\n=== Family-level Metrics ===")
|
|
508
522
|
self.get_metrics(predicted_family_frames, true_family_frames, family_list)
|
|
509
|
-
|
|
510
|
-
print("\n=== Order-level Metrics ===")
|
|
511
|
-
self.get_metrics(predicted_order_frames, true_order_frames, order_list)
|
|
512
523
|
|
|
513
524
|
def get_metrics(self, predicted_frames, true_frames, labels):
|
|
514
525
|
"""Calculate metrics for object detection predictions"""
|
|
@@ -667,4 +678,4 @@ if __name__ == "__main__":
|
|
|
667
678
|
hierarchical_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/hierarchical/hierarchical-weights.pth"
|
|
668
679
|
output_directory = "./output"
|
|
669
680
|
|
|
670
|
-
|
|
681
|
+
test(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
|
bplusplus/tracker.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import uuid
|
|
3
|
+
from scipy.optimize import linear_sum_assignment
|
|
4
|
+
from collections import deque
|
|
5
|
+
|
|
6
|
+
class BoundingBox:
|
|
7
|
+
def __init__(self, x, y, width, height, frame_id, track_id=None):
|
|
8
|
+
self.x = x
|
|
9
|
+
self.y = y
|
|
10
|
+
self.width = width
|
|
11
|
+
self.height = height
|
|
12
|
+
self.area = width * height
|
|
13
|
+
self.frame_id = frame_id
|
|
14
|
+
self.track_id = track_id
|
|
15
|
+
|
|
16
|
+
def center(self):
|
|
17
|
+
return (self.x + self.width/2, self.y + self.height/2)
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_xyxy(cls, x1, y1, x2, y2, frame_id, track_id=None):
|
|
21
|
+
"""Create BoundingBox from x1,y1,x2,y2 coordinates"""
|
|
22
|
+
width = x2 - x1
|
|
23
|
+
height = y2 - y1
|
|
24
|
+
return cls(x1, y1, width, height, frame_id, track_id)
|
|
25
|
+
|
|
26
|
+
class InsectTracker:
|
|
27
|
+
def __init__(self, image_height, image_width, max_frames=30, w_dist=0.7, w_area=0.3, cost_threshold=0.8, track_memory_frames=None, debug=False):
|
|
28
|
+
self.image_height = image_height
|
|
29
|
+
self.image_width = image_width
|
|
30
|
+
self.max_dist = np.sqrt(image_height**2 + image_width**2)
|
|
31
|
+
self.max_frames = max_frames
|
|
32
|
+
self.w_dist = w_dist
|
|
33
|
+
self.w_area = w_area
|
|
34
|
+
self.cost_threshold = cost_threshold
|
|
35
|
+
self.debug = debug
|
|
36
|
+
|
|
37
|
+
# If track_memory_frames not specified, use max_frames (full history window)
|
|
38
|
+
self.track_memory_frames = track_memory_frames if track_memory_frames is not None else max_frames
|
|
39
|
+
if self.debug:
|
|
40
|
+
print(f"DEBUG: Tracker initialized with max_frames={max_frames}, track_memory_frames={self.track_memory_frames}")
|
|
41
|
+
|
|
42
|
+
self.tracking_history = deque(maxlen=max_frames)
|
|
43
|
+
self.current_tracks = []
|
|
44
|
+
self.lost_tracks = {} # track_id -> {box: BoundingBox, frames_lost: int}
|
|
45
|
+
|
|
46
|
+
def _generate_track_id(self):
|
|
47
|
+
"""Generate a unique UUID for a new track"""
|
|
48
|
+
return str(uuid.uuid4())
|
|
49
|
+
|
|
50
|
+
def calculate_cost(self, box1, box2):
|
|
51
|
+
"""Calculate cost between two bounding boxes as per equation (4)"""
|
|
52
|
+
# Calculate center points
|
|
53
|
+
cx1, cy1 = box1.center()
|
|
54
|
+
cx2, cy2 = box2.center()
|
|
55
|
+
|
|
56
|
+
# Euclidean distance (equation 1)
|
|
57
|
+
dist = np.sqrt((cx2 - cx1)**2 + (cy2 - cy1)**2)
|
|
58
|
+
|
|
59
|
+
# Normalized distance (equation 2 used for normalization)
|
|
60
|
+
norm_dist = dist / self.max_dist
|
|
61
|
+
|
|
62
|
+
# Area cost (equation 3)
|
|
63
|
+
min_area = min(box1.area, box2.area)
|
|
64
|
+
max_area = max(box1.area, box2.area)
|
|
65
|
+
area_cost = min_area / max_area if max_area > 0 else 1.0
|
|
66
|
+
|
|
67
|
+
# Final cost (equation 4)
|
|
68
|
+
cost = (norm_dist * self.w_dist) + ((1 - area_cost) * self.w_area)
|
|
69
|
+
|
|
70
|
+
return cost
|
|
71
|
+
|
|
72
|
+
def build_cost_matrix(self, prev_boxes, curr_boxes):
|
|
73
|
+
"""Build cost matrix for Hungarian algorithm"""
|
|
74
|
+
n_prev = len(prev_boxes)
|
|
75
|
+
n_curr = len(curr_boxes)
|
|
76
|
+
n = max(n_prev, n_curr)
|
|
77
|
+
|
|
78
|
+
# Initialize cost matrix with high values
|
|
79
|
+
cost_matrix = np.ones((n, n)) * 999.0
|
|
80
|
+
|
|
81
|
+
# Fill in actual costs
|
|
82
|
+
for i in range(n_prev):
|
|
83
|
+
for j in range(n_curr):
|
|
84
|
+
cost_matrix[i, j] = self.calculate_cost(prev_boxes[i], curr_boxes[j])
|
|
85
|
+
|
|
86
|
+
return cost_matrix, n_prev, n_curr
|
|
87
|
+
|
|
88
|
+
def update(self, new_detections, frame_id):
|
|
89
|
+
"""
|
|
90
|
+
Update tracking with new detections from YOLO
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
new_detections: List of YOLO detection boxes (x1, y1, x2, y2 format)
|
|
94
|
+
frame_id: Current frame number
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
List of track IDs corresponding to each detection
|
|
98
|
+
"""
|
|
99
|
+
# Handle empty detection list (no detections in this frame)
|
|
100
|
+
if not new_detections:
|
|
101
|
+
if self.debug:
|
|
102
|
+
print(f"DEBUG: Frame {frame_id} has no detections")
|
|
103
|
+
# Move all current tracks to lost tracks
|
|
104
|
+
for track in self.current_tracks:
|
|
105
|
+
if track.track_id not in self.lost_tracks:
|
|
106
|
+
self.lost_tracks[track.track_id] = {
|
|
107
|
+
'box': track,
|
|
108
|
+
'frames_lost': 1
|
|
109
|
+
}
|
|
110
|
+
if self.debug:
|
|
111
|
+
print(f"DEBUG: Moved track {track.track_id} to lost tracks")
|
|
112
|
+
else:
|
|
113
|
+
self.lost_tracks[track.track_id]['frames_lost'] += 1
|
|
114
|
+
|
|
115
|
+
# Age lost tracks and remove old ones
|
|
116
|
+
self._age_lost_tracks()
|
|
117
|
+
|
|
118
|
+
self.current_tracks = []
|
|
119
|
+
self.tracking_history.append([])
|
|
120
|
+
return []
|
|
121
|
+
|
|
122
|
+
# Convert YOLO detections to BoundingBox objects
|
|
123
|
+
new_boxes = []
|
|
124
|
+
for i, detection in enumerate(new_detections):
|
|
125
|
+
x1, y1, x2, y2 = detection[:4]
|
|
126
|
+
bbox = BoundingBox.from_xyxy(x1, y1, x2, y2, frame_id)
|
|
127
|
+
new_boxes.append(bbox)
|
|
128
|
+
|
|
129
|
+
# If this is the first frame or no existing tracks, assign new track IDs to all boxes
|
|
130
|
+
if not self.current_tracks and not self.lost_tracks:
|
|
131
|
+
track_ids = []
|
|
132
|
+
for box in new_boxes:
|
|
133
|
+
box.track_id = self._generate_track_id()
|
|
134
|
+
track_ids.append(box.track_id)
|
|
135
|
+
if self.debug:
|
|
136
|
+
print(f"DEBUG: FIRST FRAME - Assigned track ID {box.track_id} to new detection")
|
|
137
|
+
self.current_tracks = new_boxes
|
|
138
|
+
self.tracking_history.append(new_boxes)
|
|
139
|
+
return track_ids
|
|
140
|
+
|
|
141
|
+
# Combine current tracks and lost tracks for matching
|
|
142
|
+
all_previous_tracks = self.current_tracks.copy()
|
|
143
|
+
lost_track_list = []
|
|
144
|
+
|
|
145
|
+
for track_id, lost_info in self.lost_tracks.items():
|
|
146
|
+
lost_track_list.append(lost_info['box'])
|
|
147
|
+
lost_track_list[-1].track_id = track_id # Ensure track_id is preserved
|
|
148
|
+
|
|
149
|
+
all_previous_tracks.extend(lost_track_list)
|
|
150
|
+
|
|
151
|
+
if not all_previous_tracks:
|
|
152
|
+
# No previous tracks at all, assign new IDs
|
|
153
|
+
track_ids = []
|
|
154
|
+
for box in new_boxes:
|
|
155
|
+
box.track_id = self._generate_track_id()
|
|
156
|
+
track_ids.append(box.track_id)
|
|
157
|
+
if self.debug:
|
|
158
|
+
print(f"DEBUG: No previous tracks - Assigned track ID {box.track_id} to new detection")
|
|
159
|
+
self.current_tracks = new_boxes
|
|
160
|
+
self.tracking_history.append(new_boxes)
|
|
161
|
+
return track_ids
|
|
162
|
+
|
|
163
|
+
# Build cost matrix including lost tracks
|
|
164
|
+
cost_matrix, n_prev, n_curr = self.build_cost_matrix(all_previous_tracks, new_boxes)
|
|
165
|
+
|
|
166
|
+
# Apply Hungarian algorithm
|
|
167
|
+
row_indices, col_indices = linear_sum_assignment(cost_matrix)
|
|
168
|
+
|
|
169
|
+
# Assign track IDs based on the matching
|
|
170
|
+
assigned_curr_indices = set()
|
|
171
|
+
track_ids = [None] * len(new_boxes)
|
|
172
|
+
recovered_tracks = set() # Track IDs that were recovered from lost tracks
|
|
173
|
+
|
|
174
|
+
if self.debug:
|
|
175
|
+
print(f"DEBUG: Hungarian assignment - rows: {row_indices}, cols: {col_indices}")
|
|
176
|
+
print(f"DEBUG: Cost threshold: {self.cost_threshold}")
|
|
177
|
+
print(f"DEBUG: Current tracks: {len(self.current_tracks)}, Lost tracks: {len(self.lost_tracks)}")
|
|
178
|
+
|
|
179
|
+
for i, j in zip(row_indices, col_indices):
|
|
180
|
+
# Only consider valid assignments (not dummy rows/columns)
|
|
181
|
+
if i < n_prev and j < n_curr:
|
|
182
|
+
cost = cost_matrix[i, j]
|
|
183
|
+
if self.debug:
|
|
184
|
+
print(f"DEBUG: Checking assignment {i}->{j}, cost: {cost:.3f}")
|
|
185
|
+
# Check if cost is below threshold
|
|
186
|
+
if cost < self.cost_threshold:
|
|
187
|
+
# Assign the track ID from previous box to current box
|
|
188
|
+
prev_track_id = all_previous_tracks[i].track_id
|
|
189
|
+
new_boxes[j].track_id = prev_track_id
|
|
190
|
+
track_ids[j] = prev_track_id
|
|
191
|
+
assigned_curr_indices.add(j)
|
|
192
|
+
|
|
193
|
+
# Check if this was a lost track being recovered
|
|
194
|
+
if prev_track_id in self.lost_tracks:
|
|
195
|
+
recovered_tracks.add(prev_track_id)
|
|
196
|
+
if self.debug:
|
|
197
|
+
print(f"DEBUG: RECOVERED lost track ID {prev_track_id} for detection {j} (was lost for {self.lost_tracks[prev_track_id]['frames_lost']} frames)")
|
|
198
|
+
else:
|
|
199
|
+
if self.debug:
|
|
200
|
+
print(f"DEBUG: Continued track ID {prev_track_id} for detection {j}")
|
|
201
|
+
else:
|
|
202
|
+
if self.debug:
|
|
203
|
+
print(f"DEBUG: Cost {cost:.3f} above threshold {self.cost_threshold}, not assigning")
|
|
204
|
+
|
|
205
|
+
# Remove recovered tracks from lost tracks
|
|
206
|
+
for track_id in recovered_tracks:
|
|
207
|
+
del self.lost_tracks[track_id]
|
|
208
|
+
|
|
209
|
+
# Assign new track IDs to unassigned current boxes (new insects)
|
|
210
|
+
for j in range(n_curr):
|
|
211
|
+
if j not in assigned_curr_indices:
|
|
212
|
+
new_boxes[j].track_id = self._generate_track_id()
|
|
213
|
+
track_ids[j] = new_boxes[j].track_id
|
|
214
|
+
if self.debug:
|
|
215
|
+
print(f"DEBUG: Assigned NEW track ID {new_boxes[j].track_id} to detection {j}")
|
|
216
|
+
|
|
217
|
+
# Move unmatched current tracks to lost tracks (tracks that disappeared this frame)
|
|
218
|
+
matched_track_ids = {track_ids[j] for j in assigned_curr_indices if track_ids[j] is not None}
|
|
219
|
+
for track in self.current_tracks:
|
|
220
|
+
if track.track_id not in matched_track_ids and track.track_id not in recovered_tracks:
|
|
221
|
+
if track.track_id not in self.lost_tracks:
|
|
222
|
+
self.lost_tracks[track.track_id] = {
|
|
223
|
+
'box': track,
|
|
224
|
+
'frames_lost': 1
|
|
225
|
+
}
|
|
226
|
+
if self.debug:
|
|
227
|
+
print(f"DEBUG: Track {track.track_id} disappeared, moved to lost tracks")
|
|
228
|
+
|
|
229
|
+
# Age lost tracks and remove old ones
|
|
230
|
+
self._age_lost_tracks()
|
|
231
|
+
|
|
232
|
+
# Update current tracks
|
|
233
|
+
self.current_tracks = new_boxes
|
|
234
|
+
|
|
235
|
+
# Add to tracking history
|
|
236
|
+
self.tracking_history.append(new_boxes)
|
|
237
|
+
|
|
238
|
+
return track_ids
|
|
239
|
+
|
|
240
|
+
def _age_lost_tracks(self):
|
|
241
|
+
"""Age lost tracks and remove those that have been lost too long"""
|
|
242
|
+
tracks_to_remove = []
|
|
243
|
+
for track_id, lost_info in self.lost_tracks.items():
|
|
244
|
+
lost_info['frames_lost'] += 1
|
|
245
|
+
if lost_info['frames_lost'] > self.track_memory_frames:
|
|
246
|
+
tracks_to_remove.append(track_id)
|
|
247
|
+
if self.debug:
|
|
248
|
+
print(f"DEBUG: Permanently removing track {track_id} (lost for {lost_info['frames_lost']} frames)")
|
|
249
|
+
|
|
250
|
+
for track_id in tracks_to_remove:
|
|
251
|
+
del self.lost_tracks[track_id]
|
|
252
|
+
|
|
253
|
+
def get_tracking_stats(self):
|
|
254
|
+
"""Get current tracking statistics for debugging/monitoring"""
|
|
255
|
+
return {
|
|
256
|
+
'active_tracks': len(self.current_tracks),
|
|
257
|
+
'lost_tracks': len(self.lost_tracks),
|
|
258
|
+
'active_track_ids': [track.track_id for track in self.current_tracks],
|
|
259
|
+
'lost_track_ids': list(self.lost_tracks.keys()),
|
|
260
|
+
'total_history_frames': len(self.tracking_history)
|
|
261
|
+
}
|