bplusplus 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

Files changed (97) hide show
  1. bplusplus/__init__.py +4 -2
  2. bplusplus/collect.py +69 -5
  3. bplusplus/hierarchical/test.py +670 -0
  4. bplusplus/hierarchical/train.py +676 -0
  5. bplusplus/prepare.py +228 -64
  6. bplusplus/resnet/test.py +473 -0
  7. bplusplus/resnet/train.py +329 -0
  8. bplusplus-1.2.0.dist-info/METADATA +249 -0
  9. bplusplus-1.2.0.dist-info/RECORD +12 -0
  10. bplusplus/yolov5detect/__init__.py +0 -1
  11. bplusplus/yolov5detect/detect.py +0 -444
  12. bplusplus/yolov5detect/export.py +0 -1530
  13. bplusplus/yolov5detect/insect.yaml +0 -8
  14. bplusplus/yolov5detect/models/__init__.py +0 -0
  15. bplusplus/yolov5detect/models/common.py +0 -1109
  16. bplusplus/yolov5detect/models/experimental.py +0 -130
  17. bplusplus/yolov5detect/models/hub/anchors.yaml +0 -56
  18. bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +0 -52
  19. bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +0 -42
  20. bplusplus/yolov5detect/models/hub/yolov3.yaml +0 -52
  21. bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +0 -49
  22. bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +0 -43
  23. bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +0 -55
  24. bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +0 -42
  25. bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +0 -57
  26. bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +0 -68
  27. bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +0 -49
  28. bplusplus/yolov5detect/models/hub/yolov5l6.yaml +0 -61
  29. bplusplus/yolov5detect/models/hub/yolov5m6.yaml +0 -61
  30. bplusplus/yolov5detect/models/hub/yolov5n6.yaml +0 -61
  31. bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +0 -50
  32. bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +0 -49
  33. bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +0 -49
  34. bplusplus/yolov5detect/models/hub/yolov5s6.yaml +0 -61
  35. bplusplus/yolov5detect/models/hub/yolov5x6.yaml +0 -61
  36. bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +0 -49
  37. bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +0 -49
  38. bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +0 -49
  39. bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +0 -49
  40. bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +0 -49
  41. bplusplus/yolov5detect/models/tf.py +0 -797
  42. bplusplus/yolov5detect/models/yolo.py +0 -495
  43. bplusplus/yolov5detect/models/yolov5l.yaml +0 -49
  44. bplusplus/yolov5detect/models/yolov5m.yaml +0 -49
  45. bplusplus/yolov5detect/models/yolov5n.yaml +0 -49
  46. bplusplus/yolov5detect/models/yolov5s.yaml +0 -49
  47. bplusplus/yolov5detect/models/yolov5x.yaml +0 -49
  48. bplusplus/yolov5detect/utils/__init__.py +0 -97
  49. bplusplus/yolov5detect/utils/activations.py +0 -134
  50. bplusplus/yolov5detect/utils/augmentations.py +0 -448
  51. bplusplus/yolov5detect/utils/autoanchor.py +0 -175
  52. bplusplus/yolov5detect/utils/autobatch.py +0 -70
  53. bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
  54. bplusplus/yolov5detect/utils/aws/mime.sh +0 -26
  55. bplusplus/yolov5detect/utils/aws/resume.py +0 -41
  56. bplusplus/yolov5detect/utils/aws/userdata.sh +0 -27
  57. bplusplus/yolov5detect/utils/callbacks.py +0 -72
  58. bplusplus/yolov5detect/utils/dataloaders.py +0 -1385
  59. bplusplus/yolov5detect/utils/docker/Dockerfile +0 -73
  60. bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +0 -40
  61. bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +0 -42
  62. bplusplus/yolov5detect/utils/downloads.py +0 -136
  63. bplusplus/yolov5detect/utils/flask_rest_api/README.md +0 -70
  64. bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +0 -17
  65. bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +0 -49
  66. bplusplus/yolov5detect/utils/general.py +0 -1294
  67. bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +0 -25
  68. bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +0 -6
  69. bplusplus/yolov5detect/utils/google_app_engine/app.yaml +0 -16
  70. bplusplus/yolov5detect/utils/loggers/__init__.py +0 -476
  71. bplusplus/yolov5detect/utils/loggers/clearml/README.md +0 -222
  72. bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
  73. bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +0 -230
  74. bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +0 -90
  75. bplusplus/yolov5detect/utils/loggers/comet/README.md +0 -250
  76. bplusplus/yolov5detect/utils/loggers/comet/__init__.py +0 -551
  77. bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +0 -151
  78. bplusplus/yolov5detect/utils/loggers/comet/hpo.py +0 -126
  79. bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +0 -135
  80. bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
  81. bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +0 -210
  82. bplusplus/yolov5detect/utils/loss.py +0 -259
  83. bplusplus/yolov5detect/utils/metrics.py +0 -381
  84. bplusplus/yolov5detect/utils/plots.py +0 -517
  85. bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
  86. bplusplus/yolov5detect/utils/segment/augmentations.py +0 -100
  87. bplusplus/yolov5detect/utils/segment/dataloaders.py +0 -366
  88. bplusplus/yolov5detect/utils/segment/general.py +0 -160
  89. bplusplus/yolov5detect/utils/segment/loss.py +0 -198
  90. bplusplus/yolov5detect/utils/segment/metrics.py +0 -225
  91. bplusplus/yolov5detect/utils/segment/plots.py +0 -152
  92. bplusplus/yolov5detect/utils/torch_utils.py +0 -482
  93. bplusplus/yolov5detect/utils/triton.py +0 -90
  94. bplusplus-1.1.0.dist-info/METADATA +0 -179
  95. bplusplus-1.1.0.dist-info/RECORD +0 -92
  96. {bplusplus-1.1.0.dist-info → bplusplus-1.2.0.dist-info}/LICENSE +0 -0
  97. {bplusplus-1.1.0.dist-info → bplusplus-1.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,473 @@
1
+ # pip install ultralytics torchvision pillow numpy scikit-learn tabulate tqdm
2
+ # python3 tests/two-stage(yolo-resnet).py --data ' --yolo_weights --resnet_weights --use_resnet50
3
+
4
+ import os
5
+ import cv2
6
+ import torch
7
+ from ultralytics import YOLO
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ import numpy as np
11
+ from torchvision.models import resnet152, resnet50
12
+ import torch.nn as nn
13
+ from sklearn.metrics import classification_report, accuracy_score
14
+ import time
15
+ from collections import defaultdict
16
+ from tabulate import tabulate
17
+ from tqdm import tqdm
18
+ import csv
19
+ import requests
20
+ import sys
21
+
22
+ def test_resnet(data_path, yolo_weights, resnet_weights, model="resnet152", species_names=None, output_dir="output"):
23
+ """
24
+ Run the two-stage detection and classification test
25
+
26
+ Args:
27
+ data_path (str): Path to the test directory
28
+ yolo_weights (str): Path to the YOLO model file
29
+ resnet_weights (str): Path to the ResNet model file
30
+ model (str): Model type, either "resnet50" or "resnet152"
31
+ species_names (list): List of species names
32
+ output_dir (str): Directory to save output CSV files
33
+ """
34
+ use_resnet50 = model == "resnet50"
35
+ classifier = TestTwoStage(yolo_weights, resnet_weights, use_resnet50=use_resnet50,
36
+ species_names=species_names, output_dir=output_dir)
37
+ classifier.run(data_path)
38
+
39
+ class TestTwoStage:
40
+ def __init__(self, yolo_model_path, resnet_model_path, num_classes=9, use_resnet50=False, species_names="", output_dir="output"):
41
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print(f"Using device: {self.device}")
43
+
44
+ self.output_dir = output_dir
45
+ os.makedirs(self.output_dir, exist_ok=True)
46
+
47
+ self.yolo_model = YOLO(yolo_model_path)
48
+ self.classification_model = resnet50(pretrained=False) if use_resnet50 else resnet152(pretrained=False)
49
+
50
+ self.classification_model.fc = nn.Sequential(
51
+ nn.Dropout(0.4), # Using dropout probability of 0.4 as in training
52
+ nn.Linear(self.classification_model.fc.in_features, num_classes)
53
+ )
54
+
55
+ state_dict = torch.load(resnet_model_path, map_location=self.device)
56
+ self.classification_model.load_state_dict(state_dict)
57
+ self.classification_model.to(self.device)
58
+ self.classification_model.eval()
59
+
60
+ self.classification_transform = transforms.Compose([
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
+ ])
64
+ self.species_names = species_names
65
+
66
+ def get_frames(self, test_dir):
67
+ image_dir = os.path.join(test_dir, "images")
68
+ label_dir = os.path.join(test_dir, "labels")
69
+
70
+ predicted_frames = []
71
+ true_frames = []
72
+ image_names = []
73
+
74
+ start_time = time.time() # Start timing
75
+
76
+ for image_name in tqdm(os.listdir(image_dir), desc="Processing Images", unit="image"):
77
+ image_names.append(image_name)
78
+ image_path = os.path.join(image_dir, image_name)
79
+ label_path = os.path.join(label_dir, image_name.replace('.jpg', '.txt'))
80
+
81
+ frame = cv2.imread(image_path)
82
+ # Suppress print statements from YOLO model
83
+ with torch.no_grad():
84
+ results = self.yolo_model(frame, conf=0.3, iou=0.5, verbose=False)
85
+
86
+ detections = results[0].boxes
87
+ predicted_frame = []
88
+
89
+ if detections:
90
+ for box in detections:
91
+ xyxy = box.xyxy.cpu().numpy().flatten()
92
+ x1, y1, x2, y2 = xyxy[:4]
93
+ width = x2 - x1
94
+ height = y2 - y1
95
+ x_center = x1 + width / 2
96
+ y_center = y1 + height / 2
97
+
98
+ insect_crop = frame[int(y1):int(y2), int(x1):int(x2)]
99
+ insect_crop_rgb = cv2.cvtColor(insect_crop, cv2.COLOR_BGR2RGB)
100
+ pil_img = Image.fromarray(insect_crop_rgb)
101
+ input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
102
+
103
+ with torch.no_grad():
104
+ outputs = self.classification_model(input_tensor)
105
+
106
+ # Directly use the model output without any mapping
107
+ predicted_class_idx = outputs.argmax(dim=1).item()
108
+
109
+ img_height, img_width, _ = frame.shape
110
+ x_center_norm = x_center / img_width
111
+ y_center_norm = y_center / img_height
112
+ width_norm = width / img_width
113
+ height_norm = height / img_height
114
+ predicted_frame.append([predicted_class_idx, x_center_norm, y_center_norm, width_norm, height_norm])
115
+
116
+ predicted_frames.append(predicted_frame if predicted_frame else [])
117
+
118
+ true_frame = []
119
+ if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
120
+ with open(label_path, 'r') as f:
121
+ for line in f:
122
+ label_line = line.strip().split()
123
+ true_frame.append([int(label_line[0]), *map(np.float32, label_line[1:])])
124
+
125
+ true_frames.append(true_frame if true_frame else [])
126
+
127
+ end_time = time.time() # End timing
128
+
129
+ model_type = "resnet50" if isinstance(self.classification_model, type(resnet50())) else "resnet152"
130
+ output_file = os.path.join(self.output_dir, f"results_{model_type}_{time.strftime('%Y%m%d_%H%M%S')}.csv")
131
+
132
+ with open(output_file, "w") as f:
133
+ writer = csv.writer(f)
134
+ writer.writerow(["Image Name", "True", "Predicted"])
135
+ for image_name, true_frame, predicted_frame in zip(image_names, true_frames, predicted_frames):
136
+ writer.writerow([image_name, true_frame, predicted_frame])
137
+
138
+ print(f"Results saved to {output_file}")
139
+ return predicted_frames, true_frames, end_time - start_time
140
+
141
+ def get_taxonomic_info(self, species_list):
142
+ """
143
+ Retrieves taxonomic information for a list of species from GBIF API.
144
+ Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
145
+ """
146
+ taxonomy = {1: [], 2: {}, 3: {}}
147
+ species_to_family = {}
148
+ family_to_order = {}
149
+
150
+ print(f"Building taxonomy from GBIF for {len(species_list)} species")
151
+
152
+ print("\nTaxonomy Results:")
153
+ print("-" * 80)
154
+ print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
155
+ print("-" * 80)
156
+
157
+ for species_name in species_list:
158
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
159
+ try:
160
+ response = requests.get(url)
161
+ data = response.json()
162
+
163
+ if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
164
+ family = data.get('family')
165
+ order = data.get('order')
166
+
167
+ if family and order:
168
+ status = "OK"
169
+
170
+ print(f"{species_name:<30} {order:<20} {family:<20} {status}")
171
+
172
+ species_to_family[species_name] = family
173
+ family_to_order[family] = order
174
+
175
+ if order not in taxonomy[1]:
176
+ taxonomy[1].append(order)
177
+
178
+ taxonomy[2][family] = order
179
+ taxonomy[3][species_name] = family
180
+ else:
181
+ error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
182
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
183
+ print(f"Error: {error_msg}")
184
+ sys.exit(1) # Stop the script
185
+ else:
186
+ error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
187
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
188
+ print(f"Error: {error_msg}")
189
+ sys.exit(1) # Stop the script
190
+
191
+ except Exception as e:
192
+ error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
193
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
194
+ print(f"Error: {error_msg}")
195
+ sys.exit(1) # Stop the script
196
+
197
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
198
+ print("-" * 80)
199
+
200
+ num_orders = len(taxonomy[1])
201
+ num_families = len(taxonomy[2])
202
+ num_species = len(taxonomy[3])
203
+
204
+ print("\nOrder indices:")
205
+ for i, order in enumerate(taxonomy[1]):
206
+ print(f" {i}: {order}")
207
+
208
+ print("\nFamily indices:")
209
+ for i, family in enumerate(taxonomy[2].keys()):
210
+ print(f" {i}: {family}")
211
+
212
+ print("\nSpecies indices:")
213
+ for i, species in enumerate(species_list):
214
+ print(f" {i}: {species}")
215
+
216
+ print(f"\nTaxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
217
+
218
+ return taxonomy, species_to_family, family_to_order
219
+
220
+ def get_metrics(self, predicted_frames, true_frames, labels):
221
+ """
222
+ Calculate precision, recall, and F1 score for object detection results.
223
+ """
224
+ def calculate_iou(box1, box2):
225
+ x1_min, y1_min = box1[1] - box1[3] / 2, box1[2] - box1[4] / 2
226
+ x1_max, y1_max = box1[1] + box1[3] / 2, box1[2] + box1[4] / 2
227
+ x2_min, y2_min = box2[1] - box2[3] / 2, box2[2] - box2[4] / 2
228
+ x2_max, y2_max = box2[1] + box2[3] / 2, box2[2] + box2[4] / 2
229
+
230
+ inter_x_min = max(x1_min, x2_min)
231
+ inter_y_min = max(y1_min, y2_min)
232
+ inter_x_max = min(x1_max, x2_max)
233
+ inter_y_max = min(y1_max, y2_max)
234
+
235
+ inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
236
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
237
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
238
+
239
+ iou = inter_area / (box1_area + box2_area - inter_area)
240
+ return iou
241
+
242
+ def calculate_precision_recall(pred_boxes, true_boxes, iou_threshold=0.5):
243
+ label_results = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
244
+ generic_tp = 0
245
+ generic_fp = 0
246
+
247
+ matched_true_boxes = set()
248
+
249
+ for pred_box in pred_boxes:
250
+ label_idx = pred_box[0]
251
+ matched = False
252
+
253
+ best_iou = 0
254
+ best_match_idx = -1
255
+
256
+ for i, true_box in enumerate(true_boxes):
257
+ if i in matched_true_boxes:
258
+ continue
259
+
260
+ iou = calculate_iou(pred_box, true_box)
261
+ if iou >= iou_threshold and iou > best_iou:
262
+ best_iou = iou
263
+ best_match_idx = i
264
+
265
+ if best_match_idx >= 0:
266
+ matched = True
267
+ true_box = true_boxes[best_match_idx]
268
+ matched_true_boxes.add(best_match_idx)
269
+ generic_tp += 1
270
+
271
+ if pred_box[0] == true_box[0]:
272
+ label_results[label_idx]['tp'] += 1
273
+ else:
274
+ label_results[label_idx]['fp'] += 1
275
+ true_label_idx = true_box[0]
276
+ label_results[true_label_idx]['fn'] += 1
277
+
278
+ if not matched:
279
+ label_results[label_idx]['fp'] += 1
280
+ generic_fp += 1
281
+
282
+ for i, true_box in enumerate(true_boxes):
283
+ if i not in matched_true_boxes:
284
+ label_idx = true_box[0]
285
+ label_results[label_idx]['fn'] += 1
286
+
287
+ generic_fn = len(true_boxes) - len(matched_true_boxes)
288
+
289
+ return label_results, generic_tp, generic_fp, generic_fn
290
+
291
+ label_metrics = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0})
292
+ background_metrics = {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0}
293
+ generic_metrics = {'tp': 0, 'fp': 0, 'fn': 0}
294
+
295
+ for true_frame in true_frames:
296
+ if not true_frame: # Empty frame (background only)
297
+ background_metrics['support'] += 1
298
+ else:
299
+ for true_box in true_frame:
300
+ label_idx = true_box[0]
301
+ label_metrics[label_idx]['support'] += 1 # Count each detection, not just unique labels
302
+
303
+ for pred_frame, true_frame in zip(predicted_frames, true_frames):
304
+ if not pred_frame and not true_frame:
305
+ background_metrics['tp'] += 1
306
+ elif not pred_frame:
307
+ background_metrics['fn'] += 1
308
+ elif not true_frame:
309
+ background_metrics['fp'] += 1
310
+ else:
311
+ frame_results, g_tp, g_fp, g_fn = calculate_precision_recall(pred_frame, true_frame)
312
+
313
+ for label_idx, metrics in frame_results.items():
314
+ label_metrics[label_idx]['tp'] += metrics['tp']
315
+ label_metrics[label_idx]['fp'] += metrics['fp']
316
+ label_metrics[label_idx]['fn'] += metrics['fn']
317
+
318
+ generic_metrics['tp'] += g_tp
319
+ generic_metrics['fp'] += g_fp
320
+ generic_metrics['fn'] += g_fn
321
+
322
+ table_data = []
323
+ # Store individual class metrics for macro-averaging
324
+ class_precisions = []
325
+ class_recalls = []
326
+ class_f1s = []
327
+
328
+ for label_idx, metrics in label_metrics.items():
329
+ tp = metrics['tp']
330
+ fp = metrics['fp']
331
+ fn = metrics['fn']
332
+ support = metrics['support']
333
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
334
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
335
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
336
+
337
+ # Store for macro-averaging
338
+ class_precisions.append(precision)
339
+ class_recalls.append(recall)
340
+ class_f1s.append(f1_score)
341
+
342
+ label_name = labels[label_idx] if label_idx < len(labels) else f"Label {label_idx}"
343
+ table_data.append([label_name, f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
344
+
345
+ print(f"Debug {label_name}: TP={tp}, FP={fp}, FN={fn}")
346
+ print(f" Raw P={tp/(tp+fp) if (tp+fp)>0 else 0}, R={tp/(tp+fn) if (tp+fn)>0 else 0}")
347
+
348
+ tp = background_metrics['tp']
349
+ fp = background_metrics['fp']
350
+ fn = background_metrics['fn']
351
+ support = background_metrics['support']
352
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
353
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
354
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
355
+ table_data.append(["Background", f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
356
+
357
+ headers = ["Label", "Precision", "Recall", "F1 Score", "Support"]
358
+ total_tp = sum(metrics['tp'] for metrics in label_metrics.values())
359
+ total_fp = sum(metrics['fp'] for metrics in label_metrics.values())
360
+ total_fn = sum(metrics['fn'] for metrics in label_metrics.values())
361
+ total_support = sum(metrics['support'] for metrics in label_metrics.values())
362
+
363
+ total_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
364
+ total_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
365
+ total_f1_score = 2 * (total_precision * total_recall) / (total_precision + total_recall) if (total_precision + total_recall) > 0 else 0
366
+
367
+ table_data.append(["\nTotal (micro-avg, excl. background)", f"{total_precision:.2f}", f"{total_recall:.2f}", f"{total_f1_score:.2f}", f"{total_support}"])
368
+
369
+ # Add macro-average
370
+ if class_precisions:
371
+ macro_precision = sum(class_precisions) / len(class_precisions)
372
+ macro_recall = sum(class_recalls) / len(class_recalls)
373
+ macro_f1 = sum(class_f1s) / len(class_f1s)
374
+ table_data.append(["Total (macro-avg, excl. background)", f"{macro_precision:.2f}", f"{macro_recall:.2f}", f"{macro_f1:.2f}", f"{total_support}"])
375
+
376
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
377
+
378
+ generic_tp = generic_metrics['tp']
379
+ generic_fp = generic_metrics['fp']
380
+ generic_fn = generic_metrics['fn']
381
+
382
+ generic_precision = generic_tp / (generic_tp + generic_fp) if (generic_tp + generic_fp) > 0 else 0
383
+ generic_recall = generic_tp / (generic_tp + generic_fn) if (generic_tp + generic_fn) > 0 else 0
384
+ generic_f1_score = 2 * (generic_precision * generic_recall) / (generic_precision + generic_recall) if (generic_precision + generic_recall) > 0 else 0
385
+
386
+ print("\nGeneric Total", f"{generic_precision:.2f}", f"{generic_recall:.2f}", f"{generic_f1_score:.2f}")
387
+
388
+ return total_precision, total_recall, total_f1_score
389
+
390
+ def run(self, test_dir):
391
+ predicted_frames, true_frames, total_time = self.get_frames(test_dir)
392
+ num_frames = len(os.listdir(os.path.join(test_dir, 'images')))
393
+ avg_time_per_frame = total_time / num_frames
394
+
395
+ print(f"\nTotal time: {total_time:.2f} seconds")
396
+ print(f"Average time per frame: {avg_time_per_frame:.4f} seconds")
397
+
398
+ # Get taxonomy information for hierarchical analysis
399
+ taxonomy, species_to_family, family_to_order = self.get_taxonomic_info(self.species_names)
400
+ family_list = list(family_to_order.keys())
401
+ order_list = list(taxonomy[1])
402
+
403
+ # Convert species-level predictions to family and order levels
404
+ true_family_frames = []
405
+ true_order_frames = []
406
+ predicted_family_frames = []
407
+ predicted_order_frames = []
408
+
409
+ for true_frame in true_frames:
410
+ frame_family_boxes = []
411
+ frame_order_boxes = []
412
+
413
+ if true_frame:
414
+ for true_box in true_frame:
415
+ species_idx = true_box[0]
416
+ species_name = self.species_names[species_idx]
417
+ family_name = species_to_family[species_name]
418
+ order_name = family_to_order[family_name]
419
+
420
+ family_label = [family_list.index(family_name)] + list(true_box[1:])
421
+ order_label = [order_list.index(order_name)] + list(true_box[1:])
422
+
423
+ frame_family_boxes.append(family_label)
424
+ frame_order_boxes.append(order_label)
425
+
426
+ true_family_frames.append(frame_family_boxes)
427
+ true_order_frames.append(frame_order_boxes)
428
+
429
+ for pred_frame in predicted_frames:
430
+ frame_family_boxes = []
431
+ frame_order_boxes = []
432
+
433
+ if pred_frame:
434
+ for pred_box in pred_frame:
435
+ species_idx = pred_box[0]
436
+ species_name = self.species_names[species_idx]
437
+ family_name = species_to_family[species_name]
438
+ order_name = family_to_order[family_name]
439
+
440
+ family_label = [family_list.index(family_name)] + list(map(np.float32, pred_box[1:]))
441
+ order_label = [order_list.index(order_name)] + list(map(np.float32, pred_box[1:]))
442
+
443
+ frame_family_boxes.append(family_label)
444
+ frame_order_boxes.append(order_label)
445
+
446
+ predicted_family_frames.append(frame_family_boxes)
447
+ predicted_order_frames.append(frame_order_boxes)
448
+
449
+ # Display metrics for all taxonomic levels
450
+ print("\nSpecies Level Metrics")
451
+ self.get_metrics(predicted_frames, true_frames, self.species_names)
452
+
453
+ print("\nFamily Level Metrics")
454
+ self.get_metrics(predicted_family_frames, true_family_frames, family_list)
455
+
456
+ print("\nOrder Level Metrics")
457
+ self.get_metrics(predicted_order_frames, true_order_frames, order_list)
458
+
459
+ if __name__ == "__main__":
460
+ species_names = [
461
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
462
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
463
+ "Eristalis tenax"
464
+ ]
465
+
466
+ test_resnet(
467
+ data_path="/mnt/nvme0n1p1/mit/two-stage-detection/bjerge-test",
468
+ yolo_weights="/mnt/nvme0n1p1/mit/two-stage-detection/small-generic.pt",
469
+ resnet_weights="/mnt/nvme0n1p1/mit/two-stage-detection/output/best_resnet50.pt",
470
+ model="resnet50",
471
+ species_names=species_names,
472
+ output_dir="/mnt/nvme0n1p1/mit/two-stage-detection/output"
473
+ )