bplusplus 0.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +7 -3
- bplusplus/{collect_images.py → collect.py} +71 -7
- bplusplus/hierarchical/test.py +670 -0
- bplusplus/hierarchical/train.py +676 -0
- bplusplus/prepare.py +737 -0
- bplusplus/resnet/test.py +473 -0
- bplusplus/resnet/train.py +329 -0
- bplusplus/train_validate.py +8 -64
- bplusplus-1.2.0.dist-info/METADATA +249 -0
- bplusplus-1.2.0.dist-info/RECORD +12 -0
- bplusplus/build_model.py +0 -38
- bplusplus-0.1.1.dist-info/METADATA +0 -97
- bplusplus-0.1.1.dist-info/RECORD +0 -8
- {bplusplus-0.1.1.dist-info → bplusplus-1.2.0.dist-info}/LICENSE +0 -0
- {bplusplus-0.1.1.dist-info → bplusplus-1.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,670 @@
|
|
|
1
|
+
# pip install ultralytics torchvision pillow numpy scikit-learn tabulate tqdm requests
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import cv2
|
|
5
|
+
import torch
|
|
6
|
+
from ultralytics import YOLO
|
|
7
|
+
from torchvision import transforms
|
|
8
|
+
from PIL import Image
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torchvision.models import resnet50
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
from sklearn.metrics import classification_report, accuracy_score
|
|
13
|
+
import time
|
|
14
|
+
import argparse
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
from tabulate import tabulate
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
import csv
|
|
19
|
+
import requests
|
|
20
|
+
import logging
|
|
21
|
+
import sys
|
|
22
|
+
|
|
23
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
def test_multitask(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
|
|
27
|
+
"""
|
|
28
|
+
Run the two-stage classifier on a test set.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
species_list (list): List of species names used for training
|
|
32
|
+
test_set (str): Path to the test directory
|
|
33
|
+
yolo_weights (str): Path to the YOLO model file
|
|
34
|
+
hierarchical_weights (str): Path to the hierarchical classifier model file
|
|
35
|
+
output_dir (str): Directory to save output CSV files (default: current directory)
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Results from the classifier
|
|
39
|
+
"""
|
|
40
|
+
classifier = TestTwoStage(yolo_weights, hierarchical_weights, species_list, output_dir)
|
|
41
|
+
results = classifier.run(test_set)
|
|
42
|
+
print("Testing complete with metrics calculated at all taxonomic levels")
|
|
43
|
+
return results
|
|
44
|
+
|
|
45
|
+
def cuda_cleanup():
|
|
46
|
+
"""Clear CUDA cache and reset device"""
|
|
47
|
+
if torch.cuda.is_available():
|
|
48
|
+
torch.cuda.empty_cache()
|
|
49
|
+
torch.cuda.reset_peak_memory_stats()
|
|
50
|
+
|
|
51
|
+
def setup_gpu():
|
|
52
|
+
"""Set up GPU with better error handling and reporting"""
|
|
53
|
+
if not torch.cuda.is_available():
|
|
54
|
+
logger.warning("CUDA is not available on this system")
|
|
55
|
+
return torch.device("cpu")
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
gpu_count = torch.cuda.device_count()
|
|
59
|
+
logger.info(f"Found {gpu_count} CUDA device(s)")
|
|
60
|
+
|
|
61
|
+
for i in range(gpu_count):
|
|
62
|
+
gpu_properties = torch.cuda.get_device_properties(i)
|
|
63
|
+
logger.info(f"GPU {i}: {gpu_properties.name} with {gpu_properties.total_memory / 1e9:.2f} GB memory")
|
|
64
|
+
|
|
65
|
+
device = torch.device("cuda:0")
|
|
66
|
+
test_tensor = torch.ones(1, device=device)
|
|
67
|
+
test_result = test_tensor * 2
|
|
68
|
+
del test_tensor, test_result
|
|
69
|
+
|
|
70
|
+
logger.info("CUDA initialization successful")
|
|
71
|
+
return device
|
|
72
|
+
except Exception as e:
|
|
73
|
+
logger.error(f"CUDA initialization error: {str(e)}")
|
|
74
|
+
logger.warning("Falling back to CPU")
|
|
75
|
+
return torch.device("cpu")
|
|
76
|
+
|
|
77
|
+
class HierarchicalInsectClassifier(nn.Module):
|
|
78
|
+
def __init__(self, num_classes_per_level):
|
|
79
|
+
"""
|
|
80
|
+
Args:
|
|
81
|
+
num_classes_per_level (list): Number of classes for each taxonomic level
|
|
82
|
+
"""
|
|
83
|
+
super(HierarchicalInsectClassifier, self).__init__()
|
|
84
|
+
|
|
85
|
+
self.backbone = resnet50(pretrained=True)
|
|
86
|
+
backbone_output_features = self.backbone.fc.in_features
|
|
87
|
+
self.backbone.fc = nn.Identity() # Remove the final fully connected layer
|
|
88
|
+
|
|
89
|
+
self.branches = nn.ModuleList()
|
|
90
|
+
for num_classes in num_classes_per_level:
|
|
91
|
+
branch = nn.Sequential(
|
|
92
|
+
nn.Linear(backbone_output_features, 512),
|
|
93
|
+
nn.ReLU(),
|
|
94
|
+
nn.Dropout(0.5),
|
|
95
|
+
nn.Linear(512, num_classes)
|
|
96
|
+
)
|
|
97
|
+
self.branches.append(branch)
|
|
98
|
+
|
|
99
|
+
self.num_levels = len(num_classes_per_level)
|
|
100
|
+
|
|
101
|
+
self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
|
|
102
|
+
self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
|
|
103
|
+
self.class_counts = [0] * sum(num_classes_per_level)
|
|
104
|
+
self.output_history = defaultdict(list)
|
|
105
|
+
|
|
106
|
+
def forward(self, x):
|
|
107
|
+
R0 = self.backbone(x)
|
|
108
|
+
|
|
109
|
+
outputs = []
|
|
110
|
+
for branch in self.branches:
|
|
111
|
+
outputs.append(branch(R0))
|
|
112
|
+
|
|
113
|
+
return outputs
|
|
114
|
+
|
|
115
|
+
def get_taxonomy(species_list):
|
|
116
|
+
"""
|
|
117
|
+
Retrieves taxonomic information for a list of species from GBIF API.
|
|
118
|
+
Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
|
|
119
|
+
"""
|
|
120
|
+
taxonomy = {1: [], 2: {}, 3: {}}
|
|
121
|
+
species_to_family = {}
|
|
122
|
+
family_to_order = {}
|
|
123
|
+
|
|
124
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
|
|
125
|
+
|
|
126
|
+
print("\nTaxonomy Results:")
|
|
127
|
+
print("-" * 80)
|
|
128
|
+
print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
|
|
129
|
+
print("-" * 80)
|
|
130
|
+
|
|
131
|
+
for species_name in species_list:
|
|
132
|
+
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
133
|
+
try:
|
|
134
|
+
response = requests.get(url)
|
|
135
|
+
data = response.json()
|
|
136
|
+
|
|
137
|
+
if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
|
|
138
|
+
family = data.get('family')
|
|
139
|
+
order = data.get('order')
|
|
140
|
+
|
|
141
|
+
if family and order:
|
|
142
|
+
status = "OK"
|
|
143
|
+
|
|
144
|
+
print(f"{species_name:<30} {order:<20} {family:<20} {status}")
|
|
145
|
+
|
|
146
|
+
species_to_family[species_name] = family
|
|
147
|
+
family_to_order[family] = order
|
|
148
|
+
|
|
149
|
+
if order not in taxonomy[1]:
|
|
150
|
+
taxonomy[1].append(order)
|
|
151
|
+
|
|
152
|
+
taxonomy[2][family] = order
|
|
153
|
+
taxonomy[3][species_name] = family
|
|
154
|
+
else:
|
|
155
|
+
error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
|
|
156
|
+
logger.error(error_msg)
|
|
157
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
158
|
+
print(f"Error: {error_msg}")
|
|
159
|
+
sys.exit(1) # Stop the script
|
|
160
|
+
else:
|
|
161
|
+
error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
|
|
162
|
+
logger.error(error_msg)
|
|
163
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
164
|
+
print(f"Error: {error_msg}")
|
|
165
|
+
sys.exit(1) # Stop the script
|
|
166
|
+
|
|
167
|
+
except Exception as e:
|
|
168
|
+
error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
|
|
169
|
+
logger.error(error_msg)
|
|
170
|
+
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
171
|
+
print(f"Error: {error_msg}")
|
|
172
|
+
sys.exit(1) # Stop the script
|
|
173
|
+
|
|
174
|
+
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
175
|
+
print("-" * 80)
|
|
176
|
+
|
|
177
|
+
num_orders = len(taxonomy[1])
|
|
178
|
+
num_families = len(taxonomy[2])
|
|
179
|
+
num_species = len(taxonomy[3])
|
|
180
|
+
|
|
181
|
+
print("\nOrder indices:")
|
|
182
|
+
for i, order in enumerate(taxonomy[1]):
|
|
183
|
+
print(f" {i}: {order}")
|
|
184
|
+
|
|
185
|
+
print("\nFamily indices:")
|
|
186
|
+
for i, family in enumerate(taxonomy[2].keys()):
|
|
187
|
+
print(f" {i}: {family}")
|
|
188
|
+
|
|
189
|
+
print("\nSpecies indices:")
|
|
190
|
+
for i, species in enumerate(species_list):
|
|
191
|
+
print(f" {i}: {species}")
|
|
192
|
+
|
|
193
|
+
logger.info(f"Taxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
|
|
194
|
+
return taxonomy, species_to_family, family_to_order
|
|
195
|
+
|
|
196
|
+
def create_mappings(taxonomy):
|
|
197
|
+
"""Create index mappings from taxonomy"""
|
|
198
|
+
level_to_idx = {}
|
|
199
|
+
idx_to_level = {}
|
|
200
|
+
|
|
201
|
+
for level, labels in taxonomy.items():
|
|
202
|
+
if isinstance(labels, list):
|
|
203
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
204
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
|
|
205
|
+
else: # Dictionary
|
|
206
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
|
|
207
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(labels.keys())}
|
|
208
|
+
|
|
209
|
+
return level_to_idx, idx_to_level
|
|
210
|
+
|
|
211
|
+
class TestTwoStage:
|
|
212
|
+
def __init__(self, yolo_model_path, hierarchical_model_path, species_names, output_dir="."):
|
|
213
|
+
cuda_cleanup()
|
|
214
|
+
|
|
215
|
+
self.device = setup_gpu()
|
|
216
|
+
logger.info(f"Using device: {self.device}")
|
|
217
|
+
|
|
218
|
+
# Create output directory if it doesn't exist
|
|
219
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
220
|
+
self.output_dir = output_dir
|
|
221
|
+
logger.info(f"Results will be saved to: {self.output_dir}")
|
|
222
|
+
|
|
223
|
+
print(f"Using device: {self.device}")
|
|
224
|
+
|
|
225
|
+
self.yolo_model = YOLO(yolo_model_path)
|
|
226
|
+
|
|
227
|
+
self.species_names = species_names
|
|
228
|
+
|
|
229
|
+
logger.info(f"Loading model from {hierarchical_model_path}")
|
|
230
|
+
try:
|
|
231
|
+
checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
|
|
232
|
+
logger.info("Model loaded to CPU successfully")
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logger.error(f"Error loading model: {e}")
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
if "model_state_dict" in checkpoint:
|
|
238
|
+
state_dict = checkpoint["model_state_dict"]
|
|
239
|
+
|
|
240
|
+
if "taxonomy" in checkpoint:
|
|
241
|
+
print("Using taxonomy from saved model")
|
|
242
|
+
taxonomy = checkpoint["taxonomy"]
|
|
243
|
+
if "species_list" in checkpoint:
|
|
244
|
+
saved_species = checkpoint["species_list"]
|
|
245
|
+
print(f"Saved model was trained on: {', '.join(saved_species)}")
|
|
246
|
+
|
|
247
|
+
taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
|
|
248
|
+
else:
|
|
249
|
+
taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
|
|
250
|
+
else:
|
|
251
|
+
state_dict = checkpoint
|
|
252
|
+
taxonomy, species_to_family, family_to_order = get_taxonomy(species_names)
|
|
253
|
+
|
|
254
|
+
level_to_idx, idx_to_level = create_mappings(taxonomy)
|
|
255
|
+
|
|
256
|
+
self.level_to_idx = level_to_idx
|
|
257
|
+
self.idx_to_level = idx_to_level
|
|
258
|
+
|
|
259
|
+
if hasattr(taxonomy, "items"):
|
|
260
|
+
num_classes_per_level = [len(classes) if isinstance(classes, list) else len(classes.keys())
|
|
261
|
+
for level, classes in taxonomy.items()]
|
|
262
|
+
else:
|
|
263
|
+
num_classes_per_level = [4, 5, 9] # Example values, adjust as needed
|
|
264
|
+
|
|
265
|
+
print(f"Using model with class counts: {num_classes_per_level}")
|
|
266
|
+
|
|
267
|
+
self.classification_model = HierarchicalInsectClassifier(num_classes_per_level)
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
self.classification_model.load_state_dict(state_dict)
|
|
271
|
+
print("Model weights loaded successfully")
|
|
272
|
+
except Exception as e:
|
|
273
|
+
print(f"Error loading model weights: {e}")
|
|
274
|
+
print("Attempting to load with strict=False...")
|
|
275
|
+
self.classification_model.load_state_dict(state_dict, strict=False)
|
|
276
|
+
print("Model weights loaded with strict=False")
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
self.classification_model.to(self.device)
|
|
280
|
+
print(f"Model successfully transferred to {self.device}")
|
|
281
|
+
except RuntimeError as e:
|
|
282
|
+
logger.error(f"Error transferring model to {self.device}: {e}")
|
|
283
|
+
print(f"Error transferring model to {self.device}, falling back to CPU")
|
|
284
|
+
self.device = torch.device("cpu")
|
|
285
|
+
# No need to move to CPU since it's already there
|
|
286
|
+
|
|
287
|
+
self.classification_model.eval()
|
|
288
|
+
|
|
289
|
+
self.classification_transform = transforms.Compose([
|
|
290
|
+
transforms.Resize((768, 768)), # Fixed size for all validation images
|
|
291
|
+
transforms.CenterCrop(640),
|
|
292
|
+
transforms.ToTensor(),
|
|
293
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
294
|
+
])
|
|
295
|
+
|
|
296
|
+
print("Model successfully loaded")
|
|
297
|
+
print(f"Using species: {', '.join(species_names)}")
|
|
298
|
+
|
|
299
|
+
self.species_to_family = species_to_family
|
|
300
|
+
self.family_to_order = family_to_order
|
|
301
|
+
|
|
302
|
+
def get_frames(self, test_dir):
|
|
303
|
+
image_dir = os.path.join(test_dir, "images")
|
|
304
|
+
label_dir = os.path.join(test_dir, "labels")
|
|
305
|
+
|
|
306
|
+
predicted_frames = []
|
|
307
|
+
predicted_family_frames = []
|
|
308
|
+
predicted_order_frames = []
|
|
309
|
+
true_species_frames = []
|
|
310
|
+
true_family_frames = []
|
|
311
|
+
true_order_frames = []
|
|
312
|
+
image_names = []
|
|
313
|
+
|
|
314
|
+
start_time = time.time() # Start timing
|
|
315
|
+
|
|
316
|
+
for image_name in tqdm(os.listdir(image_dir), desc="Processing Images", unit="image"):
|
|
317
|
+
image_names.append(image_name)
|
|
318
|
+
image_path = os.path.join(image_dir, image_name)
|
|
319
|
+
label_path = os.path.join(label_dir, image_name.replace('.jpg', '.txt'))
|
|
320
|
+
|
|
321
|
+
frame = cv2.imread(image_path)
|
|
322
|
+
# Suppress print statements from YOLO model
|
|
323
|
+
with torch.no_grad():
|
|
324
|
+
results = self.yolo_model(frame, conf=0.3, iou=0.5, verbose=False)
|
|
325
|
+
|
|
326
|
+
detections = results[0].boxes
|
|
327
|
+
predicted_frame = []
|
|
328
|
+
predicted_family_frame = []
|
|
329
|
+
predicted_order_frame = []
|
|
330
|
+
|
|
331
|
+
if detections:
|
|
332
|
+
for box in detections:
|
|
333
|
+
xyxy = box.xyxy.cpu().numpy().flatten()
|
|
334
|
+
x1, y1, x2, y2 = xyxy[:4]
|
|
335
|
+
width = x2 - x1
|
|
336
|
+
height = y2 - y1
|
|
337
|
+
x_center = x1 + width / 2
|
|
338
|
+
y_center = y1 + height / 2
|
|
339
|
+
|
|
340
|
+
insect_crop = frame[int(y1):int(y2), int(x1):int(x2)]
|
|
341
|
+
insect_crop_rgb = cv2.cvtColor(insect_crop, cv2.COLOR_BGR2RGB)
|
|
342
|
+
pil_img = Image.fromarray(insect_crop_rgb)
|
|
343
|
+
input_tensor = self.classification_transform(pil_img).unsqueeze(0).to(self.device)
|
|
344
|
+
|
|
345
|
+
with torch.no_grad():
|
|
346
|
+
outputs = self.classification_model(input_tensor)
|
|
347
|
+
|
|
348
|
+
# Get all taxonomic level predictions
|
|
349
|
+
order_output = outputs[0] # First output is order (level 1)
|
|
350
|
+
family_output = outputs[1] # Second output is family (level 2)
|
|
351
|
+
species_output = outputs[2] # Third output is species (level 3)
|
|
352
|
+
|
|
353
|
+
# Get prediction indices
|
|
354
|
+
order_idx = order_output.argmax(dim=1).item()
|
|
355
|
+
family_idx = family_output.argmax(dim=1).item()
|
|
356
|
+
species_idx = species_output.argmax(dim=1).item()
|
|
357
|
+
|
|
358
|
+
img_height, img_width, _ = frame.shape
|
|
359
|
+
x_center_norm = x_center / img_width
|
|
360
|
+
y_center_norm = y_center / img_height
|
|
361
|
+
width_norm = width / img_width
|
|
362
|
+
height_norm = height / img_height
|
|
363
|
+
|
|
364
|
+
# Create box coordinates in YOLO format
|
|
365
|
+
box_coords = [x_center_norm, y_center_norm, width_norm, height_norm]
|
|
366
|
+
|
|
367
|
+
# Add predictions for each taxonomic level
|
|
368
|
+
predicted_frame.append([species_idx] + box_coords)
|
|
369
|
+
predicted_family_frame.append([family_idx] + box_coords)
|
|
370
|
+
predicted_order_frame.append([order_idx] + box_coords)
|
|
371
|
+
|
|
372
|
+
predicted_frames.append(predicted_frame if predicted_frame else [])
|
|
373
|
+
predicted_family_frames.append(predicted_family_frame if predicted_family_frame else [])
|
|
374
|
+
predicted_order_frames.append(predicted_order_frame if predicted_order_frame else [])
|
|
375
|
+
|
|
376
|
+
true_species_frame = []
|
|
377
|
+
true_family_frame = []
|
|
378
|
+
true_order_frame = []
|
|
379
|
+
|
|
380
|
+
if os.path.exists(label_path) and os.path.getsize(label_path) > 0:
|
|
381
|
+
with open(label_path, 'r') as f:
|
|
382
|
+
for line in f:
|
|
383
|
+
label_line = line.strip().split()
|
|
384
|
+
species_idx = int(label_line[0])
|
|
385
|
+
box_coords = list(map(np.float32, label_line[1:]))
|
|
386
|
+
|
|
387
|
+
true_species_frame.append([species_idx] + box_coords)
|
|
388
|
+
|
|
389
|
+
if species_idx < len(self.species_names):
|
|
390
|
+
species_name = self.species_names[species_idx]
|
|
391
|
+
|
|
392
|
+
if species_name in self.species_to_family:
|
|
393
|
+
family_name = self.species_to_family[species_name]
|
|
394
|
+
# Get the index of the family in the level_to_idx mapping
|
|
395
|
+
if 2 in self.level_to_idx and family_name in self.level_to_idx[2]:
|
|
396
|
+
family_idx = self.level_to_idx[2][family_name]
|
|
397
|
+
true_family_frame.append([family_idx] + box_coords)
|
|
398
|
+
|
|
399
|
+
if family_name in self.family_to_order:
|
|
400
|
+
order_name = self.family_to_order[family_name]
|
|
401
|
+
if 1 in self.level_to_idx and order_name in self.level_to_idx[1]:
|
|
402
|
+
order_idx = self.level_to_idx[1][order_name]
|
|
403
|
+
true_order_frame.append([order_idx] + box_coords)
|
|
404
|
+
|
|
405
|
+
true_species_frames.append(true_species_frame if true_species_frame else [])
|
|
406
|
+
true_family_frames.append(true_family_frame if true_family_frame else [])
|
|
407
|
+
true_order_frames.append(true_order_frame if true_order_frame else [])
|
|
408
|
+
|
|
409
|
+
end_time = time.time() # End timing
|
|
410
|
+
|
|
411
|
+
# Create a more descriptive filename with timestamp
|
|
412
|
+
output_file = os.path.join(self.output_dir, f"results_hierarchical_{time.strftime('%Y%m%d_%H%M%S')}.csv")
|
|
413
|
+
|
|
414
|
+
with open(output_file, "w", newline='') as f:
|
|
415
|
+
writer = csv.writer(f)
|
|
416
|
+
writer.writerow([
|
|
417
|
+
"Image Name",
|
|
418
|
+
"True Species Detections",
|
|
419
|
+
"True Family Detections",
|
|
420
|
+
"True Order Detections",
|
|
421
|
+
"Species Detections",
|
|
422
|
+
"Family Detections",
|
|
423
|
+
"Order Detections"
|
|
424
|
+
])
|
|
425
|
+
|
|
426
|
+
for image_name, true_species, true_family, true_order, species_pred, family_pred, order_pred in zip(
|
|
427
|
+
image_names,
|
|
428
|
+
true_species_frames,
|
|
429
|
+
true_family_frames,
|
|
430
|
+
true_order_frames,
|
|
431
|
+
predicted_frames,
|
|
432
|
+
predicted_family_frames,
|
|
433
|
+
predicted_order_frames
|
|
434
|
+
):
|
|
435
|
+
writer.writerow([
|
|
436
|
+
image_name,
|
|
437
|
+
true_species,
|
|
438
|
+
true_family,
|
|
439
|
+
true_order,
|
|
440
|
+
species_pred,
|
|
441
|
+
family_pred,
|
|
442
|
+
order_pred
|
|
443
|
+
])
|
|
444
|
+
|
|
445
|
+
print(f"Results saved to {output_file}")
|
|
446
|
+
return predicted_frames, true_species_frames, end_time - start_time, predicted_family_frames, predicted_order_frames, true_family_frames, true_order_frames
|
|
447
|
+
|
|
448
|
+
def run(self, test_dir):
|
|
449
|
+
results = self.get_frames(test_dir)
|
|
450
|
+
predicted_frames, true_species_frames, total_time = results[0], results[1], results[2]
|
|
451
|
+
predicted_family_frames = results[3]
|
|
452
|
+
predicted_order_frames = results[4]
|
|
453
|
+
true_family_frames = results[5]
|
|
454
|
+
true_order_frames = results[6]
|
|
455
|
+
|
|
456
|
+
num_frames = len(os.listdir(os.path.join(test_dir, 'images')))
|
|
457
|
+
avg_time_per_frame = total_time / num_frames
|
|
458
|
+
|
|
459
|
+
print(f"\nTotal time: {total_time:.2f} seconds")
|
|
460
|
+
print(f"Average time per frame: {avg_time_per_frame:.4f} seconds")
|
|
461
|
+
|
|
462
|
+
self.calculate_metrics(
|
|
463
|
+
predicted_frames, true_species_frames,
|
|
464
|
+
predicted_family_frames, true_family_frames,
|
|
465
|
+
predicted_order_frames, true_order_frames
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def calculate_metrics(self, predicted_species_frames, true_species_frames,
|
|
469
|
+
predicted_family_frames, true_family_frames,
|
|
470
|
+
predicted_order_frames, true_order_frames):
|
|
471
|
+
"""Calculate metrics at all taxonomic levels"""
|
|
472
|
+
# Get list of species, families and orders
|
|
473
|
+
species_list = self.species_names
|
|
474
|
+
family_list = sorted(list(set(self.species_to_family.values())))
|
|
475
|
+
order_list = sorted(list(set(self.family_to_order.values())))
|
|
476
|
+
|
|
477
|
+
# Print the index mappings we're using for evaluation
|
|
478
|
+
print("\nUsing the following index mappings for evaluation:")
|
|
479
|
+
print("\nOrder indices:")
|
|
480
|
+
for i, order in enumerate(order_list):
|
|
481
|
+
print(f" {i}: {order}")
|
|
482
|
+
|
|
483
|
+
print("\nFamily indices:")
|
|
484
|
+
for i, family in enumerate(family_list):
|
|
485
|
+
print(f" {i}: {family}")
|
|
486
|
+
|
|
487
|
+
print("\nSpecies indices:")
|
|
488
|
+
for i, species in enumerate(species_list):
|
|
489
|
+
print(f" {i}: {species}")
|
|
490
|
+
|
|
491
|
+
# Dictionary to track prediction category counts for debugging
|
|
492
|
+
prediction_counts = {
|
|
493
|
+
"true_species_boxes": sum(len(frame) for frame in true_species_frames),
|
|
494
|
+
"true_family_boxes": sum(len(frame) for frame in true_family_frames),
|
|
495
|
+
"true_order_boxes": sum(len(frame) for frame in true_order_frames),
|
|
496
|
+
"predicted_species": sum(len(frame) for frame in predicted_species_frames),
|
|
497
|
+
"predicted_family": sum(len(frame) for frame in predicted_family_frames),
|
|
498
|
+
"predicted_order": sum(len(frame) for frame in predicted_order_frames)
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
print(f"Prediction counts: {prediction_counts}")
|
|
502
|
+
|
|
503
|
+
# Calculate metrics for all three levels
|
|
504
|
+
print("\n=== Species-level Metrics ===")
|
|
505
|
+
self.get_metrics(predicted_species_frames, true_species_frames, species_list)
|
|
506
|
+
|
|
507
|
+
print("\n=== Family-level Metrics ===")
|
|
508
|
+
self.get_metrics(predicted_family_frames, true_family_frames, family_list)
|
|
509
|
+
|
|
510
|
+
print("\n=== Order-level Metrics ===")
|
|
511
|
+
self.get_metrics(predicted_order_frames, true_order_frames, order_list)
|
|
512
|
+
|
|
513
|
+
def get_metrics(self, predicted_frames, true_frames, labels):
|
|
514
|
+
"""Calculate metrics for object detection predictions"""
|
|
515
|
+
def calculate_iou(box1, box2):
|
|
516
|
+
x1_min, y1_min = box1[1] - box1[3] / 2, box1[2] - box1[4] / 2
|
|
517
|
+
x1_max, y1_max = box1[1] + box1[3] / 2, box1[2] + box1[4] / 2
|
|
518
|
+
x2_min, y2_min = box2[1] - box2[3] / 2, box2[2] - box2[4] / 2
|
|
519
|
+
x2_max, y2_max = box2[1] + box2[3] / 2, box2[2] + box2[4] / 2
|
|
520
|
+
|
|
521
|
+
inter_x_min = max(x1_min, x2_min)
|
|
522
|
+
inter_y_min = max(y1_min, y2_min)
|
|
523
|
+
inter_x_max = min(x1_max, x2_max)
|
|
524
|
+
inter_y_max = min(y1_max, y2_max)
|
|
525
|
+
|
|
526
|
+
inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
|
|
527
|
+
box1_area = (x1_max - x1_min) * (y1_max - y1_min)
|
|
528
|
+
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
|
|
529
|
+
|
|
530
|
+
iou = inter_area / (box1_area + box2_area - inter_area)
|
|
531
|
+
return iou
|
|
532
|
+
|
|
533
|
+
def calculate_precision_recall(pred_boxes, true_boxes, iou_threshold=0.5):
|
|
534
|
+
label_results = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
|
|
535
|
+
generic_tp = 0
|
|
536
|
+
generic_fp = 0
|
|
537
|
+
|
|
538
|
+
matched_true_boxes = set()
|
|
539
|
+
|
|
540
|
+
for pred_box in pred_boxes:
|
|
541
|
+
label_idx = pred_box[0]
|
|
542
|
+
matched = False
|
|
543
|
+
|
|
544
|
+
best_iou = 0
|
|
545
|
+
best_match_idx = -1
|
|
546
|
+
|
|
547
|
+
for i, true_box in enumerate(true_boxes):
|
|
548
|
+
if i in matched_true_boxes:
|
|
549
|
+
continue
|
|
550
|
+
|
|
551
|
+
iou = calculate_iou(pred_box, true_box)
|
|
552
|
+
if iou >= iou_threshold and iou > best_iou:
|
|
553
|
+
best_iou = iou
|
|
554
|
+
best_match_idx = i
|
|
555
|
+
|
|
556
|
+
if best_match_idx >= 0:
|
|
557
|
+
matched = True
|
|
558
|
+
true_box = true_boxes[best_match_idx]
|
|
559
|
+
matched_true_boxes.add(best_match_idx)
|
|
560
|
+
generic_tp += 1
|
|
561
|
+
|
|
562
|
+
if pred_box[0] == true_box[0]:
|
|
563
|
+
label_results[label_idx]['tp'] += 1
|
|
564
|
+
else:
|
|
565
|
+
label_results[label_idx]['fp'] += 1
|
|
566
|
+
true_label_idx = true_box[0]
|
|
567
|
+
label_results[true_label_idx]['fn'] += 1
|
|
568
|
+
|
|
569
|
+
if not matched:
|
|
570
|
+
label_results[label_idx]['fp'] += 1
|
|
571
|
+
generic_fp += 1
|
|
572
|
+
|
|
573
|
+
for i, true_box in enumerate(true_boxes):
|
|
574
|
+
if i not in matched_true_boxes:
|
|
575
|
+
label_idx = true_box[0]
|
|
576
|
+
label_results[label_idx]['fn'] += 1
|
|
577
|
+
|
|
578
|
+
generic_fn = len(true_boxes) - len(matched_true_boxes)
|
|
579
|
+
|
|
580
|
+
return label_results, generic_tp, generic_fp, generic_fn
|
|
581
|
+
|
|
582
|
+
label_metrics = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0})
|
|
583
|
+
background_metrics = {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0}
|
|
584
|
+
generic_metrics = {'tp': 0, 'fp': 0, 'fn': 0}
|
|
585
|
+
|
|
586
|
+
for true_frame in true_frames:
|
|
587
|
+
if not true_frame: # Empty frame (background only)
|
|
588
|
+
background_metrics['support'] += 1
|
|
589
|
+
else:
|
|
590
|
+
for true_box in true_frame:
|
|
591
|
+
label_idx = true_box[0]
|
|
592
|
+
label_metrics[label_idx]['support'] += 1 # Count each detection, not just unique labels
|
|
593
|
+
|
|
594
|
+
for pred_frame, true_frame in zip(predicted_frames, true_frames):
|
|
595
|
+
if not pred_frame and not true_frame:
|
|
596
|
+
background_metrics['tp'] += 1
|
|
597
|
+
elif not pred_frame:
|
|
598
|
+
background_metrics['fn'] += 1
|
|
599
|
+
elif not true_frame:
|
|
600
|
+
background_metrics['fp'] += 1
|
|
601
|
+
else:
|
|
602
|
+
frame_results, g_tp, g_fp, g_fn = calculate_precision_recall(pred_frame, true_frame)
|
|
603
|
+
|
|
604
|
+
for label_idx, metrics in frame_results.items():
|
|
605
|
+
label_metrics[label_idx]['tp'] += metrics['tp']
|
|
606
|
+
label_metrics[label_idx]['fp'] += metrics['fp']
|
|
607
|
+
label_metrics[label_idx]['fn'] += metrics['fn']
|
|
608
|
+
|
|
609
|
+
generic_metrics['tp'] += g_tp
|
|
610
|
+
generic_metrics['fp'] += g_fp
|
|
611
|
+
generic_metrics['fn'] += g_fn
|
|
612
|
+
|
|
613
|
+
table_data = []
|
|
614
|
+
|
|
615
|
+
for label_idx, metrics in label_metrics.items():
|
|
616
|
+
tp = metrics['tp']
|
|
617
|
+
fp = metrics['fp']
|
|
618
|
+
fn = metrics['fn']
|
|
619
|
+
support = metrics['support']
|
|
620
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
621
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
622
|
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
623
|
+
label_name = labels[label_idx] if label_idx < len(labels) else f"Label {label_idx}"
|
|
624
|
+
table_data.append([label_name, f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
|
|
625
|
+
|
|
626
|
+
tp = background_metrics['tp']
|
|
627
|
+
fp = background_metrics['fp']
|
|
628
|
+
fn = background_metrics['fn']
|
|
629
|
+
support = background_metrics['support']
|
|
630
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
631
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
632
|
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
633
|
+
table_data.append(["Background", f"{precision:.2f}", f"{recall:.2f}", f"{f1_score:.2f}", f"{support}"])
|
|
634
|
+
|
|
635
|
+
headers = ["Label", "Precision", "Recall", "F1 Score", "Support"]
|
|
636
|
+
total_tp = sum(metrics['tp'] for metrics in label_metrics.values())
|
|
637
|
+
total_fp = sum(metrics['fp'] for metrics in label_metrics.values())
|
|
638
|
+
total_fn = sum(metrics['fn'] for metrics in label_metrics.values())
|
|
639
|
+
total_support = sum(metrics['support'] for metrics in label_metrics.values())
|
|
640
|
+
|
|
641
|
+
total_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
|
642
|
+
total_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
|
643
|
+
total_f1_score = 2 * (total_precision * total_recall) / (total_precision + total_recall) if (total_precision + total_recall) > 0 else 0
|
|
644
|
+
|
|
645
|
+
table_data.append(["\nTotal (excluding background)", f"{total_precision:.2f}", f"{total_recall:.2f}", f"{total_f1_score:.2f}", f"{total_support}"])
|
|
646
|
+
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
|
647
|
+
|
|
648
|
+
generic_tp = generic_metrics['tp']
|
|
649
|
+
generic_fp = generic_metrics['fp']
|
|
650
|
+
generic_fn = generic_metrics['fn']
|
|
651
|
+
|
|
652
|
+
generic_precision = generic_tp / (generic_tp + generic_fp) if (generic_tp + generic_fp) > 0 else 0
|
|
653
|
+
generic_recall = generic_tp / (generic_tp + generic_fn) if (generic_tp + generic_fn) > 0 else 0
|
|
654
|
+
generic_f1_score = 2 * (generic_precision * generic_recall) / (generic_precision + generic_recall) if (generic_precision + generic_recall) > 0 else 0
|
|
655
|
+
|
|
656
|
+
print("\nGeneric Total", f"{generic_precision:.2f}", f"{generic_recall:.2f}", f"{generic_f1_score:.2f}")
|
|
657
|
+
|
|
658
|
+
if __name__ == "__main__":
|
|
659
|
+
species_names = [
|
|
660
|
+
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
661
|
+
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
662
|
+
"Eristalis tenax"
|
|
663
|
+
]
|
|
664
|
+
|
|
665
|
+
test_directory = "/mnt/nvme0n1p1/mit/two-stage-detection/bjerge-test"
|
|
666
|
+
yolo_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/small-generic.pt"
|
|
667
|
+
hierarchical_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/hierarchical/hierarchical-weights.pth"
|
|
668
|
+
output_directory = "./output"
|
|
669
|
+
|
|
670
|
+
test_multitask(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
|