bplusplus 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +3 -5
- bplusplus/collect.py +2 -0
- bplusplus/inference.py +891 -0
- bplusplus/prepare.py +429 -540
- bplusplus/{hierarchical/test.py → test.py} +99 -88
- bplusplus/tracker.py +261 -0
- bplusplus/{hierarchical/train.py → train.py} +29 -29
- bplusplus-1.2.3.dist-info/METADATA +101 -0
- bplusplus-1.2.3.dist-info/RECORD +11 -0
- {bplusplus-1.2.1.dist-info → bplusplus-1.2.3.dist-info}/WHEEL +1 -1
- bplusplus/resnet/test.py +0 -473
- bplusplus/resnet/train.py +0 -329
- bplusplus/train_validate.py +0 -11
- bplusplus-1.2.1.dist-info/METADATA +0 -252
- bplusplus-1.2.1.dist-info/RECORD +0 -12
- {bplusplus-1.2.1.dist-info → bplusplus-1.2.3.dist-info}/LICENSE +0 -0
bplusplus/resnet/test.py
DELETED
|
@@ -1,473 +0,0 @@
|
|
|
1
|
-
# pip install ultralytics torchvision pillow numpy scikit-learn tabulate tqdm
|
|
2
|
-
# python3 tests/two-stage(yolo-resnet).py --data ' --yolo_weights --resnet_weights --use_resnet50
|
|
3
|
-
|
|
4
|
-
import os
|
|
5
|
-
import cv2
|
|
6
|
-
import torch
|
|
7
|
-
from ultralytics import YOLO
|
|
8
|
-
from torchvision import transforms
|
|
9
|
-
from PIL import Image
|
|
10
|
-
import numpy as np
|
|
11
|
-
from torchvision.models import resnet152, resnet50
|
|
12
|
-
import torch.nn as nn
|
|
13
|
-
from sklearn.metrics import classification_report, accuracy_score
|
|
14
|
-
import time
|
|
15
|
-
from collections import defaultdict
|
|
16
|
-
from tabulate import tabulate
|
|
17
|
-
from tqdm import tqdm
|
|
18
|
-
import csv
|
|
19
|
-
import requests
|
|
20
|
-
import sys
|
|
21
|
-
|
|
22
|
-
def test_resnet(data_path, yolo_weights, resnet_weights, model="resnet152", species_names=None, output_dir="output"):
|
|
23
|
-
"""
|
|
24
|
-
Run the two-stage detection and classification test
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
data_path (str): Path to the test directory
|
|
28
|
-
yolo_weights (str): Path to the YOLO model file
|
|
29
|
-
resnet_weights (str): Path to the ResNet model file
|
|
30
|
-
model (str): Model type, either "resnet50" or "resnet152"
|
|
31
|
-
species_names (list): List of species names
|
|
32
|
-
output_dir (str): Directory to save output CSV files
|
|
33
|
-
"""
|
|
34
|
-
use_resnet50 = model == "resnet50"
|
|
35
|
-
classifier = TestTwoStage(yolo_weights, resnet_weights, use_resnet50=use_resnet50,
|
|
36
|
-
species_names=species_names, output_dir=output_dir)
|
|
37
|
-
classifier.run(data_path)
|
|
38
|
-
|
|
39
|
-
class TestTwoStage:
|
|
40
|
-
def __init__(self, yolo_model_path, resnet_model_path, num_classes=9, use_resnet50=False, species_names="", output_dir="output"):
|
|
41
|
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
42
|
-
print(f"Using device: {self.device}")
|
|
43
|
-
|
|
44
|
-
self.output_dir = output_dir
|
|
45
|
-
os.makedirs(self.output_dir, exist_ok=True)
|
|
46
|
-
|
|
47
|
-
self.yolo_model = YOLO(yolo_model_path)
|
|
48
|
-
self.classification_model = resnet50(pretrained=False) if use_resnet50 else resnet152(pretrained=False)
|
|
49
|
-
|
|
50
|
-
self.classification_model.fc = nn.Sequential(
|
|
51
|
-
nn.Dropout(0.4), # Using dropout probability of 0.4 as in training
|
|
52
|
-
nn.Linear(self.classification_model.fc.in_features, num_classes)
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
state_dict = torch.load(resnet_model_path, map_location=self.device)
|
|
56
|
-
self.classification_model.load_state_dict(state_dict)
|
|
57
|
-
self.classification_model.to(self.device)
|
|
58
|
-
self.classification_model.eval()
|
|
59
|
-
|
|
60
|
-
self.classification_transform = transforms.Compose([
|
|
61
|
-
transforms.ToTensor(),
|
|
62
|
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
63
|
-
])
|
|
64
|
-
self.species_names = species_names
|
|
65
|
-
|
|
66
|
-
def get_frames(self, test_dir):
|
|
67
|
-
image_dir = os.path.join(test_dir, "images")
|
|
68
|
-
label_dir = os.path.join(test_dir, "labels")
|
|
69
|
-
|
|
70
|
-
predicted_frames = []
|
|
71
|
-
true_frames = []
|
|
72
|
-
image_names = []
|
|
73
|
-
|
|
74
|
-
start_time = time.time() # Start timing
|
|
75
|
-
|
|
76
|
-
for image_name in tqdm(os.listdir(image_dir), desc="Processing Images", unit="image"):
|
|
77
|
-
image_names.append(image_name)
|
|
78
|
-
image_path = os.path.join(image_dir, image_name)
|
|
79
|
-
label_path = os.path.join(label_dir, image_name.replace('.jpg', '.txt'))
|
|
80
|
-
|
|
81
|
-
frame = cv2.imread(image_path)
|
|
82
|
-
# Suppress print statements from YOLO model
|
|
83
|
-
with torch.no_grad():
|
|
84
|
-
results = self.yolo_model(frame, conf=0.3, iou=0.5, verbose=False)
|
|
85
|
-
|
|
86
|
-
detections = results[0].boxes
|
|
87
|
-
predicted_frame = []
|
|
88
|
-
|
|
89
|
-
if detections:
|
|
90
|
-
for box in detections:
|
|
91
|
-
xyxy = box.xyxy.cpu().numpy().flatten()
|
|
92
|
-
x1, y1, x2, y2 = xyxy[:4]
|
|
93
|
-
width = x2 - x1
|
|
94
|
-
height = y2 - y1
|
|
95
|
-
x_center = x1 + width / 2
|
|
96
|
-
y_center = y1 + height / 2
|
|
97
|
-
|
|
98
|
-
insect_crop = frame[int(y1):int(y2), int(x1):int(x2)]
|
|
99
|
-
insect_crop_rgb = cv2.cvtColor(insect_crop, cv2.COLOR_BGR2RGB)
|
|
100
|
-
pil_img = Image.fromarray(insect_crop_rgb)
|
|
101
|
-
input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
|
|
102
|
-
|
|
103
|
-
with torch.no_grad():
|
|
104
|
-
outputs = self.classification_model(input_tensor)
|
|
105
|
-
|
|
106
|
-
# Directly use the model output without any mapping
|
|
107
|
-
predicted_class_idx = outputs.argmax(dim=1).item()
|
|
108
|
-
|
|
109
|
-
img_height, img_width, _ = frame.shape
|
|
110
|
-
x_center_norm = x_center / img_width
|
|
111
|
-
y_center_norm = y_center / img_height
|
|
112
|
-
width_norm = width / img_width
|
|
113
|
-
height_norm = height / img_height
|
|
114
|
-
predicted_frame.append([predicted_class_idx, x_center_norm, y_center_norm, width_norm, height_norm])
|
|
115
|
-
|
|
116
|
-
predicted_frames.append(predicted_frame if predicted_frame else [])
|
|
117
|
-
|
|
118
|
-
true_frame = []
|
|
119
|
-
if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
|
|
120
|
-
with open(label_path, 'r') as f:
|
|
121
|
-
for line in f:
|
|
122
|
-
label_line = line.strip().split()
|
|
123
|
-
true_frame.append([int(label_line[0]), *map(np.float32, label_line[1:])])
|
|
124
|
-
|
|
125
|
-
true_frames.append(true_frame if true_frame else [])
|
|
126
|
-
|
|
127
|
-
end_time = time.time() # End timing
|
|
128
|
-
|
|
129
|
-
model_type = "resnet50" if isinstance(self.classification_model, type(resnet50())) else "resnet152"
|
|
130
|
-
output_file = os.path.join(self.output_dir, f"results_{model_type}_{time.strftime('%Y%m%d_%H%M%S')}.csv")
|
|
131
|
-
|
|
132
|
-
with open(output_file, "w") as f:
|
|
133
|
-
writer = csv.writer(f)
|
|
134
|
-
writer.writerow(["Image Name", "True", "Predicted"])
|
|
135
|
-
for image_name, true_frame, predicted_frame in zip(image_names, true_frames, predicted_frames):
|
|
136
|
-
writer.writerow([image_name, true_frame, predicted_frame])
|
|
137
|
-
|
|
138
|
-
print(f"Results saved to {output_file}")
|
|
139
|
-
return predicted_frames, true_frames, end_time - start_time
|
|
140
|
-
|
|
141
|
-
def get_taxonomic_info(self, species_list):
|
|
142
|
-
"""
|
|
143
|
-
Retrieves taxonomic information for a list of species from GBIF API.
|
|
144
|
-
Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
|
|
145
|
-
"""
|
|
146
|
-
taxonomy = {1: [], 2: {}, 3: {}}
|
|
147
|
-
species_to_family = {}
|
|
148
|
-
family_to_order = {}
|
|
149
|
-
|
|
150
|
-
print(f"Building taxonomy from GBIF for {len(species_list)} species")
|
|
151
|
-
|
|
152
|
-
print("\nTaxonomy Results:")
|
|
153
|
-
print("-" * 80)
|
|
154
|
-
print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
|
|
155
|
-
print("-" * 80)
|
|
156
|
-
|
|
157
|
-
for species_name in species_list:
|
|
158
|
-
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
159
|
-
try:
|
|
160
|
-
response = requests.get(url)
|
|
161
|
-
data = response.json()
|
|
162
|
-
|
|
163
|
-
if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
|
|
164
|
-
family = data.get('family')
|
|
165
|
-
order = data.get('order')
|
|
166
|
-
|
|
167
|
-
if family and order:
|
|
168
|
-
status = "OK"
|
|
169
|
-
|
|
170
|
-
print(f"{species_name:<30} {order:<20} {family:<20} {status}")
|
|
171
|
-
|
|
172
|
-
species_to_family[species_name] = family
|
|
173
|
-
family_to_order[family] = order
|
|
174
|
-
|
|
175
|
-
if order not in taxonomy[1]:
|
|
176
|
-
taxonomy[1].append(order)
|
|
177
|
-
|
|
178
|
-
taxonomy[2][family] = order
|
|
179
|
-
taxonomy[3][species_name] = family
|
|
180
|
-
else:
|
|
181
|
-
error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
|
|
182
|
-
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
183
|
-
print(f"Error: {error_msg}")
|
|
184
|
-
sys.exit(1) # Stop the script
|
|
185
|
-
else:
|
|
186
|
-
error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
|
|
187
|
-
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
188
|
-
print(f"Error: {error_msg}")
|
|
189
|
-
sys.exit(1) # Stop the script
|
|
190
|
-
|
|
191
|
-
except Exception as e:
|
|
192
|
-
error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
|
|
193
|
-
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
194
|
-
print(f"Error: {error_msg}")
|
|
195
|
-
sys.exit(1) # Stop the script
|
|
196
|
-
|
|
197
|
-
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
198
|
-
print("-" * 80)
|
|
199
|
-
|
|
200
|
-
num_orders = len(taxonomy[1])
|
|
201
|
-
num_families = len(taxonomy[2])
|
|
202
|
-
num_species = len(taxonomy[3])
|
|
203
|
-
|
|
204
|
-
print("\nOrder indices:")
|
|
205
|
-
for i, order in enumerate(taxonomy[1]):
|
|
206
|
-
print(f" {i}: {order}")
|
|
207
|
-
|
|
208
|
-
print("\nFamily indices:")
|
|
209
|
-
for i, family in enumerate(taxonomy[2].keys()):
|
|
210
|
-
print(f" {i}: {family}")
|
|
211
|
-
|
|
212
|
-
print("\nSpecies indices:")
|
|
213
|
-
for i, species in enumerate(species_list):
|
|
214
|
-
print(f" {i}: {species}")
|
|
215
|
-
|
|
216
|
-
print(f"\nTaxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
|
|
217
|
-
|
|
218
|
-
return taxonomy, species_to_family, family_to_order
|
|
219
|
-
|
|
220
|
-
def get_metrics(self, predicted_frames, true_frames, labels):
|
|
221
|
-
"""
|
|
222
|
-
Calculate precision, recall, and F1 score for object detection results.
|
|
223
|
-
"""
|
|
224
|
-
def calculate_iou(box1, box2):
|
|
225
|
-
x1_min, y1_min = box1[1] - box1[3] / 2, box1[2] - box1[4] / 2
|
|
226
|
-
x1_max, y1_max = box1[1] + box1[3] / 2, box1[2] + box1[4] / 2
|
|
227
|
-
x2_min, y2_min = box2[1] - box2[3] / 2, box2[2] - box2[4] / 2
|
|
228
|
-
x2_max, y2_max = box2[1] + box2[3] / 2, box2[2] + box2[4] / 2
|
|
229
|
-
|
|
230
|
-
inter_x_min = max(x1_min, x2_min)
|
|
231
|
-
inter_y_min = max(y1_min, y2_min)
|
|
232
|
-
inter_x_max = min(x1_max, x2_max)
|
|
233
|
-
inter_y_max = min(y1_max, y2_max)
|
|
234
|
-
|
|
235
|
-
inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
|
|
236
|
-
box1_area = (x1_max - x1_min) * (y1_max - y1_min)
|
|
237
|
-
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
|
|
238
|
-
|
|
239
|
-
iou = inter_area / (box1_area + box2_area - inter_area)
|
|
240
|
-
return iou
|
|
241
|
-
|
|
242
|
-
def calculate_precision_recall(pred_boxes, true_boxes, iou_threshold=0.5):
|
|
243
|
-
label_results = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
|
|
244
|
-
generic_tp = 0
|
|
245
|
-
generic_fp = 0
|
|
246
|
-
|
|
247
|
-
matched_true_boxes = set()
|
|
248
|
-
|
|
249
|
-
for pred_box in pred_boxes:
|
|
250
|
-
label_idx = pred_box[0]
|
|
251
|
-
matched = False
|
|
252
|
-
|
|
253
|
-
best_iou = 0
|
|
254
|
-
best_match_idx = -1
|
|
255
|
-
|
|
256
|
-
for i, true_box in enumerate(true_boxes):
|
|
257
|
-
if i in matched_true_boxes:
|
|
258
|
-
continue
|
|
259
|
-
|
|
260
|
-
iou = calculate_iou(pred_box, true_box)
|
|
261
|
-
if iou >= iou_threshold and iou > best_iou:
|
|
262
|
-
best_iou = iou
|
|
263
|
-
best_match_idx = i
|
|
264
|
-
|
|
265
|
-
if best_match_idx >= 0:
|
|
266
|
-
matched = True
|
|
267
|
-
true_box = true_boxes[best_match_idx]
|
|
268
|
-
matched_true_boxes.add(best_match_idx)
|
|
269
|
-
generic_tp += 1
|
|
270
|
-
|
|
271
|
-
if pred_box[0] == true_box[0]:
|
|
272
|
-
label_results[label_idx]['tp'] += 1
|
|
273
|
-
else:
|
|
274
|
-
label_results[label_idx]['fp'] += 1
|
|
275
|
-
true_label_idx = true_box[0]
|
|
276
|
-
label_results[true_label_idx]['fn'] += 1
|
|
277
|
-
|
|
278
|
-
if not matched:
|
|
279
|
-
label_results[label_idx]['fp'] += 1
|
|
280
|
-
generic_fp += 1
|
|
281
|
-
|
|
282
|
-
for i, true_box in enumerate(true_boxes):
|
|
283
|
-
if i not in matched_true_boxes:
|
|
284
|
-
label_idx = true_box[0]
|
|
285
|
-
label_results[label_idx]['fn'] += 1
|
|
286
|
-
|
|
287
|
-
generic_fn = len(true_boxes) - len(matched_true_boxes)
|
|
288
|
-
|
|
289
|
-
return label_results, generic_tp, generic_fp, generic_fn
|
|
290
|
-
|
|
291
|
-
label_metrics = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0})
|
|
292
|
-
background_metrics = {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0}
|
|
293
|
-
generic_metrics = {'tp': 0, 'fp': 0, 'fn': 0}
|
|
294
|
-
|
|
295
|
-
for true_frame in true_frames:
|
|
296
|
-
if not true_frame: # Empty frame (background only)
|
|
297
|
-
background_metrics['support'] += 1
|
|
298
|
-
else:
|
|
299
|
-
for true_box in true_frame:
|
|
300
|
-
label_idx = true_box[0]
|
|
301
|
-
label_metrics[label_idx]['support'] += 1 # Count each detection, not just unique labels
|
|
302
|
-
|
|
303
|
-
for pred_frame, true_frame in zip(predicted_frames, true_frames):
|
|
304
|
-
if not pred_frame and not true_frame:
|
|
305
|
-
background_metrics['tp'] += 1
|
|
306
|
-
elif not pred_frame:
|
|
307
|
-
background_metrics['fn'] += 1
|
|
308
|
-
elif not true_frame:
|
|
309
|
-
background_metrics['fp'] += 1
|
|
310
|
-
else:
|
|
311
|
-
frame_results, g_tp, g_fp, g_fn = calculate_precision_recall(pred_frame, true_frame)
|
|
312
|
-
|
|
313
|
-
for label_idx, metrics in frame_results.items():
|
|
314
|
-
label_metrics[label_idx]['tp'] += metrics['tp']
|
|
315
|
-
label_metrics[label_idx]['fp'] += metrics['fp']
|
|
316
|
-
label_metrics[label_idx]['fn'] += metrics['fn']
|
|
317
|
-
|
|
318
|
-
generic_metrics['tp'] += g_tp
|
|
319
|
-
generic_metrics['fp'] += g_fp
|
|
320
|
-
generic_metrics['fn'] += g_fn
|
|
321
|
-
|
|
322
|
-
table_data = []
|
|
323
|
-
# Store individual class metrics for macro-averaging
|
|
324
|
-
class_precisions = []
|
|
325
|
-
class_recalls = []
|
|
326
|
-
class_f1s = []
|
|
327
|
-
|
|
328
|
-
for label_idx, metrics in label_metrics.items():
|
|
329
|
-
tp = metrics['tp']
|
|
330
|
-
fp = metrics['fp']
|
|
331
|
-
fn = metrics['fn']
|
|
332
|
-
support = metrics['support']
|
|
333
|
-
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
334
|
-
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
335
|
-
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
336
|
-
|
|
337
|
-
# Store for macro-averaging
|
|
338
|
-
class_precisions.append(precision)
|
|
339
|
-
class_recalls.append(recall)
|
|
340
|
-
class_f1s.append(f1_score)
|
|
341
|
-
|
|
342
|
-
label_name = labels[label_idx] if label_idx < len(labels) else f"Label {label_idx}"
|
|
343
|
-
table_data.append([label_name, f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
|
|
344
|
-
|
|
345
|
-
print(f"Debug {label_name}: TP={tp}, FP={fp}, FN={fn}")
|
|
346
|
-
print(f" Raw P={tp/(tp+fp) if (tp+fp)>0 else 0}, R={tp/(tp+fn) if (tp+fn)>0 else 0}")
|
|
347
|
-
|
|
348
|
-
tp = background_metrics['tp']
|
|
349
|
-
fp = background_metrics['fp']
|
|
350
|
-
fn = background_metrics['fn']
|
|
351
|
-
support = background_metrics['support']
|
|
352
|
-
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
353
|
-
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
354
|
-
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
355
|
-
table_data.append(["Background", f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
|
|
356
|
-
|
|
357
|
-
headers = ["Label", "Precision", "Recall", "F1 Score", "Support"]
|
|
358
|
-
total_tp = sum(metrics['tp'] for metrics in label_metrics.values())
|
|
359
|
-
total_fp = sum(metrics['fp'] for metrics in label_metrics.values())
|
|
360
|
-
total_fn = sum(metrics['fn'] for metrics in label_metrics.values())
|
|
361
|
-
total_support = sum(metrics['support'] for metrics in label_metrics.values())
|
|
362
|
-
|
|
363
|
-
total_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
|
364
|
-
total_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
|
365
|
-
total_f1_score = 2 * (total_precision * total_recall) / (total_precision + total_recall) if (total_precision + total_recall) > 0 else 0
|
|
366
|
-
|
|
367
|
-
table_data.append(["\nTotal (micro-avg, excl. background)", f"{total_precision:.2f}", f"{total_recall:.2f}", f"{total_f1_score:.2f}", f"{total_support}"])
|
|
368
|
-
|
|
369
|
-
# Add macro-average
|
|
370
|
-
if class_precisions:
|
|
371
|
-
macro_precision = sum(class_precisions) / len(class_precisions)
|
|
372
|
-
macro_recall = sum(class_recalls) / len(class_recalls)
|
|
373
|
-
macro_f1 = sum(class_f1s) / len(class_f1s)
|
|
374
|
-
table_data.append(["Total (macro-avg, excl. background)", f"{macro_precision:.2f}", f"{macro_recall:.2f}", f"{macro_f1:.2f}", f"{total_support}"])
|
|
375
|
-
|
|
376
|
-
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
|
377
|
-
|
|
378
|
-
generic_tp = generic_metrics['tp']
|
|
379
|
-
generic_fp = generic_metrics['fp']
|
|
380
|
-
generic_fn = generic_metrics['fn']
|
|
381
|
-
|
|
382
|
-
generic_precision = generic_tp / (generic_tp + generic_fp) if (generic_tp + generic_fp) > 0 else 0
|
|
383
|
-
generic_recall = generic_tp / (generic_tp + generic_fn) if (generic_tp + generic_fn) > 0 else 0
|
|
384
|
-
generic_f1_score = 2 * (generic_precision * generic_recall) / (generic_precision + generic_recall) if (generic_precision + generic_recall) > 0 else 0
|
|
385
|
-
|
|
386
|
-
print("\nGeneric Total", f"{generic_precision:.2f}", f"{generic_recall:.2f}", f"{generic_f1_score:.2f}")
|
|
387
|
-
|
|
388
|
-
return total_precision, total_recall, total_f1_score
|
|
389
|
-
|
|
390
|
-
def run(self, test_dir):
|
|
391
|
-
predicted_frames, true_frames, total_time = self.get_frames(test_dir)
|
|
392
|
-
num_frames = len(os.listdir(os.path.join(test_dir, 'images')))
|
|
393
|
-
avg_time_per_frame = total_time / num_frames
|
|
394
|
-
|
|
395
|
-
print(f"\nTotal time: {total_time:.2f} seconds")
|
|
396
|
-
print(f"Average time per frame: {avg_time_per_frame:.4f} seconds")
|
|
397
|
-
|
|
398
|
-
# Get taxonomy information for hierarchical analysis
|
|
399
|
-
taxonomy, species_to_family, family_to_order = self.get_taxonomic_info(self.species_names)
|
|
400
|
-
family_list = list(family_to_order.keys())
|
|
401
|
-
order_list = list(taxonomy[1])
|
|
402
|
-
|
|
403
|
-
# Convert species-level predictions to family and order levels
|
|
404
|
-
true_family_frames = []
|
|
405
|
-
true_order_frames = []
|
|
406
|
-
predicted_family_frames = []
|
|
407
|
-
predicted_order_frames = []
|
|
408
|
-
|
|
409
|
-
for true_frame in true_frames:
|
|
410
|
-
frame_family_boxes = []
|
|
411
|
-
frame_order_boxes = []
|
|
412
|
-
|
|
413
|
-
if true_frame:
|
|
414
|
-
for true_box in true_frame:
|
|
415
|
-
species_idx = true_box[0]
|
|
416
|
-
species_name = self.species_names[species_idx]
|
|
417
|
-
family_name = species_to_family[species_name]
|
|
418
|
-
order_name = family_to_order[family_name]
|
|
419
|
-
|
|
420
|
-
family_label = [family_list.index(family_name)] + list(true_box[1:])
|
|
421
|
-
order_label = [order_list.index(order_name)] + list(true_box[1:])
|
|
422
|
-
|
|
423
|
-
frame_family_boxes.append(family_label)
|
|
424
|
-
frame_order_boxes.append(order_label)
|
|
425
|
-
|
|
426
|
-
true_family_frames.append(frame_family_boxes)
|
|
427
|
-
true_order_frames.append(frame_order_boxes)
|
|
428
|
-
|
|
429
|
-
for pred_frame in predicted_frames:
|
|
430
|
-
frame_family_boxes = []
|
|
431
|
-
frame_order_boxes = []
|
|
432
|
-
|
|
433
|
-
if pred_frame:
|
|
434
|
-
for pred_box in pred_frame:
|
|
435
|
-
species_idx = pred_box[0]
|
|
436
|
-
species_name = self.species_names[species_idx]
|
|
437
|
-
family_name = species_to_family[species_name]
|
|
438
|
-
order_name = family_to_order[family_name]
|
|
439
|
-
|
|
440
|
-
family_label = [family_list.index(family_name)] + list(map(np.float32, pred_box[1:]))
|
|
441
|
-
order_label = [order_list.index(order_name)] + list(map(np.float32, pred_box[1:]))
|
|
442
|
-
|
|
443
|
-
frame_family_boxes.append(family_label)
|
|
444
|
-
frame_order_boxes.append(order_label)
|
|
445
|
-
|
|
446
|
-
predicted_family_frames.append(frame_family_boxes)
|
|
447
|
-
predicted_order_frames.append(frame_order_boxes)
|
|
448
|
-
|
|
449
|
-
# Display metrics for all taxonomic levels
|
|
450
|
-
print("\nSpecies Level Metrics")
|
|
451
|
-
self.get_metrics(predicted_frames, true_frames, self.species_names)
|
|
452
|
-
|
|
453
|
-
print("\nFamily Level Metrics")
|
|
454
|
-
self.get_metrics(predicted_family_frames, true_family_frames, family_list)
|
|
455
|
-
|
|
456
|
-
print("\nOrder Level Metrics")
|
|
457
|
-
self.get_metrics(predicted_order_frames, true_order_frames, order_list)
|
|
458
|
-
|
|
459
|
-
if __name__ == "__main__":
|
|
460
|
-
species_names = [
|
|
461
|
-
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
462
|
-
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
463
|
-
"Eristalis tenax"
|
|
464
|
-
]
|
|
465
|
-
|
|
466
|
-
test_resnet(
|
|
467
|
-
data_path="/mnt/nvme0n1p1/mit/two-stage-detection/bjerge-test",
|
|
468
|
-
yolo_weights="/mnt/nvme0n1p1/mit/two-stage-detection/small-generic.pt",
|
|
469
|
-
resnet_weights="/mnt/nvme0n1p1/mit/two-stage-detection/output/best_resnet50.pt",
|
|
470
|
-
model="resnet50",
|
|
471
|
-
species_names=species_names,
|
|
472
|
-
output_dir="/mnt/nvme0n1p1/mit/two-stage-detection/output"
|
|
473
|
-
)
|