bplusplus 0.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

@@ -0,0 +1,670 @@
1
+ # pip install ultralytics torchvision pillow numpy scikit-learn tabulate tqdm requests
2
+
3
+ import os
4
+ import cv2
5
+ import torch
6
+ from ultralytics import YOLO
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import numpy as np
10
+ from torchvision.models import resnet50
11
+ import torch.nn as nn
12
+ from sklearn.metrics import classification_report, accuracy_score
13
+ import time
14
+ import argparse
15
+ from collections import defaultdict
16
+ from tabulate import tabulate
17
+ from tqdm import tqdm
18
+ import csv
19
+ import requests
20
+ import logging
21
+ import sys
22
+
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def test_multitask(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
27
+ """
28
+ Run the two-stage classifier on a test set.
29
+
30
+ Args:
31
+ species_list (list): List of species names used for training
32
+ test_set (str): Path to the test directory
33
+ yolo_weights (str): Path to the YOLO model file
34
+ hierarchical_weights (str): Path to the hierarchical classifier model file
35
+ output_dir (str): Directory to save output CSV files (default: current directory)
36
+
37
+ Returns:
38
+ Results from the classifier
39
+ """
40
+ classifier = TestTwoStage(yolo_weights, hierarchical_weights, species_list, output_dir)
41
+ results = classifier.run(test_set)
42
+ print("Testing complete with metrics calculated at all taxonomic levels")
43
+ return results
44
+
45
+ def cuda_cleanup():
46
+ """Clear CUDA cache and reset device"""
47
+ if torch.cuda.is_available():
48
+ torch.cuda.empty_cache()
49
+ torch.cuda.reset_peak_memory_stats()
50
+
51
+ def setup_gpu():
52
+ """Set up GPU with better error handling and reporting"""
53
+ if not torch.cuda.is_available():
54
+ logger.warning("CUDA is not available on this system")
55
+ return torch.device("cpu")
56
+
57
+ try:
58
+ gpu_count = torch.cuda.device_count()
59
+ logger.info(f"Found {gpu_count} CUDA device(s)")
60
+
61
+ for i in range(gpu_count):
62
+ gpu_properties = torch.cuda.get_device_properties(i)
63
+ logger.info(f"GPU {i}: {gpu_properties.name} with {gpu_properties.total_memory / 1e9:.2f} GB memory")
64
+
65
+ device = torch.device("cuda:0")
66
+ test_tensor = torch.ones(1, device=device)
67
+ test_result = test_tensor * 2
68
+ del test_tensor, test_result
69
+
70
+ logger.info("CUDA initialization successful")
71
+ return device
72
+ except Exception as e:
73
+ logger.error(f"CUDA initialization error: {str(e)}")
74
+ logger.warning("Falling back to CPU")
75
+ return torch.device("cpu")
76
+
77
+ class HierarchicalInsectClassifier(nn.Module):
78
+ def __init__(self, num_classes_per_level):
79
+ """
80
+ Args:
81
+ num_classes_per_level (list): Number of classes for each taxonomic level
82
+ """
83
+ super(HierarchicalInsectClassifier, self).__init__()
84
+
85
+ self.backbone = resnet50(pretrained=True)
86
+ backbone_output_features = self.backbone.fc.in_features
87
+ self.backbone.fc = nn.Identity() # Remove the final fully connected layer
88
+
89
+ self.branches = nn.ModuleList()
90
+ for num_classes in num_classes_per_level:
91
+ branch = nn.Sequential(
92
+ nn.Linear(backbone_output_features, 512),
93
+ nn.ReLU(),
94
+ nn.Dropout(0.5),
95
+ nn.Linear(512, num_classes)
96
+ )
97
+ self.branches.append(branch)
98
+
99
+ self.num_levels = len(num_classes_per_level)
100
+
101
+ self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
102
+ self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
103
+ self.class_counts = [0] * sum(num_classes_per_level)
104
+ self.output_history = defaultdict(list)
105
+
106
+ def forward(self, x):
107
+ R0 = self.backbone(x)
108
+
109
+ outputs = []
110
+ for branch in self.branches:
111
+ outputs.append(branch(R0))
112
+
113
+ return outputs
114
+
115
+ def get_taxonomy(species_list):
116
+ """
117
+ Retrieves taxonomic information for a list of species from GBIF API.
118
+ Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
119
+ """
120
+ taxonomy = {1: [], 2: {}, 3: {}}
121
+ species_to_family = {}
122
+ family_to_order = {}
123
+
124
+ logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
125
+
126
+ print("\nTaxonomy Results:")
127
+ print("-" * 80)
128
+ print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
129
+ print("-" * 80)
130
+
131
+ for species_name in species_list:
132
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
133
+ try:
134
+ response = requests.get(url)
135
+ data = response.json()
136
+
137
+ if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
138
+ family = data.get('family')
139
+ order = data.get('order')
140
+
141
+ if family and order:
142
+ status = "OK"
143
+
144
+ print(f"{species_name:<30} {order:<20} {family:<20} {status}")
145
+
146
+ species_to_family[species_name] = family
147
+ family_to_order[family] = order
148
+
149
+ if order not in taxonomy[1]:
150
+ taxonomy[1].append(order)
151
+
152
+ taxonomy[2][family] = order
153
+ taxonomy[3][species_name] = family
154
+ else:
155
+ error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
156
+ logger.error(error_msg)
157
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
158
+ print(f"Error: {error_msg}")
159
+ sys.exit(1) # Stop the script
160
+ else:
161
+ error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
162
+ logger.error(error_msg)
163
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
164
+ print(f"Error: {error_msg}")
165
+ sys.exit(1) # Stop the script
166
+
167
+ except Exception as e:
168
+ error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
169
+ logger.error(error_msg)
170
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
171
+ print(f"Error: {error_msg}")
172
+ sys.exit(1) # Stop the script
173
+
174
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
175
+ print("-" * 80)
176
+
177
+ num_orders = len(taxonomy[1])
178
+ num_families = len(taxonomy[2])
179
+ num_species = len(taxonomy[3])
180
+
181
+ print("\nOrder indices:")
182
+ for i, order in enumerate(taxonomy[1]):
183
+ print(f" {i}: {order}")
184
+
185
+ print("\nFamily indices:")
186
+ for i, family in enumerate(taxonomy[2].keys()):
187
+ print(f" {i}: {family}")
188
+
189
+ print("\nSpecies indices:")
190
+ for i, species in enumerate(species_list):
191
+ print(f" {i}: {species}")
192
+
193
+ logger.info(f"Taxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
194
+ return taxonomy, species_to_family, family_to_order
195
+
196
+ def create_mappings(taxonomy):
197
+ """Create index mappings from taxonomy"""
198
+ level_to_idx = {}
199
+ idx_to_level = {}
200
+
201
+ for level, labels in taxonomy.items():
202
+ if isinstance(labels, list):
203
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
204
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
205
+ else: # Dictionary
206
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
207
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels.keys())}
208
+
209
+ return level_to_idx, idx_to_level
210
+
211
+ class TestTwoStage:
212
+ def __init__(self, yolo_model_path, hierarchical_model_path, species_names, output_dir="."):
213
+ cuda_cleanup()
214
+
215
+ self.device = setup_gpu()
216
+ logger.info(f"Using device: {self.device}")
217
+
218
+ # Create output directory if it doesn't exist
219
+ os.makedirs(output_dir, exist_ok=True)
220
+ self.output_dir = output_dir
221
+ logger.info(f"Results will be saved to: {self.output_dir}")
222
+
223
+ print(f"Using device: {self.device}")
224
+
225
+ self.yolo_model = YOLO(yolo_model_path)
226
+
227
+ self.species_names = species_names
228
+
229
+ logger.info(f"Loading model from {hierarchical_model_path}")
230
+ try:
231
+ checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
232
+ logger.info("Model loaded to CPU successfully")
233
+ except Exception as e:
234
+ logger.error(f"Error loading model: {e}")
235
+ raise
236
+
237
+ if "model_state_dict" in checkpoint:
238
+ state_dict = checkpoint["model_state_dict"]
239
+
240
+ if "taxonomy" in checkpoint:
241
+ print("Using taxonomy from saved model")
242
+ taxonomy = checkpoint["taxonomy"]
243
+ if "species_list" in checkpoint:
244
+ saved_species = checkpoint["species_list"]
245
+ print(f"Saved model was trained on: {', '.join(saved_species)}")
246
+
247
+ taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
248
+ else:
249
+ taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
250
+ else:
251
+ state_dict = checkpoint
252
+ taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
253
+
254
+ level_to_idx, idx_to_level = create_mappings(taxonomy)
255
+
256
+ self.level_to_idx = level_to_idx
257
+ self.idx_to_level = idx_to_level
258
+
259
+ if hasattr(taxonomy, "items"):
260
+ num_classes_per_level = [len(classes) if isinstance(classes, list) else len(classes.keys())
261
+ for level, classes in taxonomy.items()]
262
+ else:
263
+ num_classes_per_level = [4, 5, 9] # Example values, adjust as needed
264
+
265
+ print(f"Using model with class counts: {num_classes_per_level}")
266
+
267
+ self.classification_model = HierarchicalInsectClassifier(num_classes_per_level)
268
+
269
+ try:
270
+ self.classification_model.load_state_dict(state_dict)
271
+ print("Model weights loaded successfully")
272
+ except Exception as e:
273
+ print(f"Error loading model weights: {e}")
274
+ print("Attempting to load with strict=False...")
275
+ self.classification_model.load_state_dict(state_dict, strict=False)
276
+ print("Model weights loaded with strict=False")
277
+
278
+ try:
279
+ self.classification_model.to(self.device)
280
+ print(f"Model successfully transferred to {self.device}")
281
+ except RuntimeError as e:
282
+ logger.error(f"Error transferring model to {self.device}: {e}")
283
+ print(f"Error transferring model to {self.device}, falling back to CPU")
284
+ self.device = torch.device("cpu")
285
+ # No need to move to CPU since it's already there
286
+
287
+ self.classification_model.eval()
288
+
289
+ self.classification_transform = transforms.Compose([
290
+ transforms.Resize((768, 768)), # Fixed size for all validation images
291
+ transforms.CenterCrop(640),
292
+ transforms.ToTensor(),
293
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
294
+ ])
295
+
296
+ print("Model successfully loaded")
297
+ print(f"Using species: {', '.join(species_names)}")
298
+
299
+ self.species_to_family = species_to_family
300
+ self.family_to_order = family_to_order
301
+
302
+ def get_frames(self, test_dir):
303
+ image_dir = os.path.join(test_dir, "images")
304
+ label_dir = os.path.join(test_dir, "labels")
305
+
306
+ predicted_frames = []
307
+ predicted_family_frames = []
308
+ predicted_order_frames = []
309
+ true_species_frames = []
310
+ true_family_frames = []
311
+ true_order_frames = []
312
+ image_names = []
313
+
314
+ start_time = time.time() # Start timing
315
+
316
+ for image_name in tqdm(os.listdir(image_dir), desc="Processing Images", unit="image"):
317
+ image_names.append(image_name)
318
+ image_path = os.path.join(image_dir, image_name)
319
+ label_path = os.path.join(label_dir, image_name.replace('.jpg', '.txt'))
320
+
321
+ frame = cv2.imread(image_path)
322
+ # Suppress print statements from YOLO model
323
+ with torch.no_grad():
324
+ results = self.yolo_model(frame, conf=0.3, iou=0.5, verbose=False)
325
+
326
+ detections = results[0].boxes
327
+ predicted_frame = []
328
+ predicted_family_frame = []
329
+ predicted_order_frame = []
330
+
331
+ if detections:
332
+ for box in detections:
333
+ xyxy = box.xyxy.cpu().numpy().flatten()
334
+ x1, y1, x2, y2 = xyxy[:4]
335
+ width = x2 - x1
336
+ height = y2 - y1
337
+ x_center = x1 + width / 2
338
+ y_center = y1 + height / 2
339
+
340
+ insect_crop = frame[int(y1):int(y2), int(x1):int(x2)]
341
+ insect_crop_rgb = cv2.cvtColor(insect_crop, cv2.COLOR_BGR2RGB)
342
+ pil_img = Image.fromarray(insect_crop_rgb)
343
+ input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
344
+
345
+ with torch.no_grad():
346
+ outputs = self.classification_model(input_tensor)
347
+
348
+ # Get all taxonomic level predictions
349
+ order_output = outputs[0] # First output is order (level 1)
350
+ family_output = outputs[1] # Second output is family (level 2)
351
+ species_output = outputs[2] # Third output is species (level 3)
352
+
353
+ # Get prediction indices
354
+ order_idx = order_output.argmax(dim=1).item()
355
+ family_idx = family_output.argmax(dim=1).item()
356
+ species_idx = species_output.argmax(dim=1).item()
357
+
358
+ img_height, img_width, _ = frame.shape
359
+ x_center_norm = x_center / img_width
360
+ y_center_norm = y_center / img_height
361
+ width_norm = width / img_width
362
+ height_norm = height / img_height
363
+
364
+ # Create box coordinates in YOLO format
365
+ box_coords = [x_center_norm, y_center_norm, width_norm, height_norm]
366
+
367
+ # Add predictions for each taxonomic level
368
+ predicted_frame.append([species_idx] + box_coords)
369
+ predicted_family_frame.append([family_idx] + box_coords)
370
+ predicted_order_frame.append([order_idx] + box_coords)
371
+
372
+ predicted_frames.append(predicted_frame if predicted_frame else [])
373
+ predicted_family_frames.append(predicted_family_frame if predicted_family_frame else [])
374
+ predicted_order_frames.append(predicted_order_frame if predicted_order_frame else [])
375
+
376
+ true_species_frame = []
377
+ true_family_frame = []
378
+ true_order_frame = []
379
+
380
+ if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
381
+ with open(label_path, 'r') as f:
382
+ for line in f:
383
+ label_line = line.strip().split()
384
+ species_idx = int(label_line[0])
385
+ box_coords = list(map(np.float32, label_line[1:]))
386
+
387
+ true_species_frame.append([species_idx] + box_coords)
388
+
389
+ if species_idx < len(self.species_names):
390
+ species_name = self.species_names[species_idx]
391
+
392
+ if species_name in self.species_to_family:
393
+ family_name = self.species_to_family[species_name]
394
+ # Get the index of the family in the level_to_idx mapping
395
+ if 2 in self.level_to_idx and family_name in self.level_to_idx[2]:
396
+ family_idx = self.level_to_idx[2][family_name]
397
+ true_family_frame.append([family_idx] + box_coords)
398
+
399
+ if family_name in self.family_to_order:
400
+ order_name = self.family_to_order[family_name]
401
+ if 1 in self.level_to_idx and order_name in self.level_to_idx[1]:
402
+ order_idx = self.level_to_idx[1][order_name]
403
+ true_order_frame.append([order_idx] + box_coords)
404
+
405
+ true_species_frames.append(true_species_frame if true_species_frame else [])
406
+ true_family_frames.append(true_family_frame if true_family_frame else [])
407
+ true_order_frames.append(true_order_frame if true_order_frame else [])
408
+
409
+ end_time = time.time() # End timing
410
+
411
+ # Create a more descriptive filename with timestamp
412
+ output_file = os.path.join(self.output_dir, f"results_hierarchical_{time.strftime('%Y%m%d_%H%M%S')}.csv")
413
+
414
+ with open(output_file, "w", newline='') as f:
415
+ writer = csv.writer(f)
416
+ writer.writerow([
417
+ "Image Name",
418
+ "True Species Detections",
419
+ "True Family Detections",
420
+ "True Order Detections",
421
+ "Species Detections",
422
+ "Family Detections",
423
+ "Order Detections"
424
+ ])
425
+
426
+ for image_name, true_species, true_family, true_order, species_pred, family_pred, order_pred in zip(
427
+ image_names,
428
+ true_species_frames,
429
+ true_family_frames,
430
+ true_order_frames,
431
+ predicted_frames,
432
+ predicted_family_frames,
433
+ predicted_order_frames
434
+ ):
435
+ writer.writerow([
436
+ image_name,
437
+ true_species,
438
+ true_family,
439
+ true_order,
440
+ species_pred,
441
+ family_pred,
442
+ order_pred
443
+ ])
444
+
445
+ print(f"Results saved to {output_file}")
446
+ return predicted_frames, true_species_frames, end_time - start_time, predicted_family_frames, predicted_order_frames, true_family_frames, true_order_frames
447
+
448
+ def run(self, test_dir):
449
+ results = self.get_frames(test_dir)
450
+ predicted_frames, true_species_frames, total_time = results[0], results[1], results[2]
451
+ predicted_family_frames = results[3]
452
+ predicted_order_frames = results[4]
453
+ true_family_frames = results[5]
454
+ true_order_frames = results[6]
455
+
456
+ num_frames = len(os.listdir(os.path.join(test_dir, 'images')))
457
+ avg_time_per_frame = total_time / num_frames
458
+
459
+ print(f"\nTotal time: {total_time:.2f} seconds")
460
+ print(f"Average time per frame: {avg_time_per_frame:.4f} seconds")
461
+
462
+ self.calculate_metrics(
463
+ predicted_frames, true_species_frames,
464
+ predicted_family_frames, true_family_frames,
465
+ predicted_order_frames, true_order_frames
466
+ )
467
+
468
+ def calculate_metrics(self, predicted_species_frames, true_species_frames,
469
+ predicted_family_frames, true_family_frames,
470
+ predicted_order_frames, true_order_frames):
471
+ """Calculate metrics at all taxonomic levels"""
472
+ # Get list of species, families and orders
473
+ species_list = self.species_names
474
+ family_list = sorted(list(set(self.species_to_family.values())))
475
+ order_list = sorted(list(set(self.family_to_order.values())))
476
+
477
+ # Print the index mappings we're using for evaluation
478
+ print("\nUsing the following index mappings for evaluation:")
479
+ print("\nOrder indices:")
480
+ for i, order in enumerate(order_list):
481
+ print(f" {i}: {order}")
482
+
483
+ print("\nFamily indices:")
484
+ for i, family in enumerate(family_list):
485
+ print(f" {i}: {family}")
486
+
487
+ print("\nSpecies indices:")
488
+ for i, species in enumerate(species_list):
489
+ print(f" {i}: {species}")
490
+
491
+ # Dictionary to track prediction category counts for debugging
492
+ prediction_counts = {
493
+ "true_species_boxes": sum(len(frame) for frame in true_species_frames),
494
+ "true_family_boxes": sum(len(frame) for frame in true_family_frames),
495
+ "true_order_boxes": sum(len(frame) for frame in true_order_frames),
496
+ "predicted_species": sum(len(frame) for frame in predicted_species_frames),
497
+ "predicted_family": sum(len(frame) for frame in predicted_family_frames),
498
+ "predicted_order": sum(len(frame) for frame in predicted_order_frames)
499
+ }
500
+
501
+ print(f"Prediction counts: {prediction_counts}")
502
+
503
+ # Calculate metrics for all three levels
504
+ print("\n=== Species-level Metrics ===")
505
+ self.get_metrics(predicted_species_frames, true_species_frames, species_list)
506
+
507
+ print("\n=== Family-level Metrics ===")
508
+ self.get_metrics(predicted_family_frames, true_family_frames, family_list)
509
+
510
+ print("\n=== Order-level Metrics ===")
511
+ self.get_metrics(predicted_order_frames, true_order_frames, order_list)
512
+
513
+ def get_metrics(self, predicted_frames, true_frames, labels):
514
+ """Calculate metrics for object detection predictions"""
515
+ def calculate_iou(box1, box2):
516
+ x1_min, y1_min = box1[1] - box1[3] / 2, box1[2] - box1[4] / 2
517
+ x1_max, y1_max = box1[1] + box1[3] / 2, box1[2] + box1[4] / 2
518
+ x2_min, y2_min = box2[1] - box2[3] / 2, box2[2] - box2[4] / 2
519
+ x2_max, y2_max = box2[1] + box2[3] / 2, box2[2] + box2[4] / 2
520
+
521
+ inter_x_min = max(x1_min, x2_min)
522
+ inter_y_min = max(y1_min, y2_min)
523
+ inter_x_max = min(x1_max, x2_max)
524
+ inter_y_max = min(y1_max, y2_max)
525
+
526
+ inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
527
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
528
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
529
+
530
+ iou = inter_area / (box1_area + box2_area - inter_area)
531
+ return iou
532
+
533
+ def calculate_precision_recall(pred_boxes, true_boxes, iou_threshold=0.5):
534
+ label_results = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
535
+ generic_tp = 0
536
+ generic_fp = 0
537
+
538
+ matched_true_boxes = set()
539
+
540
+ for pred_box in pred_boxes:
541
+ label_idx = pred_box[0]
542
+ matched = False
543
+
544
+ best_iou = 0
545
+ best_match_idx = -1
546
+
547
+ for i, true_box in enumerate(true_boxes):
548
+ if i in matched_true_boxes:
549
+ continue
550
+
551
+ iou = calculate_iou(pred_box, true_box)
552
+ if iou >= iou_threshold and iou > best_iou:
553
+ best_iou = iou
554
+ best_match_idx = i
555
+
556
+ if best_match_idx >= 0:
557
+ matched = True
558
+ true_box = true_boxes[best_match_idx]
559
+ matched_true_boxes.add(best_match_idx)
560
+ generic_tp += 1
561
+
562
+ if pred_box[0] == true_box[0]:
563
+ label_results[label_idx]['tp'] += 1
564
+ else:
565
+ label_results[label_idx]['fp'] += 1
566
+ true_label_idx = true_box[0]
567
+ label_results[true_label_idx]['fn'] += 1
568
+
569
+ if not matched:
570
+ label_results[label_idx]['fp'] += 1
571
+ generic_fp += 1
572
+
573
+ for i, true_box in enumerate(true_boxes):
574
+ if i not in matched_true_boxes:
575
+ label_idx = true_box[0]
576
+ label_results[label_idx]['fn'] += 1
577
+
578
+ generic_fn = len(true_boxes) - len(matched_true_boxes)
579
+
580
+ return label_results, generic_tp, generic_fp, generic_fn
581
+
582
+ label_metrics = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0})
583
+ background_metrics = {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0}
584
+ generic_metrics = {'tp': 0, 'fp': 0, 'fn': 0}
585
+
586
+ for true_frame in true_frames:
587
+ if not true_frame: # Empty frame (background only)
588
+ background_metrics['support'] += 1
589
+ else:
590
+ for true_box in true_frame:
591
+ label_idx = true_box[0]
592
+ label_metrics[label_idx]['support'] += 1 # Count each detection, not just unique labels
593
+
594
+ for pred_frame, true_frame in zip(predicted_frames, true_frames):
595
+ if not pred_frame and not true_frame:
596
+ background_metrics['tp'] += 1
597
+ elif not pred_frame:
598
+ background_metrics['fn'] += 1
599
+ elif not true_frame:
600
+ background_metrics['fp'] += 1
601
+ else:
602
+ frame_results, g_tp, g_fp, g_fn = calculate_precision_recall(pred_frame, true_frame)
603
+
604
+ for label_idx, metrics in frame_results.items():
605
+ label_metrics[label_idx]['tp'] += metrics['tp']
606
+ label_metrics[label_idx]['fp'] += metrics['fp']
607
+ label_metrics[label_idx]['fn'] += metrics['fn']
608
+
609
+ generic_metrics['tp'] += g_tp
610
+ generic_metrics['fp'] += g_fp
611
+ generic_metrics['fn'] += g_fn
612
+
613
+ table_data = []
614
+
615
+ for label_idx, metrics in label_metrics.items():
616
+ tp = metrics['tp']
617
+ fp = metrics['fp']
618
+ fn = metrics['fn']
619
+ support = metrics['support']
620
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
621
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
622
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
623
+ label_name = labels[label_idx] if label_idx < len(labels) else f"Label {label_idx}"
624
+ table_data.append([label_name, f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
625
+
626
+ tp = background_metrics['tp']
627
+ fp = background_metrics['fp']
628
+ fn = background_metrics['fn']
629
+ support = background_metrics['support']
630
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
631
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
632
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
633
+ table_data.append(["Background", f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
634
+
635
+ headers = ["Label", "Precision", "Recall", "F1 Score", "Support"]
636
+ total_tp = sum(metrics['tp'] for metrics in label_metrics.values())
637
+ total_fp = sum(metrics['fp'] for metrics in label_metrics.values())
638
+ total_fn = sum(metrics['fn'] for metrics in label_metrics.values())
639
+ total_support = sum(metrics['support'] for metrics in label_metrics.values())
640
+
641
+ total_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
642
+ total_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
643
+ total_f1_score = 2 * (total_precision * total_recall) / (total_precision + total_recall) if (total_precision + total_recall) > 0 else 0
644
+
645
+ table_data.append(["\nTotal (excluding background)", f"{total_precision:.2f}", f"{total_recall:.2f}", f"{total_f1_score:.2f}", f"{total_support}"])
646
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
647
+
648
+ generic_tp = generic_metrics['tp']
649
+ generic_fp = generic_metrics['fp']
650
+ generic_fn = generic_metrics['fn']
651
+
652
+ generic_precision = generic_tp / (generic_tp + generic_fp) if (generic_tp + generic_fp) > 0 else 0
653
+ generic_recall = generic_tp / (generic_tp + generic_fn) if (generic_tp + generic_fn) > 0 else 0
654
+ generic_f1_score = 2 * (generic_precision * generic_recall) / (generic_precision + generic_recall) if (generic_precision + generic_recall) > 0 else 0
655
+
656
+ print("\nGeneric Total", f"{generic_precision:.2f}", f"{generic_recall:.2f}", f"{generic_f1_score:.2f}")
657
+
658
+ if __name__ == "__main__":
659
+ species_names = [
660
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
661
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
662
+ "Eristalis tenax"
663
+ ]
664
+
665
+ test_directory = "/mnt/nvme0n1p1/mit/two-stage-detection/bjerge-test"
666
+ yolo_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/small-generic.pt"
667
+ hierarchical_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/hierarchical/hierarchical-weights.pth"
668
+ output_directory = "./output"
669
+
670
+ test_multitask(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)