bplusplus 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

@@ -23,7 +23,7 @@ import sys
23
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
- def test_multitask(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
26
+ def test(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
27
27
  """
28
28
  Run the two-stage classifier on a test set.
29
29
 
@@ -74,6 +74,16 @@ def setup_gpu():
74
74
  logger.warning("Falling back to CPU")
75
75
  return torch.device("cpu")
76
76
 
77
+ # Add this check for backwards compatibility
78
+ if hasattr(torch.serialization, 'add_safe_globals'):
79
+ torch.serialization.add_safe_globals([
80
+ 'torch.LongTensor',
81
+ 'torch.cuda.LongTensor',
82
+ 'torch.FloatStorage',
83
+ 'torch.FloatStorage',
84
+ 'torch.cuda.FloatStorage',
85
+ ])
86
+
77
87
  class HierarchicalInsectClassifier(nn.Module):
78
88
  def __init__(self, num_classes_per_level):
79
89
  """
@@ -243,8 +253,15 @@ class TestTwoStage:
243
253
  if "species_list" in checkpoint:
244
254
  saved_species = checkpoint["species_list"]
245
255
  print(f"Saved model was trained on: {', '.join(saved_species)}")
246
-
247
- taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
256
+
257
+ # Use saved taxonomy mappings if available
258
+ if "species_to_genus" in checkpoint and "genus_to_family" in checkpoint:
259
+ species_to_genus = checkpoint["species_to_genus"]
260
+ genus_to_family = checkpoint["genus_to_family"]
261
+ else:
262
+ # Fallback: fetch from GBIF but this may cause index mismatches
263
+ print("Warning: No taxonomy mappings in checkpoint, fetching from GBIF")
264
+ _, species_to_genus, genus_to_family = get_taxonomy(species_names)
248
265
  else:
249
266
  taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
250
267
  else:
@@ -285,8 +302,6 @@ class TestTwoStage:
285
302
  self.classification_model.eval()
286
303
 
287
304
  self.classification_transform = transforms.Compose([
288
- transforms.Resize((768, 768)), # Fixed size for all validation images
289
- transforms.CenterCrop(640),
290
305
  transforms.ToTensor(),
291
306
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
292
307
  ])
@@ -467,10 +482,18 @@ class TestTwoStage:
467
482
  predicted_genus_frames, true_genus_frames,
468
483
  predicted_family_frames, true_family_frames):
469
484
  """Calculate metrics at all taxonomic levels"""
470
- # Get list of species, families and genera
485
+ # Get list of species, families and genera using the same order as model training
471
486
  species_list = self.species_names
472
- genus_list = sorted(list(set(self.species_to_genus.values())))
473
- family_list = sorted(list(set(self.genus_to_family.values())))
487
+
488
+ # Use the index mappings from the model to ensure consistency
489
+ if 1 in self.idx_to_level and 2 in self.idx_to_level:
490
+ family_list = [self.idx_to_level[1][i] for i in sorted(self.idx_to_level[1].keys())]
491
+ genus_list = [self.idx_to_level[2][i] for i in sorted(self.idx_to_level[2].keys())]
492
+ else:
493
+ # Fallback to sorted lists (may cause issues)
494
+ print("Warning: Using fallback sorted lists for taxonomy - this may cause index mismatches")
495
+ genus_list = sorted(list(set(self.species_to_genus.values())))
496
+ family_list = sorted(list(set(self.genus_to_family.values())))
474
497
 
475
498
  # Print the index mappings we're using for evaluation
476
499
  print("\nUsing the following index mappings for evaluation:")
@@ -665,4 +688,4 @@ if __name__ == "__main__":
665
688
  hierarchical_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/hierarchical/hierarchical-weights.pth"
666
689
  output_directory = "./output"
667
690
 
668
- test_multitask(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
691
+ test(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
bplusplus/tracker.py ADDED
@@ -0,0 +1,261 @@
1
+ import numpy as np
2
+ import uuid
3
+ from scipy.optimize import linear_sum_assignment
4
+ from collections import deque
5
+
6
+ class BoundingBox:
7
+ def __init__(self, x, y, width, height, frame_id, track_id=None):
8
+ self.x = x
9
+ self.y = y
10
+ self.width = width
11
+ self.height = height
12
+ self.area = width * height
13
+ self.frame_id = frame_id
14
+ self.track_id = track_id
15
+
16
+ def center(self):
17
+ return (self.x + self.width/2, self.y + self.height/2)
18
+
19
+ @classmethod
20
+ def from_xyxy(cls, x1, y1, x2, y2, frame_id, track_id=None):
21
+ """Create BoundingBox from x1,y1,x2,y2 coordinates"""
22
+ width = x2 - x1
23
+ height = y2 - y1
24
+ return cls(x1, y1, width, height, frame_id, track_id)
25
+
26
+ class InsectTracker:
27
+ def __init__(self, image_height, image_width, max_frames=30, w_dist=0.7, w_area=0.3, cost_threshold=0.8, track_memory_frames=None, debug=False):
28
+ self.image_height = image_height
29
+ self.image_width = image_width
30
+ self.max_dist = np.sqrt(image_height**2 + image_width**2)
31
+ self.max_frames = max_frames
32
+ self.w_dist = w_dist
33
+ self.w_area = w_area
34
+ self.cost_threshold = cost_threshold
35
+ self.debug = debug
36
+
37
+ # If track_memory_frames not specified, use max_frames (full history window)
38
+ self.track_memory_frames = track_memory_frames if track_memory_frames is not None else max_frames
39
+ if self.debug:
40
+ print(f"DEBUG: Tracker initialized with max_frames={max_frames}, track_memory_frames={self.track_memory_frames}")
41
+
42
+ self.tracking_history = deque(maxlen=max_frames)
43
+ self.current_tracks = []
44
+ self.lost_tracks = {} # track_id -> {box: BoundingBox, frames_lost: int}
45
+
46
+ def _generate_track_id(self):
47
+ """Generate a unique UUID for a new track"""
48
+ return str(uuid.uuid4())
49
+
50
+ def calculate_cost(self, box1, box2):
51
+ """Calculate cost between two bounding boxes as per equation (4)"""
52
+ # Calculate center points
53
+ cx1, cy1 = box1.center()
54
+ cx2, cy2 = box2.center()
55
+
56
+ # Euclidean distance (equation 1)
57
+ dist = np.sqrt((cx2 - cx1)**2 + (cy2 - cy1)**2)
58
+
59
+ # Normalized distance (equation 2 used for normalization)
60
+ norm_dist = dist / self.max_dist
61
+
62
+ # Area cost (equation 3)
63
+ min_area = min(box1.area, box2.area)
64
+ max_area = max(box1.area, box2.area)
65
+ area_cost = min_area / max_area if max_area > 0 else 1.0
66
+
67
+ # Final cost (equation 4)
68
+ cost = (norm_dist * self.w_dist) + ((1 - area_cost) * self.w_area)
69
+
70
+ return cost
71
+
72
+ def build_cost_matrix(self, prev_boxes, curr_boxes):
73
+ """Build cost matrix for Hungarian algorithm"""
74
+ n_prev = len(prev_boxes)
75
+ n_curr = len(curr_boxes)
76
+ n = max(n_prev, n_curr)
77
+
78
+ # Initialize cost matrix with high values
79
+ cost_matrix = np.ones((n, n)) * 999.0
80
+
81
+ # Fill in actual costs
82
+ for i in range(n_prev):
83
+ for j in range(n_curr):
84
+ cost_matrix[i, j] = self.calculate_cost(prev_boxes[i], curr_boxes[j])
85
+
86
+ return cost_matrix, n_prev, n_curr
87
+
88
+ def update(self, new_detections, frame_id):
89
+ """
90
+ Update tracking with new detections from YOLO
91
+
92
+ Args:
93
+ new_detections: List of YOLO detection boxes (x1, y1, x2, y2 format)
94
+ frame_id: Current frame number
95
+
96
+ Returns:
97
+ List of track IDs corresponding to each detection
98
+ """
99
+ # Handle empty detection list (no detections in this frame)
100
+ if not new_detections:
101
+ if self.debug:
102
+ print(f"DEBUG: Frame {frame_id} has no detections")
103
+ # Move all current tracks to lost tracks
104
+ for track in self.current_tracks:
105
+ if track.track_id not in self.lost_tracks:
106
+ self.lost_tracks[track.track_id] = {
107
+ 'box': track,
108
+ 'frames_lost': 1
109
+ }
110
+ if self.debug:
111
+ print(f"DEBUG: Moved track {track.track_id} to lost tracks")
112
+ else:
113
+ self.lost_tracks[track.track_id]['frames_lost'] += 1
114
+
115
+ # Age lost tracks and remove old ones
116
+ self._age_lost_tracks()
117
+
118
+ self.current_tracks = []
119
+ self.tracking_history.append([])
120
+ return []
121
+
122
+ # Convert YOLO detections to BoundingBox objects
123
+ new_boxes = []
124
+ for i, detection in enumerate(new_detections):
125
+ x1, y1, x2, y2 = detection[:4]
126
+ bbox = BoundingBox.from_xyxy(x1, y1, x2, y2, frame_id)
127
+ new_boxes.append(bbox)
128
+
129
+ # If this is the first frame or no existing tracks, assign new track IDs to all boxes
130
+ if not self.current_tracks and not self.lost_tracks:
131
+ track_ids = []
132
+ for box in new_boxes:
133
+ box.track_id = self._generate_track_id()
134
+ track_ids.append(box.track_id)
135
+ if self.debug:
136
+ print(f"DEBUG: FIRST FRAME - Assigned track ID {box.track_id} to new detection")
137
+ self.current_tracks = new_boxes
138
+ self.tracking_history.append(new_boxes)
139
+ return track_ids
140
+
141
+ # Combine current tracks and lost tracks for matching
142
+ all_previous_tracks = self.current_tracks.copy()
143
+ lost_track_list = []
144
+
145
+ for track_id, lost_info in self.lost_tracks.items():
146
+ lost_track_list.append(lost_info['box'])
147
+ lost_track_list[-1].track_id = track_id # Ensure track_id is preserved
148
+
149
+ all_previous_tracks.extend(lost_track_list)
150
+
151
+ if not all_previous_tracks:
152
+ # No previous tracks at all, assign new IDs
153
+ track_ids = []
154
+ for box in new_boxes:
155
+ box.track_id = self._generate_track_id()
156
+ track_ids.append(box.track_id)
157
+ if self.debug:
158
+ print(f"DEBUG: No previous tracks - Assigned track ID {box.track_id} to new detection")
159
+ self.current_tracks = new_boxes
160
+ self.tracking_history.append(new_boxes)
161
+ return track_ids
162
+
163
+ # Build cost matrix including lost tracks
164
+ cost_matrix, n_prev, n_curr = self.build_cost_matrix(all_previous_tracks, new_boxes)
165
+
166
+ # Apply Hungarian algorithm
167
+ row_indices, col_indices = linear_sum_assignment(cost_matrix)
168
+
169
+ # Assign track IDs based on the matching
170
+ assigned_curr_indices = set()
171
+ track_ids = [None] * len(new_boxes)
172
+ recovered_tracks = set() # Track IDs that were recovered from lost tracks
173
+
174
+ if self.debug:
175
+ print(f"DEBUG: Hungarian assignment - rows: {row_indices}, cols: {col_indices}")
176
+ print(f"DEBUG: Cost threshold: {self.cost_threshold}")
177
+ print(f"DEBUG: Current tracks: {len(self.current_tracks)}, Lost tracks: {len(self.lost_tracks)}")
178
+
179
+ for i, j in zip(row_indices, col_indices):
180
+ # Only consider valid assignments (not dummy rows/columns)
181
+ if i < n_prev and j < n_curr:
182
+ cost = cost_matrix[i, j]
183
+ if self.debug:
184
+ print(f"DEBUG: Checking assignment {i}->{j}, cost: {cost:.3f}")
185
+ # Check if cost is below threshold
186
+ if cost < self.cost_threshold:
187
+ # Assign the track ID from previous box to current box
188
+ prev_track_id = all_previous_tracks[i].track_id
189
+ new_boxes[j].track_id = prev_track_id
190
+ track_ids[j] = prev_track_id
191
+ assigned_curr_indices.add(j)
192
+
193
+ # Check if this was a lost track being recovered
194
+ if prev_track_id in self.lost_tracks:
195
+ recovered_tracks.add(prev_track_id)
196
+ if self.debug:
197
+ print(f"DEBUG: RECOVERED lost track ID {prev_track_id} for detection {j} (was lost for {self.lost_tracks[prev_track_id]['frames_lost']} frames)")
198
+ else:
199
+ if self.debug:
200
+ print(f"DEBUG: Continued track ID {prev_track_id} for detection {j}")
201
+ else:
202
+ if self.debug:
203
+ print(f"DEBUG: Cost {cost:.3f} above threshold {self.cost_threshold}, not assigning")
204
+
205
+ # Remove recovered tracks from lost tracks
206
+ for track_id in recovered_tracks:
207
+ del self.lost_tracks[track_id]
208
+
209
+ # Assign new track IDs to unassigned current boxes (new insects)
210
+ for j in range(n_curr):
211
+ if j not in assigned_curr_indices:
212
+ new_boxes[j].track_id = self._generate_track_id()
213
+ track_ids[j] = new_boxes[j].track_id
214
+ if self.debug:
215
+ print(f"DEBUG: Assigned NEW track ID {new_boxes[j].track_id} to detection {j}")
216
+
217
+ # Move unmatched current tracks to lost tracks (tracks that disappeared this frame)
218
+ matched_track_ids = {track_ids[j] for j in assigned_curr_indices if track_ids[j] is not None}
219
+ for track in self.current_tracks:
220
+ if track.track_id not in matched_track_ids and track.track_id not in recovered_tracks:
221
+ if track.track_id not in self.lost_tracks:
222
+ self.lost_tracks[track.track_id] = {
223
+ 'box': track,
224
+ 'frames_lost': 1
225
+ }
226
+ if self.debug:
227
+ print(f"DEBUG: Track {track.track_id} disappeared, moved to lost tracks")
228
+
229
+ # Age lost tracks and remove old ones
230
+ self._age_lost_tracks()
231
+
232
+ # Update current tracks
233
+ self.current_tracks = new_boxes
234
+
235
+ # Add to tracking history
236
+ self.tracking_history.append(new_boxes)
237
+
238
+ return track_ids
239
+
240
+ def _age_lost_tracks(self):
241
+ """Age lost tracks and remove those that have been lost too long"""
242
+ tracks_to_remove = []
243
+ for track_id, lost_info in self.lost_tracks.items():
244
+ lost_info['frames_lost'] += 1
245
+ if lost_info['frames_lost'] > self.track_memory_frames:
246
+ tracks_to_remove.append(track_id)
247
+ if self.debug:
248
+ print(f"DEBUG: Permanently removing track {track_id} (lost for {lost_info['frames_lost']} frames)")
249
+
250
+ for track_id in tracks_to_remove:
251
+ del self.lost_tracks[track_id]
252
+
253
+ def get_tracking_stats(self):
254
+ """Get current tracking statistics for debugging/monitoring"""
255
+ return {
256
+ 'active_tracks': len(self.current_tracks),
257
+ 'lost_tracks': len(self.lost_tracks),
258
+ 'active_track_ids': [track.track_id for track in self.current_tracks],
259
+ 'lost_track_ids': list(self.lost_tracks.keys()),
260
+ 'total_history_frames': len(self.tracking_history)
261
+ }
@@ -14,18 +14,28 @@ import logging
14
14
  from tqdm import tqdm
15
15
  import sys
16
16
 
17
- def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='/mnt/nvme0n1p1/datasets/insect/bjerge-train2', output_dir='./output', species_list=None):
17
+ def train(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='input', output_dir='./output', species_list=None, num_workers=4):
18
18
  """
19
19
  Main function to run the entire training pipeline.
20
20
  Sets up datasets, model, training process and handles errors.
21
+
22
+ Args:
23
+ batch_size (int): Number of samples per batch. Default: 4
24
+ epochs (int): Maximum number of training epochs. Default: 30
25
+ patience (int): Early stopping patience (epochs without improvement). Default: 3
26
+ img_size (int): Target image size for training. Default: 640
27
+ data_dir (str): Directory containing train/valid subdirectories. Default: 'input'
28
+ output_dir (str): Directory to save trained model and logs. Default: './output'
29
+ species_list (list): List of species names for training. Required.
30
+ num_workers (int): Number of DataLoader worker processes.
31
+ Set to 0 to disable multiprocessing (most stable). Default: 4
21
32
  """
22
33
  global logger, device
23
34
 
24
35
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
36
  logger = logging.getLogger(__name__)
26
37
 
27
- logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}")
28
-
38
+ logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}, Num workers: {num_workers}")
29
39
 
30
40
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
41
 
@@ -52,7 +62,7 @@ def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir=
52
62
 
53
63
  taxonomy = get_taxonomy(species_list)
54
64
 
55
- level_to_idx, parent_child_relationship = create_mappings(taxonomy)
65
+ level_to_idx, parent_child_relationship = create_mappings(taxonomy, species_list)
56
66
 
57
67
  num_classes_per_level = [len(taxonomy[level]) if isinstance(taxonomy[level], list)
58
68
  else len(taxonomy[level].keys()) for level in sorted(taxonomy.keys())]
@@ -75,14 +85,14 @@ def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir=
75
85
  train_dataset,
76
86
  batch_size=batch_size,
77
87
  shuffle=True,
78
- num_workers=4
88
+ num_workers=num_workers
79
89
  )
80
90
 
81
91
  val_loader = DataLoader(
82
92
  val_dataset,
83
93
  batch_size=batch_size,
84
94
  shuffle=False,
85
- num_workers=4
95
+ num_workers=num_workers
86
96
  )
87
97
 
88
98
  try:
@@ -150,14 +160,17 @@ def get_taxonomy(species_list):
150
160
  species_to_genus = {}
151
161
  genus_to_family = {}
152
162
 
153
- logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
163
+ species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
164
+ has_unknown = len(species_list_for_gbif) != len(species_list)
165
+
166
+ logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
154
167
 
155
168
  print("\nTaxonomy Results:")
156
169
  print("-" * 80)
157
170
  print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
158
171
  print("-" * 80)
159
172
 
160
- for species_name in species_list:
173
+ for species_name in species_list_for_gbif:
161
174
  url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
162
175
  try:
163
176
  response = requests.get(url)
@@ -199,6 +212,19 @@ def get_taxonomy(species_list):
199
212
  print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
200
213
  print(f"Error: {error_msg}")
201
214
  sys.exit(1) # Stop the script
215
+
216
+ if has_unknown:
217
+ unknown_family = "Unknown"
218
+ unknown_genus = "Unknown"
219
+ unknown_species = "unknown"
220
+
221
+ if unknown_family not in taxonomy[1]:
222
+ taxonomy[1].append(unknown_family)
223
+
224
+ taxonomy[2][unknown_genus] = unknown_family
225
+ taxonomy[3][unknown_species] = unknown_genus
226
+
227
+ print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
202
228
 
203
229
  taxonomy[1] = sorted(list(set(taxonomy[1])))
204
230
  print("-" * 80)
@@ -212,7 +238,7 @@ def get_taxonomy(species_list):
212
238
  print(f" {i}: {family}")
213
239
 
214
240
  print("\nGenus indices:")
215
- for i, genus in enumerate(taxonomy[2].keys()):
241
+ for i, genus in enumerate(sorted(taxonomy[2].keys())):
216
242
  print(f" {i}: {genus}")
217
243
 
218
244
  print("\nSpecies indices:")
@@ -244,7 +270,7 @@ def get_species_from_directory(train_dir):
244
270
  logger.info(f"Found {len(species_list)} species in {train_dir}")
245
271
  return species_list
246
272
 
247
- def create_mappings(taxonomy):
273
+ def create_mappings(taxonomy, species_list=None):
248
274
  """
249
275
  Creates mapping dictionaries from taxonomy data.
250
276
  Returns level-to-index mapping and parent-child relationships between taxonomic levels.
@@ -254,9 +280,17 @@ def create_mappings(taxonomy):
254
280
 
255
281
  for level, labels in taxonomy.items():
256
282
  if isinstance(labels, list):
283
+ # Level 1: Family (already sorted)
257
284
  level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
258
- else:
259
- level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
285
+ else: # dict for levels 2 and 3
286
+ if level == 3 and species_list is not None:
287
+ # For species, the order is determined by species_list
288
+ level_to_idx[level] = {label: idx for idx, label in enumerate(species_list)}
289
+ else:
290
+ # For genus (and as a fallback for species), sort alphabetically
291
+ sorted_keys = sorted(labels.keys())
292
+ level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
293
+
260
294
  for child, parent in labels.items():
261
295
  if (level, parent) not in parent_child_relationship:
262
296
  parent_child_relationship[(level, parent)] = []
@@ -670,7 +704,7 @@ if __name__ == '__main__':
670
704
  species_list = [
671
705
  "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
672
706
  "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
673
- "Eristalis tenax"
707
+ "Eristalis tenax", "unknown"
674
708
  ]
675
- train_multitask(species_list=species_list, epochs=2)
709
+ train(species_list=species_list, epochs=2)
676
710