bplusplus 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bplusplus/__init__.py +15 -0
- bplusplus/collect.py +523 -0
- bplusplus/detector.py +376 -0
- bplusplus/inference.py +1337 -0
- bplusplus/prepare.py +706 -0
- bplusplus/tracker.py +261 -0
- bplusplus/train.py +913 -0
- bplusplus/validation.py +580 -0
- bplusplus-2.0.4.dist-info/LICENSE +21 -0
- bplusplus-2.0.4.dist-info/METADATA +259 -0
- bplusplus-2.0.4.dist-info/RECORD +12 -0
- bplusplus-2.0.4.dist-info/WHEEL +4 -0
bplusplus/validation.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torchvision import transforms, models
|
|
5
|
+
from PIL import Image
|
|
6
|
+
import numpy as np
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from tabulate import tabulate
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
import requests
|
|
11
|
+
import logging
|
|
12
|
+
import sys
|
|
13
|
+
import argparse
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
def validate(species_list, validation_dir, hierarchical_weights, img_size=640, batch_size=32, backbone: str = "resnet50"):
|
|
19
|
+
"""
|
|
20
|
+
Validate the hierarchical classifier on a directory of images organized by species.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
species_list (list): List of species names used for training
|
|
24
|
+
validation_dir (str): Path to directory containing subdirectories for each species
|
|
25
|
+
hierarchical_weights (str): Path to the hierarchical classifier model file
|
|
26
|
+
img_size (int): Image size for validation (should match training, default: 640)
|
|
27
|
+
batch_size (int): Batch size for processing images (default: 32)
|
|
28
|
+
backbone (str): ResNet backbone to use ('resnet18', 'resnet50', 'resnet101'). Default: 'resnet50'
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Dictionary containing validation results
|
|
32
|
+
"""
|
|
33
|
+
validator = HierarchicalValidator(hierarchical_weights, species_list, img_size, batch_size, backbone)
|
|
34
|
+
results = validator.run(validation_dir)
|
|
35
|
+
print("\nValidation complete with metrics calculated at all taxonomic levels")
|
|
36
|
+
return results
|
|
37
|
+
|
|
38
|
+
def cuda_cleanup():
|
|
39
|
+
"""Clear CUDA cache and reset device"""
|
|
40
|
+
if torch.cuda.is_available():
|
|
41
|
+
torch.cuda.empty_cache()
|
|
42
|
+
torch.cuda.reset_peak_memory_stats()
|
|
43
|
+
|
|
44
|
+
def setup_gpu():
|
|
45
|
+
"""Set up GPU with better error handling and reporting"""
|
|
46
|
+
if not torch.cuda.is_available():
|
|
47
|
+
logger.warning("CUDA is not available on this system")
|
|
48
|
+
return torch.device("cpu")
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
gpu_count = torch.cuda.device_count()
|
|
52
|
+
logger.info(f"Found {gpu_count} CUDA device(s)")
|
|
53
|
+
|
|
54
|
+
for i in range(gpu_count):
|
|
55
|
+
gpu_properties = torch.cuda.get_device_properties(i)
|
|
56
|
+
logger.info(f"GPU {i}: {gpu_properties.name} with {gpu_properties.total_memory / 1e9:.2f} GB memory")
|
|
57
|
+
|
|
58
|
+
device = torch.device("cuda:0")
|
|
59
|
+
test_tensor = torch.ones(1, device=device)
|
|
60
|
+
test_result = test_tensor * 2
|
|
61
|
+
del test_tensor, test_result
|
|
62
|
+
|
|
63
|
+
logger.info("CUDA initialization successful")
|
|
64
|
+
return device
|
|
65
|
+
except Exception as e:
|
|
66
|
+
logger.error(f"CUDA initialization error: {str(e)}")
|
|
67
|
+
logger.warning("Falling back to CPU")
|
|
68
|
+
return torch.device("cpu")
|
|
69
|
+
|
|
70
|
+
# Add this check for backwards compatibility
|
|
71
|
+
if hasattr(torch.serialization, 'add_safe_globals'):
|
|
72
|
+
torch.serialization.add_safe_globals([
|
|
73
|
+
'torch.LongTensor',
|
|
74
|
+
'torch.cuda.LongTensor',
|
|
75
|
+
'torch.FloatStorage',
|
|
76
|
+
'torch.FloatStorage',
|
|
77
|
+
'torch.cuda.FloatStorage',
|
|
78
|
+
])
|
|
79
|
+
|
|
80
|
+
class HierarchicalInsectClassifier(nn.Module):
|
|
81
|
+
def __init__(self, num_classes_per_level, level_to_idx=None, parent_child_relationship=None, backbone: str = "resnet50"):
|
|
82
|
+
"""
|
|
83
|
+
Args:
|
|
84
|
+
num_classes_per_level (list): Number of classes for each taxonomic level
|
|
85
|
+
level_to_idx (dict): Mapping from level to label indices
|
|
86
|
+
parent_child_relationship (dict): Parent-child relationships in taxonomy
|
|
87
|
+
backbone (str): ResNet backbone to use ('resnet18', 'resnet50', 'resnet101')
|
|
88
|
+
"""
|
|
89
|
+
super(HierarchicalInsectClassifier, self).__init__()
|
|
90
|
+
|
|
91
|
+
self.backbone = self._build_backbone(backbone)
|
|
92
|
+
self.backbone_name = backbone
|
|
93
|
+
backbone_output_features = self.backbone.fc.in_features
|
|
94
|
+
self.backbone.fc = nn.Identity()
|
|
95
|
+
|
|
96
|
+
self.branches = nn.ModuleList()
|
|
97
|
+
for num_classes in num_classes_per_level:
|
|
98
|
+
branch = nn.Sequential(
|
|
99
|
+
nn.Linear(backbone_output_features, 512),
|
|
100
|
+
nn.ReLU(),
|
|
101
|
+
nn.Dropout(0.5),
|
|
102
|
+
nn.Linear(512, num_classes)
|
|
103
|
+
)
|
|
104
|
+
self.branches.append(branch)
|
|
105
|
+
|
|
106
|
+
self.num_levels = len(num_classes_per_level)
|
|
107
|
+
|
|
108
|
+
# Store the taxonomy mappings
|
|
109
|
+
self.level_to_idx = level_to_idx
|
|
110
|
+
self.parent_child_relationship = parent_child_relationship
|
|
111
|
+
|
|
112
|
+
self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
|
|
113
|
+
self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
|
|
114
|
+
self.class_counts = [0] * sum(num_classes_per_level)
|
|
115
|
+
self.output_history = defaultdict(list)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def _build_backbone(backbone: str):
|
|
119
|
+
name = backbone.lower()
|
|
120
|
+
if name == "resnet18":
|
|
121
|
+
return models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
|
122
|
+
if name == "resnet50":
|
|
123
|
+
return models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
124
|
+
if name == "resnet101":
|
|
125
|
+
return models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
|
|
126
|
+
raise ValueError(f"Unsupported backbone '{backbone}'. Choose from 'resnet18', 'resnet50', 'resnet101'.")
|
|
127
|
+
|
|
128
|
+
def forward(self, x):
|
|
129
|
+
R0 = self.backbone(x)
|
|
130
|
+
|
|
131
|
+
outputs = []
|
|
132
|
+
for branch in self.branches:
|
|
133
|
+
outputs.append(branch(R0))
|
|
134
|
+
|
|
135
|
+
return outputs
|
|
136
|
+
|
|
137
|
+
def get_taxonomy(species_list):
|
|
138
|
+
"""
|
|
139
|
+
Retrieves taxonomic information for a list of species from GBIF API.
|
|
140
|
+
Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
|
|
141
|
+
"""
|
|
142
|
+
taxonomy = {1: [], 2: {}, 3: {}}
|
|
143
|
+
species_to_genus = {}
|
|
144
|
+
genus_to_family = {}
|
|
145
|
+
|
|
146
|
+
species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
|
|
147
|
+
has_unknown = len(species_list_for_gbif) != len(species_list)
|
|
148
|
+
|
|
149
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
|
|
150
|
+
|
|
151
|
+
print("\nTaxonomy Results:")
|
|
152
|
+
print("-" * 80)
|
|
153
|
+
print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
|
|
154
|
+
print("-" * 80)
|
|
155
|
+
|
|
156
|
+
for species_name in species_list_for_gbif:
|
|
157
|
+
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
158
|
+
try:
|
|
159
|
+
response = requests.get(url)
|
|
160
|
+
data = response.json()
|
|
161
|
+
|
|
162
|
+
if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
|
|
163
|
+
family = data.get('family')
|
|
164
|
+
genus = data.get('genus')
|
|
165
|
+
|
|
166
|
+
if family and genus:
|
|
167
|
+
status = "OK"
|
|
168
|
+
|
|
169
|
+
print(f"{species_name:<30} {family:<20} {genus:<20} {status}")
|
|
170
|
+
|
|
171
|
+
species_to_genus[species_name] = genus
|
|
172
|
+
genus_to_family[genus] = family
|
|
173
|
+
|
|
174
|
+
if family not in taxonomy[1]:
|
|
175
|
+
taxonomy[1].append(family)
|
|
176
|
+
|
|
177
|
+
taxonomy[2][genus] = family
|
|
178
|
+
taxonomy[3][species_name] = genus
|
|
179
|
+
else:
|
|
180
|
+
error_msg = f"Species '{species_name}' found in GBIF but family and genus not found, could be spelling error in species, check GBIF"
|
|
181
|
+
logger.error(error_msg)
|
|
182
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
183
|
+
print(f"Error: {error_msg}")
|
|
184
|
+
sys.exit(1)
|
|
185
|
+
else:
|
|
186
|
+
error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
|
|
187
|
+
logger.error(error_msg)
|
|
188
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
189
|
+
print(f"Error: {error_msg}")
|
|
190
|
+
sys.exit(1)
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
|
|
194
|
+
logger.error(error_msg)
|
|
195
|
+
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
196
|
+
print(f"Error: {error_msg}")
|
|
197
|
+
sys.exit(1)
|
|
198
|
+
|
|
199
|
+
if has_unknown:
|
|
200
|
+
unknown_family = "Unknown"
|
|
201
|
+
unknown_genus = "Unknown"
|
|
202
|
+
unknown_species = "unknown"
|
|
203
|
+
|
|
204
|
+
if unknown_family not in taxonomy[1]:
|
|
205
|
+
taxonomy[1].append(unknown_family)
|
|
206
|
+
|
|
207
|
+
taxonomy[2][unknown_genus] = unknown_family
|
|
208
|
+
taxonomy[3][unknown_species] = unknown_genus
|
|
209
|
+
|
|
210
|
+
print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
|
|
211
|
+
|
|
212
|
+
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
213
|
+
print("-" * 80)
|
|
214
|
+
|
|
215
|
+
num_families = len(taxonomy[1])
|
|
216
|
+
num_genera = len(taxonomy[2])
|
|
217
|
+
num_species = len(taxonomy[3])
|
|
218
|
+
|
|
219
|
+
print("\nFamily indices:")
|
|
220
|
+
for i, family in enumerate(taxonomy[1]):
|
|
221
|
+
print(f" {i}: {family}")
|
|
222
|
+
|
|
223
|
+
print("\nGenus indices:")
|
|
224
|
+
for i, genus in enumerate(sorted(taxonomy[2].keys())):
|
|
225
|
+
print(f" {i}: {genus}")
|
|
226
|
+
|
|
227
|
+
print("\nSpecies indices:")
|
|
228
|
+
for i, species in enumerate(species_list):
|
|
229
|
+
print(f" {i}: {species}")
|
|
230
|
+
|
|
231
|
+
logger.info(f"Taxonomy built: {num_families} families, {num_genera} genera, {num_species} species")
|
|
232
|
+
return taxonomy, species_to_genus, genus_to_family
|
|
233
|
+
|
|
234
|
+
def create_mappings(taxonomy, species_list=None):
|
|
235
|
+
"""Create index mappings from taxonomy"""
|
|
236
|
+
level_to_idx = {}
|
|
237
|
+
idx_to_level = {}
|
|
238
|
+
|
|
239
|
+
for level, labels in taxonomy.items():
|
|
240
|
+
if isinstance(labels, list):
|
|
241
|
+
# Level 1: Family (already sorted)
|
|
242
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
243
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
|
|
244
|
+
else: # dict for levels 2 and 3
|
|
245
|
+
if level == 3 and species_list is not None:
|
|
246
|
+
# For species, the order is determined by species_list
|
|
247
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(species_list)}
|
|
248
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(species_list)}
|
|
249
|
+
else:
|
|
250
|
+
# For genus, sort alphabetically
|
|
251
|
+
sorted_keys = sorted(labels.keys())
|
|
252
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
|
|
253
|
+
idx_to_level[level] = {idx: label for idx, label in enumerate(sorted_keys)}
|
|
254
|
+
|
|
255
|
+
return level_to_idx, idx_to_level
|
|
256
|
+
|
|
257
|
+
class HierarchicalValidator:
|
|
258
|
+
def __init__(self, hierarchical_model_path, species_names, img_size=640, batch_size=32, backbone: str = "resnet50"):
|
|
259
|
+
cuda_cleanup()
|
|
260
|
+
|
|
261
|
+
self.device = setup_gpu()
|
|
262
|
+
logger.info(f"Using device: {self.device}")
|
|
263
|
+
print(f"Using device: {self.device}")
|
|
264
|
+
|
|
265
|
+
self.species_names = species_names
|
|
266
|
+
self.img_size = img_size
|
|
267
|
+
self.batch_size = batch_size
|
|
268
|
+
self.backbone = backbone
|
|
269
|
+
|
|
270
|
+
logger.info(f"Loading model from {hierarchical_model_path}")
|
|
271
|
+
try:
|
|
272
|
+
checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
|
|
273
|
+
logger.info("Model loaded to CPU successfully")
|
|
274
|
+
except Exception as e:
|
|
275
|
+
logger.error(f"Error loading model: {e}")
|
|
276
|
+
raise
|
|
277
|
+
|
|
278
|
+
# Extract taxonomy and model state
|
|
279
|
+
if "model_state_dict" in checkpoint:
|
|
280
|
+
state_dict = checkpoint["model_state_dict"]
|
|
281
|
+
checkpoint_backbone = checkpoint.get("backbone", backbone)
|
|
282
|
+
self.backbone = checkpoint_backbone
|
|
283
|
+
|
|
284
|
+
if "taxonomy" in checkpoint:
|
|
285
|
+
print("Using taxonomy from saved model")
|
|
286
|
+
taxonomy = checkpoint["taxonomy"]
|
|
287
|
+
if "species_list" in checkpoint:
|
|
288
|
+
saved_species = checkpoint["species_list"]
|
|
289
|
+
print(f"Saved model was trained on: {', '.join(saved_species)}")
|
|
290
|
+
|
|
291
|
+
# Construct mappings from saved taxonomy
|
|
292
|
+
if "level_to_idx" in checkpoint:
|
|
293
|
+
level_to_idx = checkpoint["level_to_idx"]
|
|
294
|
+
# Create idx_to_level from level_to_idx
|
|
295
|
+
idx_to_level = {}
|
|
296
|
+
for level, label_dict in level_to_idx.items():
|
|
297
|
+
idx_to_level[level] = {idx: label for label, idx in label_dict.items()}
|
|
298
|
+
species_to_genus = {species: genus for species, genus in taxonomy[3].items()}
|
|
299
|
+
genus_to_family = {genus: family for genus, family in taxonomy[2].items()}
|
|
300
|
+
else:
|
|
301
|
+
# Fallback: create mappings from taxonomy
|
|
302
|
+
print("Warning: No level_to_idx in checkpoint, creating from taxonomy")
|
|
303
|
+
level_to_idx, idx_to_level = create_mappings(taxonomy, species_names)
|
|
304
|
+
species_to_genus = {species: genus for species, genus in taxonomy[3].items()}
|
|
305
|
+
genus_to_family = {genus: family for genus, family in taxonomy[2].items()}
|
|
306
|
+
else:
|
|
307
|
+
# Fetch from GBIF
|
|
308
|
+
print("No taxonomy in checkpoint, fetching from GBIF")
|
|
309
|
+
taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
310
|
+
level_to_idx, idx_to_level = create_mappings(taxonomy, species_names)
|
|
311
|
+
else:
|
|
312
|
+
# Old format without model_state_dict wrapper
|
|
313
|
+
state_dict = checkpoint
|
|
314
|
+
taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
315
|
+
level_to_idx, idx_to_level = create_mappings(taxonomy, species_names)
|
|
316
|
+
# keep user-provided backbone when old checkpoint is used
|
|
317
|
+
|
|
318
|
+
self.level_to_idx = level_to_idx
|
|
319
|
+
self.idx_to_level = idx_to_level
|
|
320
|
+
self.taxonomy = taxonomy
|
|
321
|
+
self.species_to_genus = species_to_genus
|
|
322
|
+
self.genus_to_family = genus_to_family
|
|
323
|
+
|
|
324
|
+
# Get number of classes per level
|
|
325
|
+
if hasattr(taxonomy, "items"):
|
|
326
|
+
num_classes_per_level = [len(classes) if isinstance(classes, list) else len(classes.keys())
|
|
327
|
+
for level, classes in taxonomy.items()]
|
|
328
|
+
|
|
329
|
+
print(f"Using model with class counts: {num_classes_per_level}")
|
|
330
|
+
|
|
331
|
+
# Initialize and load model
|
|
332
|
+
self.model = HierarchicalInsectClassifier(
|
|
333
|
+
num_classes_per_level,
|
|
334
|
+
level_to_idx=level_to_idx,
|
|
335
|
+
parent_child_relationship=checkpoint.get("parent_child_relationship", None) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint else None,
|
|
336
|
+
backbone=self.backbone
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
try:
|
|
340
|
+
self.model.load_state_dict(state_dict)
|
|
341
|
+
print("Model weights loaded successfully")
|
|
342
|
+
except Exception as e:
|
|
343
|
+
print(f"Error loading model weights: {e}")
|
|
344
|
+
print("Attempting to load with strict=False...")
|
|
345
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
346
|
+
print("Model weights loaded with strict=False")
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
self.model.to(self.device)
|
|
350
|
+
print(f"Model successfully transferred to {self.device}")
|
|
351
|
+
except RuntimeError as e:
|
|
352
|
+
logger.error(f"Error transferring model to {self.device}: {e}")
|
|
353
|
+
print(f"Error transferring model to {self.device}, falling back to CPU")
|
|
354
|
+
self.device = torch.device("cpu")
|
|
355
|
+
|
|
356
|
+
self.model.eval()
|
|
357
|
+
|
|
358
|
+
# Transform for validation (same as training validation transform)
|
|
359
|
+
self.transform = transforms.Compose([
|
|
360
|
+
transforms.Resize((self.img_size, self.img_size)),
|
|
361
|
+
transforms.ToTensor(),
|
|
362
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
363
|
+
])
|
|
364
|
+
|
|
365
|
+
print("Model successfully loaded")
|
|
366
|
+
print(f"Using image size: {self.img_size}x{self.img_size}")
|
|
367
|
+
print(f"Using species: {', '.join(species_names)}")
|
|
368
|
+
|
|
369
|
+
def load_images_from_directory(self, validation_dir):
|
|
370
|
+
"""Load all images from subdirectories organized by species"""
|
|
371
|
+
images = []
|
|
372
|
+
labels = []
|
|
373
|
+
image_paths = []
|
|
374
|
+
|
|
375
|
+
print(f"\nLoading images from {validation_dir}")
|
|
376
|
+
|
|
377
|
+
for species_name in self.species_names:
|
|
378
|
+
species_dir = os.path.join(validation_dir, species_name)
|
|
379
|
+
|
|
380
|
+
if not os.path.exists(species_dir):
|
|
381
|
+
logger.warning(f"Directory not found for species: {species_name}")
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
species_idx = self.species_names.index(species_name)
|
|
385
|
+
|
|
386
|
+
# Get taxonomy for this species
|
|
387
|
+
if species_name in self.species_to_genus:
|
|
388
|
+
genus_name = self.species_to_genus[species_name]
|
|
389
|
+
family_name = self.genus_to_family[genus_name]
|
|
390
|
+
|
|
391
|
+
genus_idx = self.level_to_idx[2][genus_name]
|
|
392
|
+
family_idx = self.level_to_idx[1][family_name]
|
|
393
|
+
else:
|
|
394
|
+
logger.warning(f"Taxonomy not found for species: {species_name}")
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
# Load all images for this species
|
|
398
|
+
image_files = [f for f in os.listdir(species_dir)
|
|
399
|
+
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
400
|
+
|
|
401
|
+
for img_file in image_files:
|
|
402
|
+
img_path = os.path.join(species_dir, img_file)
|
|
403
|
+
try:
|
|
404
|
+
img = Image.open(img_path).convert('RGB')
|
|
405
|
+
images.append(img)
|
|
406
|
+
labels.append([family_idx, genus_idx, species_idx])
|
|
407
|
+
image_paths.append(img_path)
|
|
408
|
+
except Exception as e:
|
|
409
|
+
logger.warning(f"Error loading image {img_path}: {e}")
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
print(f" {species_name}: {len(image_files)} images")
|
|
413
|
+
|
|
414
|
+
print(f"\nTotal images loaded: {len(images)}")
|
|
415
|
+
return images, labels, image_paths
|
|
416
|
+
|
|
417
|
+
def predict_images(self, images):
|
|
418
|
+
"""Run prediction on a list of images"""
|
|
419
|
+
predictions = []
|
|
420
|
+
|
|
421
|
+
# Process in batches
|
|
422
|
+
for i in tqdm(range(0, len(images), self.batch_size), desc="Processing batches"):
|
|
423
|
+
batch_images = images[i:i + self.batch_size]
|
|
424
|
+
|
|
425
|
+
# Transform and batch
|
|
426
|
+
batch_tensors = []
|
|
427
|
+
for img in batch_images:
|
|
428
|
+
tensor = self.transform(img)
|
|
429
|
+
batch_tensors.append(tensor)
|
|
430
|
+
|
|
431
|
+
batch_tensor = torch.stack(batch_tensors).to(self.device)
|
|
432
|
+
|
|
433
|
+
# Predict
|
|
434
|
+
with torch.no_grad():
|
|
435
|
+
outputs = self.model(batch_tensor)
|
|
436
|
+
|
|
437
|
+
# Get predictions for each level
|
|
438
|
+
for j in range(len(batch_images)):
|
|
439
|
+
pred = []
|
|
440
|
+
for level_output in outputs:
|
|
441
|
+
level_pred = level_output[j].argmax().item()
|
|
442
|
+
pred.append(level_pred)
|
|
443
|
+
predictions.append(pred)
|
|
444
|
+
|
|
445
|
+
return predictions
|
|
446
|
+
|
|
447
|
+
def run(self, validation_dir):
|
|
448
|
+
"""Run validation on the dataset"""
|
|
449
|
+
# Load images
|
|
450
|
+
images, labels, image_paths = self.load_images_from_directory(validation_dir)
|
|
451
|
+
|
|
452
|
+
if len(images) == 0:
|
|
453
|
+
logger.error("No images found in validation directory")
|
|
454
|
+
return None
|
|
455
|
+
|
|
456
|
+
# Get predictions
|
|
457
|
+
print("\nRunning predictions...")
|
|
458
|
+
predictions = self.predict_images(images)
|
|
459
|
+
|
|
460
|
+
# Calculate metrics
|
|
461
|
+
print("\n" + "="*80)
|
|
462
|
+
print("VALIDATION RESULTS")
|
|
463
|
+
print("="*80)
|
|
464
|
+
self.calculate_metrics(predictions, labels)
|
|
465
|
+
|
|
466
|
+
return {
|
|
467
|
+
'predictions': predictions,
|
|
468
|
+
'labels': labels,
|
|
469
|
+
'image_paths': image_paths
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
def calculate_metrics(self, predictions, labels):
|
|
473
|
+
"""Calculate metrics at all taxonomic levels"""
|
|
474
|
+
# Convert to numpy arrays for easier manipulation
|
|
475
|
+
predictions = np.array(predictions)
|
|
476
|
+
labels = np.array(labels)
|
|
477
|
+
|
|
478
|
+
level_names = ['Family', 'Genus', 'Species']
|
|
479
|
+
|
|
480
|
+
for level in range(3):
|
|
481
|
+
print(f"\n{'='*80}")
|
|
482
|
+
print(f"{level_names[level]}-level Metrics")
|
|
483
|
+
print(f"{'='*80}")
|
|
484
|
+
|
|
485
|
+
# Get predictions and labels for this level
|
|
486
|
+
level_preds = predictions[:, level]
|
|
487
|
+
level_labels = labels[:, level]
|
|
488
|
+
|
|
489
|
+
# Get label names for this level
|
|
490
|
+
if level == 0: # Family
|
|
491
|
+
label_names = [self.idx_to_level[1][i] for i in sorted(self.idx_to_level[1].keys())]
|
|
492
|
+
elif level == 1: # Genus
|
|
493
|
+
label_names = [self.idx_to_level[2][i] for i in sorted(self.idx_to_level[2].keys())]
|
|
494
|
+
else: # Species
|
|
495
|
+
label_names = self.species_names
|
|
496
|
+
|
|
497
|
+
self.print_classification_report(level_preds, level_labels, label_names)
|
|
498
|
+
|
|
499
|
+
def print_classification_report(self, predictions, labels, label_names):
|
|
500
|
+
"""Print detailed classification metrics"""
|
|
501
|
+
# Calculate per-class metrics
|
|
502
|
+
unique_labels = sorted(set(labels))
|
|
503
|
+
|
|
504
|
+
table_data = []
|
|
505
|
+
total_correct = 0
|
|
506
|
+
total_samples = 0
|
|
507
|
+
|
|
508
|
+
for label_idx in unique_labels:
|
|
509
|
+
# Get indices for this class
|
|
510
|
+
class_mask = labels == label_idx
|
|
511
|
+
class_preds = predictions[class_mask]
|
|
512
|
+
class_labels = labels[class_mask]
|
|
513
|
+
|
|
514
|
+
# Calculate metrics
|
|
515
|
+
support = len(class_labels)
|
|
516
|
+
correct = (class_preds == class_labels).sum()
|
|
517
|
+
|
|
518
|
+
# Calculate precision, recall, f1
|
|
519
|
+
tp = correct
|
|
520
|
+
fp = ((predictions == label_idx) & (labels != label_idx)).sum()
|
|
521
|
+
fn = support - correct
|
|
522
|
+
|
|
523
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
524
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
525
|
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
526
|
+
|
|
527
|
+
label_name = label_names[label_idx] if label_idx < len(label_names) else f"Label {label_idx}"
|
|
528
|
+
table_data.append([label_name, f"{precision:.4f}", f"{recall:.4f}", f"{f1_score:.4f}", support])
|
|
529
|
+
|
|
530
|
+
total_correct += correct
|
|
531
|
+
total_samples += support
|
|
532
|
+
|
|
533
|
+
# Calculate overall metrics
|
|
534
|
+
overall_accuracy = total_correct / total_samples if total_samples > 0 else 0
|
|
535
|
+
|
|
536
|
+
# Calculate macro and weighted averages
|
|
537
|
+
macro_precision = np.mean([float(row[1]) for row in table_data])
|
|
538
|
+
macro_recall = np.mean([float(row[2]) for row in table_data])
|
|
539
|
+
macro_f1 = np.mean([float(row[3]) for row in table_data])
|
|
540
|
+
|
|
541
|
+
weighted_precision = np.sum([float(row[1]) * row[4] for row in table_data]) / total_samples
|
|
542
|
+
weighted_recall = np.sum([float(row[2]) * row[4] for row in table_data]) / total_samples
|
|
543
|
+
weighted_f1 = np.sum([float(row[3]) * row[4] for row in table_data]) / total_samples
|
|
544
|
+
|
|
545
|
+
# Add summary rows
|
|
546
|
+
table_data.append([])
|
|
547
|
+
table_data.append(["Macro avg", f"{macro_precision:.4f}", f"{macro_recall:.4f}", f"{macro_f1:.4f}", total_samples])
|
|
548
|
+
table_data.append(["Weighted avg", f"{weighted_precision:.4f}", f"{weighted_recall:.4f}", f"{weighted_f1:.4f}", total_samples])
|
|
549
|
+
table_data.append([])
|
|
550
|
+
table_data.append(["Overall Accuracy", "", "", f"{overall_accuracy:.4f}", total_samples])
|
|
551
|
+
|
|
552
|
+
headers = ["Label", "Precision", "Recall", "F1 Score", "Support"]
|
|
553
|
+
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
|
554
|
+
|
|
555
|
+
if __name__ == "__main__":
|
|
556
|
+
parser = argparse.ArgumentParser(description='Validate hierarchical insect classifier')
|
|
557
|
+
parser.add_argument('--validation_dir', type=str, required=True,
|
|
558
|
+
help='Path to validation directory with subdirectories for each species')
|
|
559
|
+
parser.add_argument('--weights', type=str, required=True,
|
|
560
|
+
help='Path to hierarchical model weights')
|
|
561
|
+
parser.add_argument('--species', type=str, nargs='+', required=True,
|
|
562
|
+
help='List of species names (must match training order)')
|
|
563
|
+
parser.add_argument('--img_size', type=int, default=640,
|
|
564
|
+
help='Image size for validation (should match training, default: 640)')
|
|
565
|
+
parser.add_argument('--batch_size', type=int, default=32,
|
|
566
|
+
help='Batch size for processing (default: 32)')
|
|
567
|
+
parser.add_argument('--backbone', type=str, default='resnet50', choices=['resnet18', 'resnet50', 'resnet101'],
|
|
568
|
+
help='ResNet backbone to use (default: resnet50)')
|
|
569
|
+
|
|
570
|
+
args = parser.parse_args()
|
|
571
|
+
|
|
572
|
+
validate(
|
|
573
|
+
species_list=args.species,
|
|
574
|
+
validation_dir=args.validation_dir,
|
|
575
|
+
hierarchical_weights=args.weights,
|
|
576
|
+
img_size=args.img_size,
|
|
577
|
+
batch_size=args.batch_size,
|
|
578
|
+
backbone=args.backbone
|
|
579
|
+
)
|
|
580
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Titus Venverloo
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|