bplusplus 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

@@ -23,7 +23,7 @@ import sys
23
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
- def test_multitask(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
26
+ def test(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
27
27
  """
28
28
  Run the two-stage classifier on a test set.
29
29
 
@@ -115,17 +115,17 @@ class HierarchicalInsectClassifier(nn.Module):
115
115
  def get_taxonomy(species_list):
116
116
  """
117
117
  Retrieves taxonomic information for a list of species from GBIF API.
118
- Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
118
+ Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
119
119
  """
120
120
  taxonomy = {1: [], 2: {}, 3: {}}
121
- species_to_family = {}
122
- family_to_order = {}
121
+ species_to_genus = {}
122
+ genus_to_family = {}
123
123
 
124
124
  logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
125
125
 
126
126
  print("\nTaxonomy Results:")
127
127
  print("-" * 80)
128
- print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
128
+ print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
129
129
  print("-" * 80)
130
130
 
131
131
  for species_name in species_list:
@@ -136,23 +136,23 @@ def get_taxonomy(species_list):
136
136
 
137
137
  if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
138
138
  family = data.get('family')
139
- order = data.get('order')
139
+ genus = data.get('genus')
140
140
 
141
- if family and order:
141
+ if family and genus:
142
142
  status = "OK"
143
143
 
144
- print(f"{species_name:<30} {order:<20} {family:<20} {status}")
144
+ print(f"{species_name:<30} {family:<20} {genus:<20} {status}")
145
145
 
146
- species_to_family[species_name] = family
147
- family_to_order[family] = order
146
+ species_to_genus[species_name] = genus
147
+ genus_to_family[genus] = family
148
148
 
149
- if order not in taxonomy[1]:
150
- taxonomy[1].append(order)
149
+ if family not in taxonomy[1]:
150
+ taxonomy[1].append(family)
151
151
 
152
- taxonomy[2][family] = order
153
- taxonomy[3][species_name] = family
152
+ taxonomy[2][genus] = family
153
+ taxonomy[3][species_name] = genus
154
154
  else:
155
- error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
155
+ error_msg = f"Species '{species_name}' found in GBIF but family and genus not found, could be spelling error in species, check GBIF"
156
156
  logger.error(error_msg)
157
157
  print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
158
158
  print(f"Error: {error_msg}")
@@ -174,24 +174,24 @@ def get_taxonomy(species_list):
174
174
  taxonomy[1] = sorted(list(set(taxonomy[1])))
175
175
  print("-" * 80)
176
176
 
177
- num_orders = len(taxonomy[1])
178
- num_families = len(taxonomy[2])
177
+ num_families = len(taxonomy[1])
178
+ num_genera = len(taxonomy[2])
179
179
  num_species = len(taxonomy[3])
180
180
 
181
- print("\nOrder indices:")
182
- for i, order in enumerate(taxonomy[1]):
183
- print(f" {i}: {order}")
184
-
185
181
  print("\nFamily indices:")
186
- for i, family in enumerate(taxonomy[2].keys()):
182
+ for i, family in enumerate(taxonomy[1]):
187
183
  print(f" {i}: {family}")
188
184
 
185
+ print("\nGenus indices:")
186
+ for i, genus in enumerate(taxonomy[2].keys()):
187
+ print(f" {i}: {genus}")
188
+
189
189
  print("\nSpecies indices:")
190
190
  for i, species in enumerate(species_list):
191
191
  print(f" {i}: {species}")
192
192
 
193
- logger.info(f"Taxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
194
- return taxonomy, species_to_family, family_to_order
193
+ logger.info(f"Taxonomy built: {num_families} families, {num_genera} genera, {num_species} species")
194
+ return taxonomy, species_to_genus, genus_to_family
195
195
 
196
196
  def create_mappings(taxonomy):
197
197
  """Create index mappings from taxonomy"""
@@ -243,13 +243,20 @@ class TestTwoStage:
243
243
  if "species_list" in checkpoint:
244
244
  saved_species = checkpoint["species_list"]
245
245
  print(f"Saved model was trained on: {', '.join(saved_species)}")
246
-
247
- taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
246
+
247
+ # Use saved taxonomy mappings if available
248
+ if "species_to_genus" in checkpoint and "genus_to_family" in checkpoint:
249
+ species_to_genus = checkpoint["species_to_genus"]
250
+ genus_to_family = checkpoint["genus_to_family"]
251
+ else:
252
+ # Fallback: fetch from GBIF but this may cause index mismatches
253
+ print("Warning: No taxonomy mappings in checkpoint, fetching from GBIF")
254
+ _, species_to_genus, genus_to_family = get_taxonomy(species_names)
248
255
  else:
249
- taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
256
+ taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
250
257
  else:
251
258
  state_dict = checkpoint
252
- taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
259
+ taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
253
260
 
254
261
  level_to_idx, idx_to_level = create_mappings(taxonomy)
255
262
 
@@ -259,8 +266,6 @@ class TestTwoStage:
259
266
  if hasattr(taxonomy, "items"):
260
267
  num_classes_per_level = [len(classes) if isinstance(classes, list) else len(classes.keys())
261
268
  for level, classes in taxonomy.items()]
262
- else:
263
- num_classes_per_level = [4, 5, 9] # Example values, adjust as needed
264
269
 
265
270
  print(f"Using model with class counts: {num_classes_per_level}")
266
271
 
@@ -287,8 +292,6 @@ class TestTwoStage:
287
292
  self.classification_model.eval()
288
293
 
289
294
  self.classification_transform = transforms.Compose([
290
- transforms.Resize((768, 768)), # Fixed size for all validation images
291
- transforms.CenterCrop(640),
292
295
  transforms.ToTensor(),
293
296
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
294
297
  ])
@@ -296,8 +299,8 @@ class TestTwoStage:
296
299
  print("Model successfully loaded")
297
300
  print(f"Using species: {', '.join(species_names)}")
298
301
 
299
- self.species_to_family = species_to_family
300
- self.family_to_order = family_to_order
302
+ self.species_to_genus = species_to_genus
303
+ self.genus_to_family = genus_to_family
301
304
 
302
305
  def get_frames(self, test_dir):
303
306
  image_dir = os.path.join(test_dir, "images")
@@ -305,10 +308,10 @@ class TestTwoStage:
305
308
 
306
309
  predicted_frames = []
307
310
  predicted_family_frames = []
308
- predicted_order_frames = []
311
+ predicted_genus_frames = []
309
312
  true_species_frames = []
310
313
  true_family_frames = []
311
- true_order_frames = []
314
+ true_genus_frames = []
312
315
  image_names = []
313
316
 
314
317
  start_time = time.time() # Start timing
@@ -326,7 +329,7 @@ class TestTwoStage:
326
329
  detections = results[0].boxes
327
330
  predicted_frame = []
328
331
  predicted_family_frame = []
329
- predicted_order_frame = []
332
+ predicted_genus_frame = []
330
333
 
331
334
  if detections:
332
335
  for box in detections:
@@ -346,13 +349,13 @@ class TestTwoStage:
346
349
  outputs = self.classification_model(input_tensor)
347
350
 
348
351
  # Get all taxonomic level predictions
349
- order_output = outputs[0] # First output is order (level 1)
350
- family_output = outputs[1] # Second output is family (level 2)
352
+ family_output = outputs[0] # First output is family (level 1)
353
+ genus_output = outputs[1] # Second output is genus (level 2)
351
354
  species_output = outputs[2] # Third output is species (level 3)
352
355
 
353
356
  # Get prediction indices
354
- order_idx = order_output.argmax(dim=1).item()
355
357
  family_idx = family_output.argmax(dim=1).item()
358
+ genus_idx = genus_output.argmax(dim=1).item()
356
359
  species_idx = species_output.argmax(dim=1).item()
357
360
 
358
361
  img_height, img_width, _ = frame.shape
@@ -367,15 +370,15 @@ class TestTwoStage:
367
370
  # Add predictions for each taxonomic level
368
371
  predicted_frame.append([species_idx] + box_coords)
369
372
  predicted_family_frame.append([family_idx] + box_coords)
370
- predicted_order_frame.append([order_idx] + box_coords)
373
+ predicted_genus_frame.append([genus_idx] + box_coords)
371
374
 
372
375
  predicted_frames.append(predicted_frame if predicted_frame else [])
373
376
  predicted_family_frames.append(predicted_family_frame if predicted_family_frame else [])
374
- predicted_order_frames.append(predicted_order_frame if predicted_order_frame else [])
377
+ predicted_genus_frames.append(predicted_genus_frame if predicted_genus_frame else [])
375
378
 
376
379
  true_species_frame = []
377
380
  true_family_frame = []
378
- true_order_frame = []
381
+ true_genus_frame = []
379
382
 
380
383
  if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
381
384
  with open(label_path, 'r') as f:
@@ -389,22 +392,22 @@ class TestTwoStage:
389
392
  if species_idx < len(self.species_names):
390
393
  species_name = self.species_names[species_idx]
391
394
 
392
- if species_name in self.species_to_family:
393
- family_name = self.species_to_family[species_name]
394
- # Get the index of the family in the level_to_idx mapping
395
- if 2 in self.level_to_idx and family_name in self.level_to_idx[2]:
396
- family_idx = self.level_to_idx[2][family_name]
397
- true_family_frame.append([family_idx] + box_coords)
395
+ if species_name in self.species_to_genus:
396
+ genus_name = self.species_to_genus[species_name]
397
+ # Get the index of the genus in the level_to_idx mapping
398
+ if 2 in self.level_to_idx and genus_name in self.level_to_idx[2]:
399
+ genus_idx = self.level_to_idx[2][genus_name]
400
+ true_genus_frame.append([genus_idx] + box_coords)
398
401
 
399
- if family_name in self.family_to_order:
400
- order_name = self.family_to_order[family_name]
401
- if 1 in self.level_to_idx and order_name in self.level_to_idx[1]:
402
- order_idx = self.level_to_idx[1][order_name]
403
- true_order_frame.append([order_idx] + box_coords)
402
+ if genus_name in self.genus_to_family:
403
+ family_name = self.genus_to_family[genus_name]
404
+ if 1 in self.level_to_idx and family_name in self.level_to_idx[1]:
405
+ family_idx = self.level_to_idx[1][family_name]
406
+ true_family_frame.append([family_idx] + box_coords)
404
407
 
405
408
  true_species_frames.append(true_species_frame if true_species_frame else [])
406
409
  true_family_frames.append(true_family_frame if true_family_frame else [])
407
- true_order_frames.append(true_order_frame if true_order_frame else [])
410
+ true_genus_frames.append(true_genus_frame if true_genus_frame else [])
408
411
 
409
412
  end_time = time.time() # End timing
410
413
 
@@ -416,42 +419,42 @@ class TestTwoStage:
416
419
  writer.writerow([
417
420
  "Image Name",
418
421
  "True Species Detections",
422
+ "True Genus Detections",
419
423
  "True Family Detections",
420
- "True Order Detections",
421
424
  "Species Detections",
422
- "Family Detections",
423
- "Order Detections"
425
+ "Genus Detections",
426
+ "Family Detections"
424
427
  ])
425
428
 
426
- for image_name, true_species, true_family, true_order, species_pred, family_pred, order_pred in zip(
429
+ for image_name, true_species, true_genus, true_family, species_pred, genus_pred, family_pred in zip(
427
430
  image_names,
428
431
  true_species_frames,
432
+ true_genus_frames,
429
433
  true_family_frames,
430
- true_order_frames,
431
434
  predicted_frames,
432
- predicted_family_frames,
433
- predicted_order_frames
435
+ predicted_genus_frames,
436
+ predicted_family_frames
434
437
  ):
435
438
  writer.writerow([
436
439
  image_name,
437
440
  true_species,
441
+ true_genus,
438
442
  true_family,
439
- true_order,
440
443
  species_pred,
441
- family_pred,
442
- order_pred
444
+ genus_pred,
445
+ family_pred
443
446
  ])
444
447
 
445
448
  print(f"Results saved to {output_file}")
446
- return predicted_frames, true_species_frames, end_time - start_time, predicted_family_frames, predicted_order_frames, true_family_frames, true_order_frames
449
+ return predicted_frames, true_species_frames, end_time - start_time, predicted_genus_frames, predicted_family_frames, true_genus_frames, true_family_frames
447
450
 
448
451
  def run(self, test_dir):
449
452
  results = self.get_frames(test_dir)
450
453
  predicted_frames, true_species_frames, total_time = results[0], results[1], results[2]
451
- predicted_family_frames = results[3]
452
- predicted_order_frames = results[4]
453
- true_family_frames = results[5]
454
- true_order_frames = results[6]
454
+ predicted_genus_frames = results[3]
455
+ predicted_family_frames = results[4]
456
+ true_genus_frames = results[5]
457
+ true_family_frames = results[6]
455
458
 
456
459
  num_frames = len(os.listdir(os.path.join(test_dir, 'images')))
457
460
  avg_time_per_frame = total_time / num_frames
@@ -461,29 +464,37 @@ class TestTwoStage:
461
464
 
462
465
  self.calculate_metrics(
463
466
  predicted_frames, true_species_frames,
464
- predicted_family_frames, true_family_frames,
465
- predicted_order_frames, true_order_frames
467
+ predicted_genus_frames, true_genus_frames,
468
+ predicted_family_frames, true_family_frames
466
469
  )
467
470
 
468
471
  def calculate_metrics(self, predicted_species_frames, true_species_frames,
469
- predicted_family_frames, true_family_frames,
470
- predicted_order_frames, true_order_frames):
472
+ predicted_genus_frames, true_genus_frames,
473
+ predicted_family_frames, true_family_frames):
471
474
  """Calculate metrics at all taxonomic levels"""
472
- # Get list of species, families and orders
475
+ # Get list of species, families and genera using the same order as model training
473
476
  species_list = self.species_names
474
- family_list = sorted(list(set(self.species_to_family.values())))
475
- order_list = sorted(list(set(self.family_to_order.values())))
477
+
478
+ # Use the index mappings from the model to ensure consistency
479
+ if 1 in self.idx_to_level and 2 in self.idx_to_level:
480
+ family_list = [self.idx_to_level[1][i] for i in sorted(self.idx_to_level[1].keys())]
481
+ genus_list = [self.idx_to_level[2][i] for i in sorted(self.idx_to_level[2].keys())]
482
+ else:
483
+ # Fallback to sorted lists (may cause issues)
484
+ print("Warning: Using fallback sorted lists for taxonomy - this may cause index mismatches")
485
+ genus_list = sorted(list(set(self.species_to_genus.values())))
486
+ family_list = sorted(list(set(self.genus_to_family.values())))
476
487
 
477
488
  # Print the index mappings we're using for evaluation
478
489
  print("\nUsing the following index mappings for evaluation:")
479
- print("\nOrder indices:")
480
- for i, order in enumerate(order_list):
481
- print(f" {i}: {order}")
482
-
483
490
  print("\nFamily indices:")
484
491
  for i, family in enumerate(family_list):
485
492
  print(f" {i}: {family}")
486
493
 
494
+ print("\nGenus indices:")
495
+ for i, genus in enumerate(genus_list):
496
+ print(f" {i}: {genus}")
497
+
487
498
  print("\nSpecies indices:")
488
499
  for i, species in enumerate(species_list):
489
500
  print(f" {i}: {species}")
@@ -491,11 +502,11 @@ class TestTwoStage:
491
502
  # Dictionary to track prediction category counts for debugging
492
503
  prediction_counts = {
493
504
  "true_species_boxes": sum(len(frame) for frame in true_species_frames),
505
+ "true_genus_boxes": sum(len(frame) for frame in true_genus_frames),
494
506
  "true_family_boxes": sum(len(frame) for frame in true_family_frames),
495
- "true_order_boxes": sum(len(frame) for frame in true_order_frames),
496
507
  "predicted_species": sum(len(frame) for frame in predicted_species_frames),
497
- "predicted_family": sum(len(frame) for frame in predicted_family_frames),
498
- "predicted_order": sum(len(frame) for frame in predicted_order_frames)
508
+ "predicted_genus": sum(len(frame) for frame in predicted_genus_frames),
509
+ "predicted_family": sum(len(frame) for frame in predicted_family_frames)
499
510
  }
500
511
 
501
512
  print(f"Prediction counts: {prediction_counts}")
@@ -504,11 +515,11 @@ class TestTwoStage:
504
515
  print("\n=== Species-level Metrics ===")
505
516
  self.get_metrics(predicted_species_frames, true_species_frames, species_list)
506
517
 
518
+ print("\n=== Genus-level Metrics ===")
519
+ self.get_metrics(predicted_genus_frames, true_genus_frames, genus_list)
520
+
507
521
  print("\n=== Family-level Metrics ===")
508
522
  self.get_metrics(predicted_family_frames, true_family_frames, family_list)
509
-
510
- print("\n=== Order-level Metrics ===")
511
- self.get_metrics(predicted_order_frames, true_order_frames, order_list)
512
523
 
513
524
  def get_metrics(self, predicted_frames, true_frames, labels):
514
525
  """Calculate metrics for object detection predictions"""
@@ -667,4 +678,4 @@ if __name__ == "__main__":
667
678
  hierarchical_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/hierarchical/hierarchical-weights.pth"
668
679
  output_directory = "./output"
669
680
 
670
- test_multitask(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
681
+ test(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
bplusplus/tracker.py ADDED
@@ -0,0 +1,261 @@
1
+ import numpy as np
2
+ import uuid
3
+ from scipy.optimize import linear_sum_assignment
4
+ from collections import deque
5
+
6
+ class BoundingBox:
7
+ def __init__(self, x, y, width, height, frame_id, track_id=None):
8
+ self.x = x
9
+ self.y = y
10
+ self.width = width
11
+ self.height = height
12
+ self.area = width * height
13
+ self.frame_id = frame_id
14
+ self.track_id = track_id
15
+
16
+ def center(self):
17
+ return (self.x + self.width/2, self.y + self.height/2)
18
+
19
+ @classmethod
20
+ def from_xyxy(cls, x1, y1, x2, y2, frame_id, track_id=None):
21
+ """Create BoundingBox from x1,y1,x2,y2 coordinates"""
22
+ width = x2 - x1
23
+ height = y2 - y1
24
+ return cls(x1, y1, width, height, frame_id, track_id)
25
+
26
+ class InsectTracker:
27
+ def __init__(self, image_height, image_width, max_frames=30, w_dist=0.7, w_area=0.3, cost_threshold=0.8, track_memory_frames=None, debug=False):
28
+ self.image_height = image_height
29
+ self.image_width = image_width
30
+ self.max_dist = np.sqrt(image_height**2 + image_width**2)
31
+ self.max_frames = max_frames
32
+ self.w_dist = w_dist
33
+ self.w_area = w_area
34
+ self.cost_threshold = cost_threshold
35
+ self.debug = debug
36
+
37
+ # If track_memory_frames not specified, use max_frames (full history window)
38
+ self.track_memory_frames = track_memory_frames if track_memory_frames is not None else max_frames
39
+ if self.debug:
40
+ print(f"DEBUG: Tracker initialized with max_frames={max_frames}, track_memory_frames={self.track_memory_frames}")
41
+
42
+ self.tracking_history = deque(maxlen=max_frames)
43
+ self.current_tracks = []
44
+ self.lost_tracks = {} # track_id -> {box: BoundingBox, frames_lost: int}
45
+
46
+ def _generate_track_id(self):
47
+ """Generate a unique UUID for a new track"""
48
+ return str(uuid.uuid4())
49
+
50
+ def calculate_cost(self, box1, box2):
51
+ """Calculate cost between two bounding boxes as per equation (4)"""
52
+ # Calculate center points
53
+ cx1, cy1 = box1.center()
54
+ cx2, cy2 = box2.center()
55
+
56
+ # Euclidean distance (equation 1)
57
+ dist = np.sqrt((cx2 - cx1)**2 + (cy2 - cy1)**2)
58
+
59
+ # Normalized distance (equation 2 used for normalization)
60
+ norm_dist = dist / self.max_dist
61
+
62
+ # Area cost (equation 3)
63
+ min_area = min(box1.area, box2.area)
64
+ max_area = max(box1.area, box2.area)
65
+ area_cost = min_area / max_area if max_area > 0 else 1.0
66
+
67
+ # Final cost (equation 4)
68
+ cost = (norm_dist * self.w_dist) + ((1 - area_cost) * self.w_area)
69
+
70
+ return cost
71
+
72
+ def build_cost_matrix(self, prev_boxes, curr_boxes):
73
+ """Build cost matrix for Hungarian algorithm"""
74
+ n_prev = len(prev_boxes)
75
+ n_curr = len(curr_boxes)
76
+ n = max(n_prev, n_curr)
77
+
78
+ # Initialize cost matrix with high values
79
+ cost_matrix = np.ones((n, n)) * 999.0
80
+
81
+ # Fill in actual costs
82
+ for i in range(n_prev):
83
+ for j in range(n_curr):
84
+ cost_matrix[i, j] = self.calculate_cost(prev_boxes[i], curr_boxes[j])
85
+
86
+ return cost_matrix, n_prev, n_curr
87
+
88
+ def update(self, new_detections, frame_id):
89
+ """
90
+ Update tracking with new detections from YOLO
91
+
92
+ Args:
93
+ new_detections: List of YOLO detection boxes (x1, y1, x2, y2 format)
94
+ frame_id: Current frame number
95
+
96
+ Returns:
97
+ List of track IDs corresponding to each detection
98
+ """
99
+ # Handle empty detection list (no detections in this frame)
100
+ if not new_detections:
101
+ if self.debug:
102
+ print(f"DEBUG: Frame {frame_id} has no detections")
103
+ # Move all current tracks to lost tracks
104
+ for track in self.current_tracks:
105
+ if track.track_id not in self.lost_tracks:
106
+ self.lost_tracks[track.track_id] = {
107
+ 'box': track,
108
+ 'frames_lost': 1
109
+ }
110
+ if self.debug:
111
+ print(f"DEBUG: Moved track {track.track_id} to lost tracks")
112
+ else:
113
+ self.lost_tracks[track.track_id]['frames_lost'] += 1
114
+
115
+ # Age lost tracks and remove old ones
116
+ self._age_lost_tracks()
117
+
118
+ self.current_tracks = []
119
+ self.tracking_history.append([])
120
+ return []
121
+
122
+ # Convert YOLO detections to BoundingBox objects
123
+ new_boxes = []
124
+ for i, detection in enumerate(new_detections):
125
+ x1, y1, x2, y2 = detection[:4]
126
+ bbox = BoundingBox.from_xyxy(x1, y1, x2, y2, frame_id)
127
+ new_boxes.append(bbox)
128
+
129
+ # If this is the first frame or no existing tracks, assign new track IDs to all boxes
130
+ if not self.current_tracks and not self.lost_tracks:
131
+ track_ids = []
132
+ for box in new_boxes:
133
+ box.track_id = self._generate_track_id()
134
+ track_ids.append(box.track_id)
135
+ if self.debug:
136
+ print(f"DEBUG: FIRST FRAME - Assigned track ID {box.track_id} to new detection")
137
+ self.current_tracks = new_boxes
138
+ self.tracking_history.append(new_boxes)
139
+ return track_ids
140
+
141
+ # Combine current tracks and lost tracks for matching
142
+ all_previous_tracks = self.current_tracks.copy()
143
+ lost_track_list = []
144
+
145
+ for track_id, lost_info in self.lost_tracks.items():
146
+ lost_track_list.append(lost_info['box'])
147
+ lost_track_list[-1].track_id = track_id # Ensure track_id is preserved
148
+
149
+ all_previous_tracks.extend(lost_track_list)
150
+
151
+ if not all_previous_tracks:
152
+ # No previous tracks at all, assign new IDs
153
+ track_ids = []
154
+ for box in new_boxes:
155
+ box.track_id = self._generate_track_id()
156
+ track_ids.append(box.track_id)
157
+ if self.debug:
158
+ print(f"DEBUG: No previous tracks - Assigned track ID {box.track_id} to new detection")
159
+ self.current_tracks = new_boxes
160
+ self.tracking_history.append(new_boxes)
161
+ return track_ids
162
+
163
+ # Build cost matrix including lost tracks
164
+ cost_matrix, n_prev, n_curr = self.build_cost_matrix(all_previous_tracks, new_boxes)
165
+
166
+ # Apply Hungarian algorithm
167
+ row_indices, col_indices = linear_sum_assignment(cost_matrix)
168
+
169
+ # Assign track IDs based on the matching
170
+ assigned_curr_indices = set()
171
+ track_ids = [None] * len(new_boxes)
172
+ recovered_tracks = set() # Track IDs that were recovered from lost tracks
173
+
174
+ if self.debug:
175
+ print(f"DEBUG: Hungarian assignment - rows: {row_indices}, cols: {col_indices}")
176
+ print(f"DEBUG: Cost threshold: {self.cost_threshold}")
177
+ print(f"DEBUG: Current tracks: {len(self.current_tracks)}, Lost tracks: {len(self.lost_tracks)}")
178
+
179
+ for i, j in zip(row_indices, col_indices):
180
+ # Only consider valid assignments (not dummy rows/columns)
181
+ if i < n_prev and j < n_curr:
182
+ cost = cost_matrix[i, j]
183
+ if self.debug:
184
+ print(f"DEBUG: Checking assignment {i}->{j}, cost: {cost:.3f}")
185
+ # Check if cost is below threshold
186
+ if cost < self.cost_threshold:
187
+ # Assign the track ID from previous box to current box
188
+ prev_track_id = all_previous_tracks[i].track_id
189
+ new_boxes[j].track_id = prev_track_id
190
+ track_ids[j] = prev_track_id
191
+ assigned_curr_indices.add(j)
192
+
193
+ # Check if this was a lost track being recovered
194
+ if prev_track_id in self.lost_tracks:
195
+ recovered_tracks.add(prev_track_id)
196
+ if self.debug:
197
+ print(f"DEBUG: RECOVERED lost track ID {prev_track_id} for detection {j} (was lost for {self.lost_tracks[prev_track_id]['frames_lost']} frames)")
198
+ else:
199
+ if self.debug:
200
+ print(f"DEBUG: Continued track ID {prev_track_id} for detection {j}")
201
+ else:
202
+ if self.debug:
203
+ print(f"DEBUG: Cost {cost:.3f} above threshold {self.cost_threshold}, not assigning")
204
+
205
+ # Remove recovered tracks from lost tracks
206
+ for track_id in recovered_tracks:
207
+ del self.lost_tracks[track_id]
208
+
209
+ # Assign new track IDs to unassigned current boxes (new insects)
210
+ for j in range(n_curr):
211
+ if j not in assigned_curr_indices:
212
+ new_boxes[j].track_id = self._generate_track_id()
213
+ track_ids[j] = new_boxes[j].track_id
214
+ if self.debug:
215
+ print(f"DEBUG: Assigned NEW track ID {new_boxes[j].track_id} to detection {j}")
216
+
217
+ # Move unmatched current tracks to lost tracks (tracks that disappeared this frame)
218
+ matched_track_ids = {track_ids[j] for j in assigned_curr_indices if track_ids[j] is not None}
219
+ for track in self.current_tracks:
220
+ if track.track_id not in matched_track_ids and track.track_id not in recovered_tracks:
221
+ if track.track_id not in self.lost_tracks:
222
+ self.lost_tracks[track.track_id] = {
223
+ 'box': track,
224
+ 'frames_lost': 1
225
+ }
226
+ if self.debug:
227
+ print(f"DEBUG: Track {track.track_id} disappeared, moved to lost tracks")
228
+
229
+ # Age lost tracks and remove old ones
230
+ self._age_lost_tracks()
231
+
232
+ # Update current tracks
233
+ self.current_tracks = new_boxes
234
+
235
+ # Add to tracking history
236
+ self.tracking_history.append(new_boxes)
237
+
238
+ return track_ids
239
+
240
+ def _age_lost_tracks(self):
241
+ """Age lost tracks and remove those that have been lost too long"""
242
+ tracks_to_remove = []
243
+ for track_id, lost_info in self.lost_tracks.items():
244
+ lost_info['frames_lost'] += 1
245
+ if lost_info['frames_lost'] > self.track_memory_frames:
246
+ tracks_to_remove.append(track_id)
247
+ if self.debug:
248
+ print(f"DEBUG: Permanently removing track {track_id} (lost for {lost_info['frames_lost']} frames)")
249
+
250
+ for track_id in tracks_to_remove:
251
+ del self.lost_tracks[track_id]
252
+
253
+ def get_tracking_stats(self):
254
+ """Get current tracking statistics for debugging/monitoring"""
255
+ return {
256
+ 'active_tracks': len(self.current_tracks),
257
+ 'lost_tracks': len(self.lost_tracks),
258
+ 'active_track_ids': [track.track_id for track in self.current_tracks],
259
+ 'lost_track_ids': list(self.lost_tracks.keys()),
260
+ 'total_history_frames': len(self.tracking_history)
261
+ }