bplusplus 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bplusplus/__init__.py +15 -0
- bplusplus/collect.py +523 -0
- bplusplus/detector.py +376 -0
- bplusplus/inference.py +1337 -0
- bplusplus/prepare.py +706 -0
- bplusplus/tracker.py +261 -0
- bplusplus/train.py +913 -0
- bplusplus/validation.py +580 -0
- bplusplus-2.0.4.dist-info/LICENSE +21 -0
- bplusplus-2.0.4.dist-info/METADATA +259 -0
- bplusplus-2.0.4.dist-info/RECORD +12 -0
- bplusplus-2.0.4.dist-info/WHEEL +4 -0
bplusplus/train.py
ADDED
|
@@ -0,0 +1,913 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
import torchvision.models as models
|
|
5
|
+
import torchvision.transforms as transforms
|
|
6
|
+
from torch.utils.data import Dataset, DataLoader
|
|
7
|
+
from PIL import Image
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
import requests
|
|
12
|
+
import time
|
|
13
|
+
import logging
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
def train(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='input', output_dir='./output', species_list=None, num_workers=4, train_transforms=None, backbone: str = "resnet50"):
|
|
18
|
+
"""
|
|
19
|
+
Main function to run the entire training pipeline.
|
|
20
|
+
Sets up datasets, model, training process and handles errors.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
batch_size (int): Number of samples per batch. Default: 4
|
|
24
|
+
epochs (int): Maximum number of training epochs. Default: 30
|
|
25
|
+
patience (int): Early stopping patience (epochs without improvement). Default: 3
|
|
26
|
+
img_size (int): Target image size for training. Default: 640
|
|
27
|
+
data_dir (str): Directory containing train/valid subdirectories. Default: 'input'
|
|
28
|
+
output_dir (str): Directory to save trained model and logs. Default: './output'
|
|
29
|
+
species_list (list): List of species names for training. Required.
|
|
30
|
+
num_workers (int): Number of DataLoader worker processes.
|
|
31
|
+
Set to 0 to disable multiprocessing (most stable). Default: 4
|
|
32
|
+
train_transforms: Optional custom torchvision transforms for training data.
|
|
33
|
+
backbone (str): ResNet backbone to use ('resnet18', 'resnet50', 'resnet101'). Default: 'resnet50'
|
|
34
|
+
"""
|
|
35
|
+
global logger, device
|
|
36
|
+
|
|
37
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}, Num workers: {num_workers}, Backbone: {backbone}")
|
|
41
|
+
|
|
42
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
43
|
+
|
|
44
|
+
torch.manual_seed(42)
|
|
45
|
+
np.random.seed(42)
|
|
46
|
+
|
|
47
|
+
learning_rate = 1.0e-4
|
|
48
|
+
|
|
49
|
+
train_dir = os.path.join(data_dir, 'train')
|
|
50
|
+
val_dir = os.path.join(data_dir, 'valid')
|
|
51
|
+
|
|
52
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
53
|
+
|
|
54
|
+
missing_species = []
|
|
55
|
+
for species in species_list:
|
|
56
|
+
species_dir = os.path.join(train_dir, species)
|
|
57
|
+
if not os.path.isdir(species_dir):
|
|
58
|
+
missing_species.append(species)
|
|
59
|
+
|
|
60
|
+
if missing_species:
|
|
61
|
+
raise ValueError(f"The following species directories were not found: {missing_species}")
|
|
62
|
+
|
|
63
|
+
logger.info(f"Using {len(species_list)} species in the specified order")
|
|
64
|
+
|
|
65
|
+
taxonomy = get_taxonomy(species_list)
|
|
66
|
+
|
|
67
|
+
level_to_idx, parent_child_relationship = create_mappings(taxonomy, species_list)
|
|
68
|
+
|
|
69
|
+
num_classes_per_level = [len(taxonomy[level]) if isinstance(taxonomy[level], list)
|
|
70
|
+
else len(taxonomy[level].keys()) for level in sorted(taxonomy.keys())]
|
|
71
|
+
|
|
72
|
+
def _has_images(path):
|
|
73
|
+
for _, _, files in os.walk(path):
|
|
74
|
+
if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files):
|
|
75
|
+
return True
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
train_dataset = InsectDataset(
|
|
79
|
+
root_dir=train_dir,
|
|
80
|
+
transform=train_transforms or get_transforms(is_training=True, img_size=img_size),
|
|
81
|
+
taxonomy=taxonomy,
|
|
82
|
+
level_to_idx=level_to_idx
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Analyze class balance and warn about imbalances
|
|
86
|
+
balance_stats = analyze_class_balance(train_dataset, taxonomy, level_to_idx)
|
|
87
|
+
|
|
88
|
+
# Log balance summary
|
|
89
|
+
for level, stats in balance_stats.items():
|
|
90
|
+
level_name = {1: "Family", 2: "Genus", 3: "Species"}[level]
|
|
91
|
+
logger.info(f"{level_name} balance: {stats['severity']} (ratio: {stats['imbalance_ratio']:.1f}x, CV: {stats['cv']:.1f}%)")
|
|
92
|
+
|
|
93
|
+
validation_available = os.path.isdir(val_dir) and _has_images(val_dir)
|
|
94
|
+
if not validation_available:
|
|
95
|
+
logger.warning("Validation skipped: 'valid' directory missing or contains no images.")
|
|
96
|
+
val_loader = None
|
|
97
|
+
else:
|
|
98
|
+
val_dataset = InsectDataset(
|
|
99
|
+
root_dir=val_dir,
|
|
100
|
+
transform=get_transforms(is_training=False, img_size=img_size),
|
|
101
|
+
taxonomy=taxonomy,
|
|
102
|
+
level_to_idx=level_to_idx
|
|
103
|
+
)
|
|
104
|
+
val_loader = DataLoader(
|
|
105
|
+
val_dataset,
|
|
106
|
+
batch_size=batch_size,
|
|
107
|
+
shuffle=False,
|
|
108
|
+
num_workers=num_workers
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
train_loader = DataLoader(
|
|
112
|
+
train_dataset,
|
|
113
|
+
batch_size=batch_size,
|
|
114
|
+
shuffle=True,
|
|
115
|
+
num_workers=num_workers
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
logger.info("Initializing model...")
|
|
120
|
+
model = HierarchicalInsectClassifier(
|
|
121
|
+
num_classes_per_level=num_classes_per_level,
|
|
122
|
+
level_to_idx=level_to_idx,
|
|
123
|
+
parent_child_relationship=parent_child_relationship,
|
|
124
|
+
backbone=backbone
|
|
125
|
+
)
|
|
126
|
+
logger.info(f"Model structure initialized with {sum(p.numel() for p in model.parameters())} parameters")
|
|
127
|
+
|
|
128
|
+
logger.info("Setting up loss function and optimizer...")
|
|
129
|
+
criterion = HierarchicalLoss(
|
|
130
|
+
alpha=0.5,
|
|
131
|
+
level_to_idx=level_to_idx,
|
|
132
|
+
parent_child_relationship=parent_child_relationship
|
|
133
|
+
)
|
|
134
|
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
135
|
+
|
|
136
|
+
logger.info("Testing model with a dummy input...")
|
|
137
|
+
model.to(device)
|
|
138
|
+
dummy_input = torch.randn(1, 3, img_size, img_size).to(device)
|
|
139
|
+
with torch.no_grad():
|
|
140
|
+
try:
|
|
141
|
+
test_outputs = model(dummy_input)
|
|
142
|
+
logger.info(f"Forward pass test successful, output shapes: {[out.shape for out in test_outputs]}")
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Forward pass test failed: {str(e)}")
|
|
145
|
+
raise
|
|
146
|
+
|
|
147
|
+
logger.info("Starting training process...")
|
|
148
|
+
best_model_path = os.path.join(output_dir, 'best_multitask.pt')
|
|
149
|
+
|
|
150
|
+
trained_model = train_model(
|
|
151
|
+
model=model,
|
|
152
|
+
train_loader=train_loader,
|
|
153
|
+
val_loader=val_loader,
|
|
154
|
+
criterion=criterion,
|
|
155
|
+
optimizer=optimizer,
|
|
156
|
+
level_to_idx=level_to_idx,
|
|
157
|
+
parent_child_relationship=parent_child_relationship,
|
|
158
|
+
taxonomy=taxonomy,
|
|
159
|
+
species_list=species_list,
|
|
160
|
+
num_epochs=epochs,
|
|
161
|
+
patience=patience,
|
|
162
|
+
best_model_path=best_model_path,
|
|
163
|
+
backbone=backbone
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
logger.info("Model saved successfully with taxonomy information!")
|
|
167
|
+
print("Model saved successfully with taxonomy information!")
|
|
168
|
+
|
|
169
|
+
return trained_model, taxonomy, level_to_idx, parent_child_relationship
|
|
170
|
+
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.error(f"Critical error during model setup or training: {str(e)}")
|
|
173
|
+
logger.exception("Stack trace:")
|
|
174
|
+
raise
|
|
175
|
+
|
|
176
|
+
def analyze_class_balance(dataset, taxonomy, level_to_idx):
|
|
177
|
+
"""
|
|
178
|
+
Analyze class balance across all taxonomic levels and warn about imbalances.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
dataset: InsectDataset instance
|
|
182
|
+
taxonomy: Taxonomy dictionary
|
|
183
|
+
level_to_idx: Level to index mapping
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
dict: Balance statistics for each level
|
|
187
|
+
"""
|
|
188
|
+
# Count samples per class at each level
|
|
189
|
+
level_names = {1: "Family", 2: "Genus", 3: "Species"}
|
|
190
|
+
level_counts = {1: defaultdict(int), 2: defaultdict(int), 3: defaultdict(int)}
|
|
191
|
+
|
|
192
|
+
for sample in dataset.samples:
|
|
193
|
+
family, genus, species = sample['labels']
|
|
194
|
+
level_counts[1][family] += 1
|
|
195
|
+
level_counts[2][genus] += 1
|
|
196
|
+
level_counts[3][species] += 1
|
|
197
|
+
|
|
198
|
+
# Analyze each level
|
|
199
|
+
balance_stats = {}
|
|
200
|
+
has_severe_imbalance = False
|
|
201
|
+
|
|
202
|
+
print("\n" + "=" * 80)
|
|
203
|
+
print("CLASS BALANCE ANALYSIS")
|
|
204
|
+
print("=" * 80)
|
|
205
|
+
|
|
206
|
+
for level in [1, 2, 3]:
|
|
207
|
+
counts = level_counts[level]
|
|
208
|
+
if not counts:
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
level_name = level_names[level]
|
|
212
|
+
values = list(counts.values())
|
|
213
|
+
classes = list(counts.keys())
|
|
214
|
+
|
|
215
|
+
total = sum(values)
|
|
216
|
+
min_count = min(values)
|
|
217
|
+
max_count = max(values)
|
|
218
|
+
mean_count = np.mean(values)
|
|
219
|
+
std_count = np.std(values)
|
|
220
|
+
|
|
221
|
+
# Imbalance ratio: how many times larger is the biggest class vs smallest
|
|
222
|
+
imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
|
|
223
|
+
|
|
224
|
+
# Coefficient of variation (CV): std/mean - higher means more imbalanced
|
|
225
|
+
cv = (std_count / mean_count) * 100 if mean_count > 0 else 0
|
|
226
|
+
|
|
227
|
+
# Find underrepresented and overrepresented classes
|
|
228
|
+
underrepresented = [(c, n) for c, n in counts.items() if n < mean_count * 0.5]
|
|
229
|
+
overrepresented = [(c, n) for c, n in counts.items() if n > mean_count * 2]
|
|
230
|
+
|
|
231
|
+
# Determine severity
|
|
232
|
+
if imbalance_ratio > 10:
|
|
233
|
+
severity = "SEVERE"
|
|
234
|
+
has_severe_imbalance = True
|
|
235
|
+
elif imbalance_ratio > 5:
|
|
236
|
+
severity = "MODERATE"
|
|
237
|
+
elif imbalance_ratio > 2:
|
|
238
|
+
severity = "MILD"
|
|
239
|
+
else:
|
|
240
|
+
severity = "BALANCED"
|
|
241
|
+
|
|
242
|
+
balance_stats[level] = {
|
|
243
|
+
'counts': dict(counts),
|
|
244
|
+
'min': min_count,
|
|
245
|
+
'max': max_count,
|
|
246
|
+
'mean': mean_count,
|
|
247
|
+
'std': std_count,
|
|
248
|
+
'imbalance_ratio': imbalance_ratio,
|
|
249
|
+
'cv': cv,
|
|
250
|
+
'severity': severity,
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Print level summary
|
|
254
|
+
print(f"\n{level_name.upper()} LEVEL ({len(classes)} classes, {total} total samples)")
|
|
255
|
+
print("-" * 60)
|
|
256
|
+
|
|
257
|
+
# Sort by count for display
|
|
258
|
+
sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True)
|
|
259
|
+
|
|
260
|
+
# Calculate ideal count for visual bar
|
|
261
|
+
ideal_count = total / len(classes)
|
|
262
|
+
|
|
263
|
+
for class_name, count in sorted_counts:
|
|
264
|
+
pct = (count / total) * 100
|
|
265
|
+
bar_len = int((count / max_count) * 30)
|
|
266
|
+
bar = "█" * bar_len + "░" * (30 - bar_len)
|
|
267
|
+
|
|
268
|
+
# Mark under/over represented
|
|
269
|
+
if count < mean_count * 0.5:
|
|
270
|
+
marker = " ⚠️ LOW"
|
|
271
|
+
elif count > mean_count * 2:
|
|
272
|
+
marker = " ⚠️ HIGH"
|
|
273
|
+
else:
|
|
274
|
+
marker = ""
|
|
275
|
+
|
|
276
|
+
print(f" {class_name:<30} {bar} {count:>5} ({pct:>5.1f}%){marker}")
|
|
277
|
+
|
|
278
|
+
print(f"\n Statistics:")
|
|
279
|
+
print(f" Min: {min_count}, Max: {max_count}, Mean: {mean_count:.1f}, Std: {std_count:.1f}")
|
|
280
|
+
print(f" Imbalance ratio: {imbalance_ratio:.1f}x (max/min)")
|
|
281
|
+
print(f" Coefficient of variation: {cv:.1f}%")
|
|
282
|
+
print(f" Balance status: {severity}")
|
|
283
|
+
|
|
284
|
+
# Detailed warnings
|
|
285
|
+
if severity in ["SEVERE", "MODERATE"]:
|
|
286
|
+
print(f"\n ⚠️ WARNING: {severity} class imbalance detected at {level_name} level!")
|
|
287
|
+
|
|
288
|
+
if underrepresented:
|
|
289
|
+
print(f"\n Underrepresented classes (< 50% of mean):")
|
|
290
|
+
for c, n in sorted(underrepresented, key=lambda x: x[1]):
|
|
291
|
+
deficit = int(mean_count - n)
|
|
292
|
+
print(f" • {c}: {n} samples (need ~{deficit} more for balance)")
|
|
293
|
+
|
|
294
|
+
if overrepresented:
|
|
295
|
+
print(f"\n Overrepresented classes (> 200% of mean):")
|
|
296
|
+
for c, n in sorted(overrepresented, key=lambda x: x[1], reverse=True):
|
|
297
|
+
excess = int(n - mean_count)
|
|
298
|
+
print(f" • {c}: {n} samples ({excess} above mean)")
|
|
299
|
+
|
|
300
|
+
print("\n" + "=" * 80)
|
|
301
|
+
|
|
302
|
+
if has_severe_imbalance:
|
|
303
|
+
print("\n⚠️ SEVERE CLASS IMBALANCE DETECTED!")
|
|
304
|
+
print("-" * 60)
|
|
305
|
+
print("This can cause the following problems:")
|
|
306
|
+
print(" 1. BIASED PREDICTIONS: Model will favor majority classes")
|
|
307
|
+
print(" 2. POOR MINORITY RECALL: Rare classes may be ignored entirely")
|
|
308
|
+
print(" 3. MISLEADING ACCURACY: High accuracy doesn't mean good performance")
|
|
309
|
+
print(" 4. UNSTABLE TRAINING: Loss may be dominated by majority classes")
|
|
310
|
+
print("\nRecommended actions:")
|
|
311
|
+
print(" • Collect more images for underrepresented classes")
|
|
312
|
+
print(" • Use data augmentation more aggressively on minority classes")
|
|
313
|
+
print(" • Consider weighted loss functions (not yet implemented)")
|
|
314
|
+
print(" • Consider oversampling minority / undersampling majority classes")
|
|
315
|
+
print(" • Evaluate using per-class metrics (precision, recall, F1)")
|
|
316
|
+
print("-" * 60)
|
|
317
|
+
else:
|
|
318
|
+
print("\n✓ Class balance is acceptable for training.")
|
|
319
|
+
|
|
320
|
+
print("=" * 80 + "\n")
|
|
321
|
+
|
|
322
|
+
return balance_stats
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def get_taxonomy(species_list):
|
|
326
|
+
"""
|
|
327
|
+
Retrieves taxonomic information for a list of species from GBIF API.
|
|
328
|
+
Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
|
|
329
|
+
"""
|
|
330
|
+
taxonomy = {1: [], 2: {}, 3: {}}
|
|
331
|
+
species_to_genus = {}
|
|
332
|
+
genus_to_family = {}
|
|
333
|
+
|
|
334
|
+
species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
|
|
335
|
+
has_unknown = len(species_list_for_gbif) != len(species_list)
|
|
336
|
+
|
|
337
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
|
|
338
|
+
|
|
339
|
+
print("\nTaxonomy Results:")
|
|
340
|
+
print("-" * 80)
|
|
341
|
+
print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
|
|
342
|
+
print("-" * 80)
|
|
343
|
+
|
|
344
|
+
for species_name in species_list_for_gbif:
|
|
345
|
+
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
346
|
+
try:
|
|
347
|
+
response = requests.get(url)
|
|
348
|
+
data = response.json()
|
|
349
|
+
|
|
350
|
+
if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
|
|
351
|
+
family = data.get('family')
|
|
352
|
+
genus = data.get('genus')
|
|
353
|
+
|
|
354
|
+
if family and genus:
|
|
355
|
+
status = "OK"
|
|
356
|
+
|
|
357
|
+
print(f"{species_name:<30} {family:<20} {genus:<20} {status}")
|
|
358
|
+
|
|
359
|
+
species_to_genus[species_name] = genus
|
|
360
|
+
genus_to_family[genus] = family
|
|
361
|
+
|
|
362
|
+
if family not in taxonomy[1]:
|
|
363
|
+
taxonomy[1].append(family)
|
|
364
|
+
|
|
365
|
+
taxonomy[2][genus] = family
|
|
366
|
+
taxonomy[3][species_name] = genus
|
|
367
|
+
else:
|
|
368
|
+
error_msg = f"Species '{species_name}' found in GBIF but family and genus not found, could be spelling error in species, check GBIF"
|
|
369
|
+
logger.error(error_msg)
|
|
370
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
371
|
+
print(f"Error: {error_msg}")
|
|
372
|
+
sys.exit(1) # Stop the script
|
|
373
|
+
else:
|
|
374
|
+
error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
|
|
375
|
+
logger.error(error_msg)
|
|
376
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
377
|
+
print(f"Error: {error_msg}")
|
|
378
|
+
sys.exit(1) # Stop the script
|
|
379
|
+
|
|
380
|
+
except Exception as e:
|
|
381
|
+
error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
|
|
382
|
+
logger.error(error_msg)
|
|
383
|
+
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
384
|
+
print(f"Error: {error_msg}")
|
|
385
|
+
sys.exit(1) # Stop the script
|
|
386
|
+
|
|
387
|
+
if has_unknown:
|
|
388
|
+
unknown_family = "Unknown"
|
|
389
|
+
unknown_genus = "Unknown"
|
|
390
|
+
unknown_species = "unknown"
|
|
391
|
+
|
|
392
|
+
if unknown_family not in taxonomy[1]:
|
|
393
|
+
taxonomy[1].append(unknown_family)
|
|
394
|
+
|
|
395
|
+
taxonomy[2][unknown_genus] = unknown_family
|
|
396
|
+
taxonomy[3][unknown_species] = unknown_genus
|
|
397
|
+
|
|
398
|
+
print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
|
|
399
|
+
|
|
400
|
+
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
401
|
+
print("-" * 80)
|
|
402
|
+
|
|
403
|
+
num_families = len(taxonomy[1])
|
|
404
|
+
num_genera = len(taxonomy[2])
|
|
405
|
+
num_species = len(taxonomy[3])
|
|
406
|
+
|
|
407
|
+
print("\nFamily indices:")
|
|
408
|
+
for i, family in enumerate(taxonomy[1]):
|
|
409
|
+
print(f" {i}: {family}")
|
|
410
|
+
|
|
411
|
+
print("\nGenus indices:")
|
|
412
|
+
for i, genus in enumerate(sorted(taxonomy[2].keys())):
|
|
413
|
+
print(f" {i}: {genus}")
|
|
414
|
+
|
|
415
|
+
print("\nSpecies indices:")
|
|
416
|
+
for i, species in enumerate(species_list):
|
|
417
|
+
print(f" {i}: {species}")
|
|
418
|
+
|
|
419
|
+
logger.info(f"Taxonomy built: {num_families} families, {num_genera} genera, {num_species} species")
|
|
420
|
+
return taxonomy
|
|
421
|
+
|
|
422
|
+
def get_species_from_directory(train_dir):
|
|
423
|
+
"""
|
|
424
|
+
Extracts a list of species names from subdirectories in the training directory.
|
|
425
|
+
Returns a sorted list of species names found.
|
|
426
|
+
"""
|
|
427
|
+
if not os.path.exists(train_dir):
|
|
428
|
+
raise ValueError(f"Training directory does not exist: {train_dir}")
|
|
429
|
+
|
|
430
|
+
species_list = []
|
|
431
|
+
for item in os.listdir(train_dir):
|
|
432
|
+
item_path = os.path.join(train_dir, item)
|
|
433
|
+
if os.path.isdir(item_path):
|
|
434
|
+
species_list.append(item)
|
|
435
|
+
|
|
436
|
+
species_list.sort()
|
|
437
|
+
|
|
438
|
+
if not species_list:
|
|
439
|
+
raise ValueError(f"No species subdirectories found in {train_dir}")
|
|
440
|
+
|
|
441
|
+
logger.info(f"Found {len(species_list)} species in {train_dir}")
|
|
442
|
+
return species_list
|
|
443
|
+
|
|
444
|
+
def create_mappings(taxonomy, species_list=None):
|
|
445
|
+
"""
|
|
446
|
+
Creates mapping dictionaries from taxonomy data.
|
|
447
|
+
Returns level-to-index mapping and parent-child relationships between taxonomic levels.
|
|
448
|
+
"""
|
|
449
|
+
level_to_idx = {}
|
|
450
|
+
parent_child_relationship = {}
|
|
451
|
+
|
|
452
|
+
for level, labels in taxonomy.items():
|
|
453
|
+
if isinstance(labels, list):
|
|
454
|
+
# Level 1: Family (already sorted)
|
|
455
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
456
|
+
else: # dict for levels 2 and 3
|
|
457
|
+
if level == 3 and species_list is not None:
|
|
458
|
+
# For species, the order is determined by species_list
|
|
459
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(species_list)}
|
|
460
|
+
else:
|
|
461
|
+
# For genus (and as a fallback for species), sort alphabetically
|
|
462
|
+
sorted_keys = sorted(labels.keys())
|
|
463
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
|
|
464
|
+
|
|
465
|
+
for child, parent in labels.items():
|
|
466
|
+
if (level, parent) not in parent_child_relationship:
|
|
467
|
+
parent_child_relationship[(level, parent)] = []
|
|
468
|
+
parent_child_relationship[(level, parent)].append(child)
|
|
469
|
+
|
|
470
|
+
return level_to_idx, parent_child_relationship
|
|
471
|
+
|
|
472
|
+
class InsectDataset(Dataset):
|
|
473
|
+
"""
|
|
474
|
+
PyTorch dataset for loading and processing insect images.
|
|
475
|
+
Organizes data according to taxonomic hierarchy and validates images.
|
|
476
|
+
"""
|
|
477
|
+
def __init__(self, root_dir, transform=None, taxonomy=None, level_to_idx=None):
|
|
478
|
+
self.root_dir = root_dir
|
|
479
|
+
self.transform = transform
|
|
480
|
+
self.taxonomy = taxonomy
|
|
481
|
+
self.level_to_idx = level_to_idx
|
|
482
|
+
self.samples = []
|
|
483
|
+
|
|
484
|
+
species_to_genus = {species: genus for species, genus in taxonomy[3].items()}
|
|
485
|
+
genus_to_family = {genus: family for genus, family in taxonomy[2].items()}
|
|
486
|
+
|
|
487
|
+
for species_name in os.listdir(root_dir):
|
|
488
|
+
species_path = os.path.join(root_dir, species_name)
|
|
489
|
+
if os.path.isdir(species_path):
|
|
490
|
+
if species_name in species_to_genus:
|
|
491
|
+
genus_name = species_to_genus[species_name]
|
|
492
|
+
family_name = genus_to_family[genus_name]
|
|
493
|
+
|
|
494
|
+
for img_file in os.listdir(species_path):
|
|
495
|
+
if img_file.endswith(('.jpg', '.png', '.jpeg')):
|
|
496
|
+
img_path = os.path.join(species_path, img_file)
|
|
497
|
+
# Validate the image can be opened
|
|
498
|
+
try:
|
|
499
|
+
with Image.open(img_path) as img:
|
|
500
|
+
img.convert('RGB')
|
|
501
|
+
# Only add valid images to samples
|
|
502
|
+
self.samples.append({
|
|
503
|
+
'image_path': img_path,
|
|
504
|
+
'labels': [family_name, genus_name, species_name]
|
|
505
|
+
})
|
|
506
|
+
|
|
507
|
+
except Exception as e:
|
|
508
|
+
logger.warning(f"Skipping invalid image: {img_path} - Error: {str(e)}")
|
|
509
|
+
else:
|
|
510
|
+
logger.warning(f"Warning: Species '{species_name}' not found in taxonomy. Skipping.")
|
|
511
|
+
|
|
512
|
+
# Log statistics about valid/invalid images
|
|
513
|
+
logger.info(f"Dataset loaded with {len(self.samples)} valid images")
|
|
514
|
+
|
|
515
|
+
def __len__(self):
|
|
516
|
+
return len(self.samples)
|
|
517
|
+
|
|
518
|
+
def __getitem__(self, idx):
|
|
519
|
+
sample = self.samples[idx]
|
|
520
|
+
image = Image.open(sample['image_path']).convert('RGB')
|
|
521
|
+
|
|
522
|
+
if self.transform:
|
|
523
|
+
image = self.transform(image)
|
|
524
|
+
|
|
525
|
+
label_indices = [self.level_to_idx[level+1][label] for level, label in enumerate(sample['labels'])]
|
|
526
|
+
|
|
527
|
+
return image, torch.tensor(label_indices)
|
|
528
|
+
|
|
529
|
+
class HierarchicalInsectClassifier(nn.Module):
|
|
530
|
+
"""
|
|
531
|
+
Deep learning model for hierarchical insect classification.
|
|
532
|
+
Uses a ResNet backbone with multiple classification branches for different taxonomic levels.
|
|
533
|
+
Includes anomaly detection capabilities.
|
|
534
|
+
"""
|
|
535
|
+
def __init__(self, num_classes_per_level, level_to_idx=None, parent_child_relationship=None, backbone: str = "resnet50"):
|
|
536
|
+
super(HierarchicalInsectClassifier, self).__init__()
|
|
537
|
+
|
|
538
|
+
self.backbone = self._build_backbone(backbone)
|
|
539
|
+
self.backbone_name = backbone
|
|
540
|
+
backbone_output_features = self.backbone.fc.in_features
|
|
541
|
+
self.backbone.fc = nn.Identity()
|
|
542
|
+
|
|
543
|
+
self.branches = nn.ModuleList()
|
|
544
|
+
for num_classes in num_classes_per_level:
|
|
545
|
+
branch = nn.Sequential(
|
|
546
|
+
nn.Linear(backbone_output_features, 512),
|
|
547
|
+
nn.ReLU(),
|
|
548
|
+
nn.Dropout(0.5),
|
|
549
|
+
nn.Linear(512, num_classes)
|
|
550
|
+
)
|
|
551
|
+
self.branches.append(branch)
|
|
552
|
+
|
|
553
|
+
self.num_levels = len(num_classes_per_level)
|
|
554
|
+
|
|
555
|
+
# Store the taxonomy mappings
|
|
556
|
+
self.level_to_idx = level_to_idx
|
|
557
|
+
self.parent_child_relationship = parent_child_relationship
|
|
558
|
+
|
|
559
|
+
self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
|
|
560
|
+
self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
|
|
561
|
+
self.class_counts = [0] * sum(num_classes_per_level)
|
|
562
|
+
self.output_history = defaultdict(list)
|
|
563
|
+
|
|
564
|
+
@staticmethod
|
|
565
|
+
def _build_backbone(backbone: str):
|
|
566
|
+
name = backbone.lower()
|
|
567
|
+
if name == "resnet18":
|
|
568
|
+
return models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
|
569
|
+
if name == "resnet50":
|
|
570
|
+
return models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
571
|
+
if name == "resnet101":
|
|
572
|
+
return models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
|
|
573
|
+
raise ValueError(f"Unsupported backbone '{backbone}'. Choose from 'resnet18', 'resnet50', 'resnet101'.")
|
|
574
|
+
|
|
575
|
+
def forward(self, x):
|
|
576
|
+
R0 = self.backbone(x)
|
|
577
|
+
|
|
578
|
+
outputs = []
|
|
579
|
+
for branch in self.branches:
|
|
580
|
+
outputs.append(branch(R0))
|
|
581
|
+
|
|
582
|
+
return outputs
|
|
583
|
+
|
|
584
|
+
def predict_with_hierarchy(self, x):
|
|
585
|
+
outputs = self.forward(x)
|
|
586
|
+
predictions = []
|
|
587
|
+
confidences = []
|
|
588
|
+
is_unsure = []
|
|
589
|
+
|
|
590
|
+
level1_output = outputs[0]
|
|
591
|
+
level1_probs = torch.softmax(level1_output, dim=1)
|
|
592
|
+
level1_pred = torch.argmax(level1_output, dim=1)
|
|
593
|
+
level1_conf = torch.gather(level1_probs, 1, level1_pred.unsqueeze(1)).squeeze(1)
|
|
594
|
+
|
|
595
|
+
start_idx = 0
|
|
596
|
+
level1_unsure = self.detect_anomalies(level1_output, level1_pred, start_idx)
|
|
597
|
+
|
|
598
|
+
predictions.append(level1_pred)
|
|
599
|
+
confidences.append(level1_conf)
|
|
600
|
+
is_unsure.append(level1_unsure)
|
|
601
|
+
|
|
602
|
+
# Check if taxonomy mappings are available
|
|
603
|
+
if self.level_to_idx is None or self.parent_child_relationship is None:
|
|
604
|
+
# Return basic predictions if taxonomy isn't available
|
|
605
|
+
for level in range(1, self.num_levels):
|
|
606
|
+
level_output = outputs[level]
|
|
607
|
+
level_probs = torch.softmax(level_output, dim=1)
|
|
608
|
+
level_pred = torch.argmax(level_output, dim=1)
|
|
609
|
+
level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
|
|
610
|
+
start_idx += outputs[level-1].shape[1]
|
|
611
|
+
level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
|
|
612
|
+
|
|
613
|
+
predictions.append(level_pred)
|
|
614
|
+
confidences.append(level_conf)
|
|
615
|
+
is_unsure.append(level_unsure)
|
|
616
|
+
|
|
617
|
+
return predictions, confidences, is_unsure
|
|
618
|
+
|
|
619
|
+
# If taxonomy is available, use hierarchical constraints
|
|
620
|
+
for level in range(1, self.num_levels):
|
|
621
|
+
level_output = outputs[level]
|
|
622
|
+
level_probs = torch.softmax(level_output, dim=1)
|
|
623
|
+
level_pred = torch.argmax(level_output, dim=1)
|
|
624
|
+
level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
|
|
625
|
+
|
|
626
|
+
start_idx += outputs[level-1].shape[1]
|
|
627
|
+
level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
|
|
628
|
+
|
|
629
|
+
level_unsure_hierarchy = torch.zeros_like(level_pred, dtype=torch.bool)
|
|
630
|
+
for i in range(level_pred.shape[0]):
|
|
631
|
+
prev_level_pred_idx = predictions[level-1][i].item()
|
|
632
|
+
curr_level_pred_idx = level_pred[i].item()
|
|
633
|
+
|
|
634
|
+
prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
|
|
635
|
+
curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
|
|
636
|
+
|
|
637
|
+
if (level+1, prev_level_label) in self.parent_child_relationship:
|
|
638
|
+
valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
|
|
639
|
+
if curr_level_label not in valid_children:
|
|
640
|
+
level_unsure_hierarchy[i] = True
|
|
641
|
+
else:
|
|
642
|
+
level_unsure_hierarchy[i] = True
|
|
643
|
+
|
|
644
|
+
level_unsure = torch.logical_or(level_unsure, level_unsure_hierarchy)
|
|
645
|
+
|
|
646
|
+
predictions.append(level_pred)
|
|
647
|
+
confidences.append(level_conf)
|
|
648
|
+
is_unsure.append(level_unsure)
|
|
649
|
+
|
|
650
|
+
return predictions, confidences, is_unsure
|
|
651
|
+
|
|
652
|
+
def detect_anomalies(self, outputs, predictions, start_idx):
|
|
653
|
+
unsure = torch.zeros_like(predictions, dtype=torch.bool)
|
|
654
|
+
|
|
655
|
+
if self.training:
|
|
656
|
+
for i in range(outputs.shape[0]):
|
|
657
|
+
pred_class = predictions[i].item()
|
|
658
|
+
class_idx = start_idx + pred_class
|
|
659
|
+
self.output_history[class_idx].append(outputs[i, pred_class].item())
|
|
660
|
+
else:
|
|
661
|
+
for i in range(outputs.shape[0]):
|
|
662
|
+
pred_class = predictions[i].item()
|
|
663
|
+
class_idx = start_idx + pred_class
|
|
664
|
+
|
|
665
|
+
if len(self.output_history[class_idx]) > 0:
|
|
666
|
+
mean = np.mean(self.output_history[class_idx])
|
|
667
|
+
std = np.std(self.output_history[class_idx])
|
|
668
|
+
threshold = mean - 2 * std
|
|
669
|
+
|
|
670
|
+
if outputs[i, pred_class].item() < threshold:
|
|
671
|
+
unsure[i] = True
|
|
672
|
+
|
|
673
|
+
return unsure
|
|
674
|
+
|
|
675
|
+
def update_anomaly_stats(self):
|
|
676
|
+
for class_idx, outputs in self.output_history.items():
|
|
677
|
+
if len(outputs) > 0:
|
|
678
|
+
self.class_means[class_idx] = torch.tensor(np.mean(outputs))
|
|
679
|
+
self.class_stds[class_idx] = torch.tensor(np.std(outputs))
|
|
680
|
+
|
|
681
|
+
class HierarchicalLoss(nn.Module):
|
|
682
|
+
"""
|
|
683
|
+
Custom loss function for hierarchical classification.
|
|
684
|
+
Combines cross-entropy loss with dependency loss to enforce taxonomic constraints.
|
|
685
|
+
"""
|
|
686
|
+
def __init__(self, alpha=0.5, level_to_idx=None, parent_child_relationship=None):
|
|
687
|
+
super(HierarchicalLoss, self).__init__()
|
|
688
|
+
self.alpha = alpha
|
|
689
|
+
self.ce_loss = nn.CrossEntropyLoss()
|
|
690
|
+
self.level_to_idx = level_to_idx
|
|
691
|
+
self.parent_child_relationship = parent_child_relationship
|
|
692
|
+
|
|
693
|
+
def forward(self, outputs, targets, predictions):
|
|
694
|
+
ce_losses = []
|
|
695
|
+
for level, output in enumerate(outputs):
|
|
696
|
+
ce_losses.append(self.ce_loss(output, targets[:, level]))
|
|
697
|
+
|
|
698
|
+
total_ce_loss = sum(ce_losses)
|
|
699
|
+
|
|
700
|
+
dependency_losses = []
|
|
701
|
+
|
|
702
|
+
# Skip dependency loss calculation if taxonomy isn't available
|
|
703
|
+
if self.level_to_idx is None or self.parent_child_relationship is None:
|
|
704
|
+
return total_ce_loss, total_ce_loss, torch.zeros(1, device=outputs[0].device)
|
|
705
|
+
|
|
706
|
+
for level in range(1, len(outputs)):
|
|
707
|
+
dependency_loss = torch.zeros(1, device=outputs[0].device)
|
|
708
|
+
for i in range(targets.shape[0]):
|
|
709
|
+
prev_level_pred_idx = predictions[level-1][i].item()
|
|
710
|
+
curr_level_pred_idx = predictions[level][i].item()
|
|
711
|
+
|
|
712
|
+
prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
|
|
713
|
+
curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
|
|
714
|
+
|
|
715
|
+
is_valid = False
|
|
716
|
+
if (level+1, prev_level_label) in self.parent_child_relationship:
|
|
717
|
+
valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
|
|
718
|
+
if curr_level_label in valid_children:
|
|
719
|
+
is_valid = True
|
|
720
|
+
|
|
721
|
+
D_l = 0 if is_valid else 1
|
|
722
|
+
dependency_loss += torch.exp(torch.tensor(D_l, device=outputs[0].device)) - 1
|
|
723
|
+
|
|
724
|
+
dependency_loss /= targets.shape[0]
|
|
725
|
+
dependency_losses.append(dependency_loss)
|
|
726
|
+
|
|
727
|
+
total_dependency_loss = sum(dependency_losses) if dependency_losses else torch.zeros(1, device=outputs[0].device)
|
|
728
|
+
|
|
729
|
+
total_loss = self.alpha * total_ce_loss + (1 - self.alpha) * total_dependency_loss
|
|
730
|
+
|
|
731
|
+
return total_loss, total_ce_loss, total_dependency_loss
|
|
732
|
+
|
|
733
|
+
def get_transforms(is_training=True, img_size=640):
|
|
734
|
+
"""
|
|
735
|
+
Creates image transformation pipelines.
|
|
736
|
+
Returns different transformations for training and validation data.
|
|
737
|
+
"""
|
|
738
|
+
if is_training:
|
|
739
|
+
return transforms.Compose([
|
|
740
|
+
transforms.RandomResizedCrop(img_size),
|
|
741
|
+
transforms.RandomHorizontalFlip(),
|
|
742
|
+
transforms.RandomVerticalFlip(),
|
|
743
|
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
|
744
|
+
transforms.RandomPerspective(distortion_scale=0.2),
|
|
745
|
+
transforms.ToTensor(),
|
|
746
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
747
|
+
])
|
|
748
|
+
else:
|
|
749
|
+
return transforms.Compose([
|
|
750
|
+
transforms.Resize((img_size, img_size)), # Fixed size for all validation images
|
|
751
|
+
transforms.ToTensor(),
|
|
752
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
753
|
+
])
|
|
754
|
+
|
|
755
|
+
def train_model(model, train_loader, val_loader, criterion, optimizer, level_to_idx, parent_child_relationship, taxonomy, species_list, num_epochs=10, patience=5, best_model_path='best_multitask.pt', backbone='resnet50'):
|
|
756
|
+
"""
|
|
757
|
+
Trains the hierarchical classifier model.
|
|
758
|
+
Implements early stopping, validation, and model checkpointing.
|
|
759
|
+
"""
|
|
760
|
+
logger.info("Starting training")
|
|
761
|
+
model.to(device)
|
|
762
|
+
|
|
763
|
+
best_val_loss = float('inf')
|
|
764
|
+
epochs_without_improvement = 0
|
|
765
|
+
validation_enabled = val_loader is not None
|
|
766
|
+
|
|
767
|
+
for epoch in range(num_epochs):
|
|
768
|
+
model.train()
|
|
769
|
+
running_loss = 0.0
|
|
770
|
+
running_ce_loss = 0.0
|
|
771
|
+
running_dep_loss = 0.0
|
|
772
|
+
correct_predictions = [0] * model.num_levels
|
|
773
|
+
total_predictions = 0
|
|
774
|
+
|
|
775
|
+
train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
|
|
776
|
+
|
|
777
|
+
for batch_idx, (images, labels) in enumerate(train_pbar):
|
|
778
|
+
try:
|
|
779
|
+
images = images.to(device)
|
|
780
|
+
labels = labels.to(device)
|
|
781
|
+
|
|
782
|
+
outputs = model(images)
|
|
783
|
+
|
|
784
|
+
predictions = []
|
|
785
|
+
for output in outputs:
|
|
786
|
+
pred = torch.argmax(output, dim=1)
|
|
787
|
+
predictions.append(pred)
|
|
788
|
+
|
|
789
|
+
loss, ce_loss, dep_loss = criterion(outputs, labels, predictions)
|
|
790
|
+
|
|
791
|
+
optimizer.zero_grad()
|
|
792
|
+
loss.backward()
|
|
793
|
+
optimizer.step()
|
|
794
|
+
|
|
795
|
+
running_loss += loss.item()
|
|
796
|
+
running_ce_loss += ce_loss.item()
|
|
797
|
+
running_dep_loss += dep_loss.item() if dep_loss.numel() > 0 else 0
|
|
798
|
+
|
|
799
|
+
for level in range(model.num_levels):
|
|
800
|
+
correct_predictions[level] += (predictions[level] == labels[:, level]).sum().item()
|
|
801
|
+
total_predictions += labels.size(0)
|
|
802
|
+
|
|
803
|
+
train_pbar.set_postfix(loss=f"{loss.item():.4f}")
|
|
804
|
+
|
|
805
|
+
except Exception as e:
|
|
806
|
+
logger.error(f"Error in training batch {batch_idx}: {str(e)}")
|
|
807
|
+
continue # Skip this batch and continue with the next one
|
|
808
|
+
|
|
809
|
+
epoch_loss = running_loss / len(train_loader)
|
|
810
|
+
epoch_ce_loss = running_ce_loss / len(train_loader)
|
|
811
|
+
epoch_dep_loss = running_dep_loss / len(train_loader)
|
|
812
|
+
epoch_accuracies = [correct / total_predictions for correct in correct_predictions]
|
|
813
|
+
|
|
814
|
+
model.update_anomaly_stats()
|
|
815
|
+
|
|
816
|
+
if not validation_enabled:
|
|
817
|
+
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
818
|
+
print(f"Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dep: {epoch_dep_loss:.4f})")
|
|
819
|
+
print("Validation skipped (no valid data found).")
|
|
820
|
+
print('-' * 60)
|
|
821
|
+
continue
|
|
822
|
+
|
|
823
|
+
model.eval()
|
|
824
|
+
val_running_loss = 0.0
|
|
825
|
+
val_correct_predictions = [0] * model.num_levels
|
|
826
|
+
val_total_predictions = 0
|
|
827
|
+
val_unsure_count = [0] * model.num_levels
|
|
828
|
+
|
|
829
|
+
val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
|
|
830
|
+
|
|
831
|
+
with torch.no_grad():
|
|
832
|
+
for batch_idx, (images, labels) in enumerate(val_pbar):
|
|
833
|
+
try:
|
|
834
|
+
images = images.to(device)
|
|
835
|
+
labels = labels.to(device)
|
|
836
|
+
|
|
837
|
+
predictions, confidences, is_unsure = model.predict_with_hierarchy(images)
|
|
838
|
+
outputs = model(images)
|
|
839
|
+
|
|
840
|
+
loss, _, _ = criterion(outputs, labels, predictions)
|
|
841
|
+
val_running_loss += loss.item()
|
|
842
|
+
|
|
843
|
+
for level in range(model.num_levels):
|
|
844
|
+
correct_mask = (predictions[level] == labels[:, level]) & ~is_unsure[level]
|
|
845
|
+
val_correct_predictions[level] += correct_mask.sum().item()
|
|
846
|
+
val_unsure_count[level] += is_unsure[level].sum().item()
|
|
847
|
+
val_total_predictions += labels.size(0)
|
|
848
|
+
|
|
849
|
+
val_pbar.set_postfix(loss=f"{loss.item():.4f}")
|
|
850
|
+
|
|
851
|
+
except Exception as e:
|
|
852
|
+
logger.error(f"Error in validation batch {batch_idx}: {str(e)}")
|
|
853
|
+
continue
|
|
854
|
+
|
|
855
|
+
val_epoch_loss = val_running_loss / len(val_loader)
|
|
856
|
+
val_epoch_accuracies = [correct / val_total_predictions for correct in val_correct_predictions]
|
|
857
|
+
val_unsure_rates = [unsure / val_total_predictions for unsure in val_unsure_count]
|
|
858
|
+
|
|
859
|
+
# Print epoch summary
|
|
860
|
+
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
861
|
+
print(f"Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dep: {epoch_dep_loss:.4f})")
|
|
862
|
+
print(f"Valid Loss: {val_epoch_loss:.4f}")
|
|
863
|
+
|
|
864
|
+
for level in range(model.num_levels):
|
|
865
|
+
print(f"Level {level+1} - Train Acc: {epoch_accuracies[level]:.4f}, "
|
|
866
|
+
f"Valid Acc: {val_epoch_accuracies[level]:.4f}, "
|
|
867
|
+
f"Unsure: {val_unsure_rates[level]:.4f}")
|
|
868
|
+
print('-' * 60)
|
|
869
|
+
|
|
870
|
+
if val_epoch_loss < best_val_loss:
|
|
871
|
+
best_val_loss = val_epoch_loss
|
|
872
|
+
epochs_without_improvement = 0
|
|
873
|
+
|
|
874
|
+
torch.save({
|
|
875
|
+
'model_state_dict': model.state_dict(),
|
|
876
|
+
'taxonomy': taxonomy,
|
|
877
|
+
'level_to_idx': level_to_idx,
|
|
878
|
+
'parent_child_relationship': parent_child_relationship,
|
|
879
|
+
'species_list': species_list,
|
|
880
|
+
'backbone': backbone
|
|
881
|
+
}, best_model_path)
|
|
882
|
+
logger.info(f"Saved best model at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
|
|
883
|
+
else:
|
|
884
|
+
epochs_without_improvement += 1
|
|
885
|
+
logger.info(f"No improvement for {epochs_without_improvement} epochs. Best val loss: {best_val_loss:.4f}")
|
|
886
|
+
|
|
887
|
+
if epochs_without_improvement >= patience:
|
|
888
|
+
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
|
889
|
+
print(f"Early stopping triggered after {epoch+1} epochs")
|
|
890
|
+
break
|
|
891
|
+
|
|
892
|
+
if not validation_enabled:
|
|
893
|
+
torch.save({
|
|
894
|
+
'model_state_dict': model.state_dict(),
|
|
895
|
+
'taxonomy': taxonomy,
|
|
896
|
+
'level_to_idx': level_to_idx,
|
|
897
|
+
'parent_child_relationship': parent_child_relationship,
|
|
898
|
+
'species_list': species_list,
|
|
899
|
+
'backbone': backbone
|
|
900
|
+
}, best_model_path)
|
|
901
|
+
logger.info(f"Saved model (validation skipped) at {best_model_path}")
|
|
902
|
+
|
|
903
|
+
logger.info("Training completed successfully")
|
|
904
|
+
return model
|
|
905
|
+
|
|
906
|
+
if __name__ == '__main__':
|
|
907
|
+
species_list = [
|
|
908
|
+
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
909
|
+
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
910
|
+
"Eristalis tenax"
|
|
911
|
+
]
|
|
912
|
+
train(species_list=species_list, epochs=2)
|
|
913
|
+
|