bplusplus 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bplusplus/train.py ADDED
@@ -0,0 +1,913 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from PIL import Image
8
+ import os
9
+ import numpy as np
10
+ from collections import defaultdict
11
+ import requests
12
+ import time
13
+ import logging
14
+ from tqdm import tqdm
15
+ import sys
16
+
17
+ def train(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='input', output_dir='./output', species_list=None, num_workers=4, train_transforms=None, backbone: str = "resnet50"):
18
+ """
19
+ Main function to run the entire training pipeline.
20
+ Sets up datasets, model, training process and handles errors.
21
+
22
+ Args:
23
+ batch_size (int): Number of samples per batch. Default: 4
24
+ epochs (int): Maximum number of training epochs. Default: 30
25
+ patience (int): Early stopping patience (epochs without improvement). Default: 3
26
+ img_size (int): Target image size for training. Default: 640
27
+ data_dir (str): Directory containing train/valid subdirectories. Default: 'input'
28
+ output_dir (str): Directory to save trained model and logs. Default: './output'
29
+ species_list (list): List of species names for training. Required.
30
+ num_workers (int): Number of DataLoader worker processes.
31
+ Set to 0 to disable multiprocessing (most stable). Default: 4
32
+ train_transforms: Optional custom torchvision transforms for training data.
33
+ backbone (str): ResNet backbone to use ('resnet18', 'resnet50', 'resnet101'). Default: 'resnet50'
34
+ """
35
+ global logger, device
36
+
37
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
38
+ logger = logging.getLogger(__name__)
39
+
40
+ logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}, Num workers: {num_workers}, Backbone: {backbone}")
41
+
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+
44
+ torch.manual_seed(42)
45
+ np.random.seed(42)
46
+
47
+ learning_rate = 1.0e-4
48
+
49
+ train_dir = os.path.join(data_dir, 'train')
50
+ val_dir = os.path.join(data_dir, 'valid')
51
+
52
+ os.makedirs(output_dir, exist_ok=True)
53
+
54
+ missing_species = []
55
+ for species in species_list:
56
+ species_dir = os.path.join(train_dir, species)
57
+ if not os.path.isdir(species_dir):
58
+ missing_species.append(species)
59
+
60
+ if missing_species:
61
+ raise ValueError(f"The following species directories were not found: {missing_species}")
62
+
63
+ logger.info(f"Using {len(species_list)} species in the specified order")
64
+
65
+ taxonomy = get_taxonomy(species_list)
66
+
67
+ level_to_idx, parent_child_relationship = create_mappings(taxonomy, species_list)
68
+
69
+ num_classes_per_level = [len(taxonomy[level]) if isinstance(taxonomy[level], list)
70
+ else len(taxonomy[level].keys()) for level in sorted(taxonomy.keys())]
71
+
72
+ def _has_images(path):
73
+ for _, _, files in os.walk(path):
74
+ if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files):
75
+ return True
76
+ return False
77
+
78
+ train_dataset = InsectDataset(
79
+ root_dir=train_dir,
80
+ transform=train_transforms or get_transforms(is_training=True, img_size=img_size),
81
+ taxonomy=taxonomy,
82
+ level_to_idx=level_to_idx
83
+ )
84
+
85
+ # Analyze class balance and warn about imbalances
86
+ balance_stats = analyze_class_balance(train_dataset, taxonomy, level_to_idx)
87
+
88
+ # Log balance summary
89
+ for level, stats in balance_stats.items():
90
+ level_name = {1: "Family", 2: "Genus", 3: "Species"}[level]
91
+ logger.info(f"{level_name} balance: {stats['severity']} (ratio: {stats['imbalance_ratio']:.1f}x, CV: {stats['cv']:.1f}%)")
92
+
93
+ validation_available = os.path.isdir(val_dir) and _has_images(val_dir)
94
+ if not validation_available:
95
+ logger.warning("Validation skipped: 'valid' directory missing or contains no images.")
96
+ val_loader = None
97
+ else:
98
+ val_dataset = InsectDataset(
99
+ root_dir=val_dir,
100
+ transform=get_transforms(is_training=False, img_size=img_size),
101
+ taxonomy=taxonomy,
102
+ level_to_idx=level_to_idx
103
+ )
104
+ val_loader = DataLoader(
105
+ val_dataset,
106
+ batch_size=batch_size,
107
+ shuffle=False,
108
+ num_workers=num_workers
109
+ )
110
+
111
+ train_loader = DataLoader(
112
+ train_dataset,
113
+ batch_size=batch_size,
114
+ shuffle=True,
115
+ num_workers=num_workers
116
+ )
117
+
118
+ try:
119
+ logger.info("Initializing model...")
120
+ model = HierarchicalInsectClassifier(
121
+ num_classes_per_level=num_classes_per_level,
122
+ level_to_idx=level_to_idx,
123
+ parent_child_relationship=parent_child_relationship,
124
+ backbone=backbone
125
+ )
126
+ logger.info(f"Model structure initialized with {sum(p.numel() for p in model.parameters())} parameters")
127
+
128
+ logger.info("Setting up loss function and optimizer...")
129
+ criterion = HierarchicalLoss(
130
+ alpha=0.5,
131
+ level_to_idx=level_to_idx,
132
+ parent_child_relationship=parent_child_relationship
133
+ )
134
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
135
+
136
+ logger.info("Testing model with a dummy input...")
137
+ model.to(device)
138
+ dummy_input = torch.randn(1, 3, img_size, img_size).to(device)
139
+ with torch.no_grad():
140
+ try:
141
+ test_outputs = model(dummy_input)
142
+ logger.info(f"Forward pass test successful, output shapes: {[out.shape for out in test_outputs]}")
143
+ except Exception as e:
144
+ logger.error(f"Forward pass test failed: {str(e)}")
145
+ raise
146
+
147
+ logger.info("Starting training process...")
148
+ best_model_path = os.path.join(output_dir, 'best_multitask.pt')
149
+
150
+ trained_model = train_model(
151
+ model=model,
152
+ train_loader=train_loader,
153
+ val_loader=val_loader,
154
+ criterion=criterion,
155
+ optimizer=optimizer,
156
+ level_to_idx=level_to_idx,
157
+ parent_child_relationship=parent_child_relationship,
158
+ taxonomy=taxonomy,
159
+ species_list=species_list,
160
+ num_epochs=epochs,
161
+ patience=patience,
162
+ best_model_path=best_model_path,
163
+ backbone=backbone
164
+ )
165
+
166
+ logger.info("Model saved successfully with taxonomy information!")
167
+ print("Model saved successfully with taxonomy information!")
168
+
169
+ return trained_model, taxonomy, level_to_idx, parent_child_relationship
170
+
171
+ except Exception as e:
172
+ logger.error(f"Critical error during model setup or training: {str(e)}")
173
+ logger.exception("Stack trace:")
174
+ raise
175
+
176
+ def analyze_class_balance(dataset, taxonomy, level_to_idx):
177
+ """
178
+ Analyze class balance across all taxonomic levels and warn about imbalances.
179
+
180
+ Args:
181
+ dataset: InsectDataset instance
182
+ taxonomy: Taxonomy dictionary
183
+ level_to_idx: Level to index mapping
184
+
185
+ Returns:
186
+ dict: Balance statistics for each level
187
+ """
188
+ # Count samples per class at each level
189
+ level_names = {1: "Family", 2: "Genus", 3: "Species"}
190
+ level_counts = {1: defaultdict(int), 2: defaultdict(int), 3: defaultdict(int)}
191
+
192
+ for sample in dataset.samples:
193
+ family, genus, species = sample['labels']
194
+ level_counts[1][family] += 1
195
+ level_counts[2][genus] += 1
196
+ level_counts[3][species] += 1
197
+
198
+ # Analyze each level
199
+ balance_stats = {}
200
+ has_severe_imbalance = False
201
+
202
+ print("\n" + "=" * 80)
203
+ print("CLASS BALANCE ANALYSIS")
204
+ print("=" * 80)
205
+
206
+ for level in [1, 2, 3]:
207
+ counts = level_counts[level]
208
+ if not counts:
209
+ continue
210
+
211
+ level_name = level_names[level]
212
+ values = list(counts.values())
213
+ classes = list(counts.keys())
214
+
215
+ total = sum(values)
216
+ min_count = min(values)
217
+ max_count = max(values)
218
+ mean_count = np.mean(values)
219
+ std_count = np.std(values)
220
+
221
+ # Imbalance ratio: how many times larger is the biggest class vs smallest
222
+ imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
223
+
224
+ # Coefficient of variation (CV): std/mean - higher means more imbalanced
225
+ cv = (std_count / mean_count) * 100 if mean_count > 0 else 0
226
+
227
+ # Find underrepresented and overrepresented classes
228
+ underrepresented = [(c, n) for c, n in counts.items() if n < mean_count * 0.5]
229
+ overrepresented = [(c, n) for c, n in counts.items() if n > mean_count * 2]
230
+
231
+ # Determine severity
232
+ if imbalance_ratio > 10:
233
+ severity = "SEVERE"
234
+ has_severe_imbalance = True
235
+ elif imbalance_ratio > 5:
236
+ severity = "MODERATE"
237
+ elif imbalance_ratio > 2:
238
+ severity = "MILD"
239
+ else:
240
+ severity = "BALANCED"
241
+
242
+ balance_stats[level] = {
243
+ 'counts': dict(counts),
244
+ 'min': min_count,
245
+ 'max': max_count,
246
+ 'mean': mean_count,
247
+ 'std': std_count,
248
+ 'imbalance_ratio': imbalance_ratio,
249
+ 'cv': cv,
250
+ 'severity': severity,
251
+ }
252
+
253
+ # Print level summary
254
+ print(f"\n{level_name.upper()} LEVEL ({len(classes)} classes, {total} total samples)")
255
+ print("-" * 60)
256
+
257
+ # Sort by count for display
258
+ sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True)
259
+
260
+ # Calculate ideal count for visual bar
261
+ ideal_count = total / len(classes)
262
+
263
+ for class_name, count in sorted_counts:
264
+ pct = (count / total) * 100
265
+ bar_len = int((count / max_count) * 30)
266
+ bar = "█" * bar_len + "░" * (30 - bar_len)
267
+
268
+ # Mark under/over represented
269
+ if count < mean_count * 0.5:
270
+ marker = " ⚠️ LOW"
271
+ elif count > mean_count * 2:
272
+ marker = " ⚠️ HIGH"
273
+ else:
274
+ marker = ""
275
+
276
+ print(f" {class_name:<30} {bar} {count:>5} ({pct:>5.1f}%){marker}")
277
+
278
+ print(f"\n Statistics:")
279
+ print(f" Min: {min_count}, Max: {max_count}, Mean: {mean_count:.1f}, Std: {std_count:.1f}")
280
+ print(f" Imbalance ratio: {imbalance_ratio:.1f}x (max/min)")
281
+ print(f" Coefficient of variation: {cv:.1f}%")
282
+ print(f" Balance status: {severity}")
283
+
284
+ # Detailed warnings
285
+ if severity in ["SEVERE", "MODERATE"]:
286
+ print(f"\n ⚠️ WARNING: {severity} class imbalance detected at {level_name} level!")
287
+
288
+ if underrepresented:
289
+ print(f"\n Underrepresented classes (< 50% of mean):")
290
+ for c, n in sorted(underrepresented, key=lambda x: x[1]):
291
+ deficit = int(mean_count - n)
292
+ print(f" • {c}: {n} samples (need ~{deficit} more for balance)")
293
+
294
+ if overrepresented:
295
+ print(f"\n Overrepresented classes (> 200% of mean):")
296
+ for c, n in sorted(overrepresented, key=lambda x: x[1], reverse=True):
297
+ excess = int(n - mean_count)
298
+ print(f" • {c}: {n} samples ({excess} above mean)")
299
+
300
+ print("\n" + "=" * 80)
301
+
302
+ if has_severe_imbalance:
303
+ print("\n⚠️ SEVERE CLASS IMBALANCE DETECTED!")
304
+ print("-" * 60)
305
+ print("This can cause the following problems:")
306
+ print(" 1. BIASED PREDICTIONS: Model will favor majority classes")
307
+ print(" 2. POOR MINORITY RECALL: Rare classes may be ignored entirely")
308
+ print(" 3. MISLEADING ACCURACY: High accuracy doesn't mean good performance")
309
+ print(" 4. UNSTABLE TRAINING: Loss may be dominated by majority classes")
310
+ print("\nRecommended actions:")
311
+ print(" • Collect more images for underrepresented classes")
312
+ print(" • Use data augmentation more aggressively on minority classes")
313
+ print(" • Consider weighted loss functions (not yet implemented)")
314
+ print(" • Consider oversampling minority / undersampling majority classes")
315
+ print(" • Evaluate using per-class metrics (precision, recall, F1)")
316
+ print("-" * 60)
317
+ else:
318
+ print("\n✓ Class balance is acceptable for training.")
319
+
320
+ print("=" * 80 + "\n")
321
+
322
+ return balance_stats
323
+
324
+
325
+ def get_taxonomy(species_list):
326
+ """
327
+ Retrieves taxonomic information for a list of species from GBIF API.
328
+ Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
329
+ """
330
+ taxonomy = {1: [], 2: {}, 3: {}}
331
+ species_to_genus = {}
332
+ genus_to_family = {}
333
+
334
+ species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
335
+ has_unknown = len(species_list_for_gbif) != len(species_list)
336
+
337
+ logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
338
+
339
+ print("\nTaxonomy Results:")
340
+ print("-" * 80)
341
+ print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
342
+ print("-" * 80)
343
+
344
+ for species_name in species_list_for_gbif:
345
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
346
+ try:
347
+ response = requests.get(url)
348
+ data = response.json()
349
+
350
+ if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
351
+ family = data.get('family')
352
+ genus = data.get('genus')
353
+
354
+ if family and genus:
355
+ status = "OK"
356
+
357
+ print(f"{species_name:<30} {family:<20} {genus:<20} {status}")
358
+
359
+ species_to_genus[species_name] = genus
360
+ genus_to_family[genus] = family
361
+
362
+ if family not in taxonomy[1]:
363
+ taxonomy[1].append(family)
364
+
365
+ taxonomy[2][genus] = family
366
+ taxonomy[3][species_name] = genus
367
+ else:
368
+ error_msg = f"Species '{species_name}' found in GBIF but family and genus not found, could be spelling error in species, check GBIF"
369
+ logger.error(error_msg)
370
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
371
+ print(f"Error: {error_msg}")
372
+ sys.exit(1) # Stop the script
373
+ else:
374
+ error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
375
+ logger.error(error_msg)
376
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
377
+ print(f"Error: {error_msg}")
378
+ sys.exit(1) # Stop the script
379
+
380
+ except Exception as e:
381
+ error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
382
+ logger.error(error_msg)
383
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
384
+ print(f"Error: {error_msg}")
385
+ sys.exit(1) # Stop the script
386
+
387
+ if has_unknown:
388
+ unknown_family = "Unknown"
389
+ unknown_genus = "Unknown"
390
+ unknown_species = "unknown"
391
+
392
+ if unknown_family not in taxonomy[1]:
393
+ taxonomy[1].append(unknown_family)
394
+
395
+ taxonomy[2][unknown_genus] = unknown_family
396
+ taxonomy[3][unknown_species] = unknown_genus
397
+
398
+ print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
399
+
400
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
401
+ print("-" * 80)
402
+
403
+ num_families = len(taxonomy[1])
404
+ num_genera = len(taxonomy[2])
405
+ num_species = len(taxonomy[3])
406
+
407
+ print("\nFamily indices:")
408
+ for i, family in enumerate(taxonomy[1]):
409
+ print(f" {i}: {family}")
410
+
411
+ print("\nGenus indices:")
412
+ for i, genus in enumerate(sorted(taxonomy[2].keys())):
413
+ print(f" {i}: {genus}")
414
+
415
+ print("\nSpecies indices:")
416
+ for i, species in enumerate(species_list):
417
+ print(f" {i}: {species}")
418
+
419
+ logger.info(f"Taxonomy built: {num_families} families, {num_genera} genera, {num_species} species")
420
+ return taxonomy
421
+
422
+ def get_species_from_directory(train_dir):
423
+ """
424
+ Extracts a list of species names from subdirectories in the training directory.
425
+ Returns a sorted list of species names found.
426
+ """
427
+ if not os.path.exists(train_dir):
428
+ raise ValueError(f"Training directory does not exist: {train_dir}")
429
+
430
+ species_list = []
431
+ for item in os.listdir(train_dir):
432
+ item_path = os.path.join(train_dir, item)
433
+ if os.path.isdir(item_path):
434
+ species_list.append(item)
435
+
436
+ species_list.sort()
437
+
438
+ if not species_list:
439
+ raise ValueError(f"No species subdirectories found in {train_dir}")
440
+
441
+ logger.info(f"Found {len(species_list)} species in {train_dir}")
442
+ return species_list
443
+
444
+ def create_mappings(taxonomy, species_list=None):
445
+ """
446
+ Creates mapping dictionaries from taxonomy data.
447
+ Returns level-to-index mapping and parent-child relationships between taxonomic levels.
448
+ """
449
+ level_to_idx = {}
450
+ parent_child_relationship = {}
451
+
452
+ for level, labels in taxonomy.items():
453
+ if isinstance(labels, list):
454
+ # Level 1: Family (already sorted)
455
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
456
+ else: # dict for levels 2 and 3
457
+ if level == 3 and species_list is not None:
458
+ # For species, the order is determined by species_list
459
+ level_to_idx[level] = {label: idx for idx, label in enumerate(species_list)}
460
+ else:
461
+ # For genus (and as a fallback for species), sort alphabetically
462
+ sorted_keys = sorted(labels.keys())
463
+ level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
464
+
465
+ for child, parent in labels.items():
466
+ if (level, parent) not in parent_child_relationship:
467
+ parent_child_relationship[(level, parent)] = []
468
+ parent_child_relationship[(level, parent)].append(child)
469
+
470
+ return level_to_idx, parent_child_relationship
471
+
472
+ class InsectDataset(Dataset):
473
+ """
474
+ PyTorch dataset for loading and processing insect images.
475
+ Organizes data according to taxonomic hierarchy and validates images.
476
+ """
477
+ def __init__(self, root_dir, transform=None, taxonomy=None, level_to_idx=None):
478
+ self.root_dir = root_dir
479
+ self.transform = transform
480
+ self.taxonomy = taxonomy
481
+ self.level_to_idx = level_to_idx
482
+ self.samples = []
483
+
484
+ species_to_genus = {species: genus for species, genus in taxonomy[3].items()}
485
+ genus_to_family = {genus: family for genus, family in taxonomy[2].items()}
486
+
487
+ for species_name in os.listdir(root_dir):
488
+ species_path = os.path.join(root_dir, species_name)
489
+ if os.path.isdir(species_path):
490
+ if species_name in species_to_genus:
491
+ genus_name = species_to_genus[species_name]
492
+ family_name = genus_to_family[genus_name]
493
+
494
+ for img_file in os.listdir(species_path):
495
+ if img_file.endswith(('.jpg', '.png', '.jpeg')):
496
+ img_path = os.path.join(species_path, img_file)
497
+ # Validate the image can be opened
498
+ try:
499
+ with Image.open(img_path) as img:
500
+ img.convert('RGB')
501
+ # Only add valid images to samples
502
+ self.samples.append({
503
+ 'image_path': img_path,
504
+ 'labels': [family_name, genus_name, species_name]
505
+ })
506
+
507
+ except Exception as e:
508
+ logger.warning(f"Skipping invalid image: {img_path} - Error: {str(e)}")
509
+ else:
510
+ logger.warning(f"Warning: Species '{species_name}' not found in taxonomy. Skipping.")
511
+
512
+ # Log statistics about valid/invalid images
513
+ logger.info(f"Dataset loaded with {len(self.samples)} valid images")
514
+
515
+ def __len__(self):
516
+ return len(self.samples)
517
+
518
+ def __getitem__(self, idx):
519
+ sample = self.samples[idx]
520
+ image = Image.open(sample['image_path']).convert('RGB')
521
+
522
+ if self.transform:
523
+ image = self.transform(image)
524
+
525
+ label_indices = [self.level_to_idx[level+1][label] for level, label in enumerate(sample['labels'])]
526
+
527
+ return image, torch.tensor(label_indices)
528
+
529
+ class HierarchicalInsectClassifier(nn.Module):
530
+ """
531
+ Deep learning model for hierarchical insect classification.
532
+ Uses a ResNet backbone with multiple classification branches for different taxonomic levels.
533
+ Includes anomaly detection capabilities.
534
+ """
535
+ def __init__(self, num_classes_per_level, level_to_idx=None, parent_child_relationship=None, backbone: str = "resnet50"):
536
+ super(HierarchicalInsectClassifier, self).__init__()
537
+
538
+ self.backbone = self._build_backbone(backbone)
539
+ self.backbone_name = backbone
540
+ backbone_output_features = self.backbone.fc.in_features
541
+ self.backbone.fc = nn.Identity()
542
+
543
+ self.branches = nn.ModuleList()
544
+ for num_classes in num_classes_per_level:
545
+ branch = nn.Sequential(
546
+ nn.Linear(backbone_output_features, 512),
547
+ nn.ReLU(),
548
+ nn.Dropout(0.5),
549
+ nn.Linear(512, num_classes)
550
+ )
551
+ self.branches.append(branch)
552
+
553
+ self.num_levels = len(num_classes_per_level)
554
+
555
+ # Store the taxonomy mappings
556
+ self.level_to_idx = level_to_idx
557
+ self.parent_child_relationship = parent_child_relationship
558
+
559
+ self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
560
+ self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
561
+ self.class_counts = [0] * sum(num_classes_per_level)
562
+ self.output_history = defaultdict(list)
563
+
564
+ @staticmethod
565
+ def _build_backbone(backbone: str):
566
+ name = backbone.lower()
567
+ if name == "resnet18":
568
+ return models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
569
+ if name == "resnet50":
570
+ return models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
571
+ if name == "resnet101":
572
+ return models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
573
+ raise ValueError(f"Unsupported backbone '{backbone}'. Choose from 'resnet18', 'resnet50', 'resnet101'.")
574
+
575
+ def forward(self, x):
576
+ R0 = self.backbone(x)
577
+
578
+ outputs = []
579
+ for branch in self.branches:
580
+ outputs.append(branch(R0))
581
+
582
+ return outputs
583
+
584
+ def predict_with_hierarchy(self, x):
585
+ outputs = self.forward(x)
586
+ predictions = []
587
+ confidences = []
588
+ is_unsure = []
589
+
590
+ level1_output = outputs[0]
591
+ level1_probs = torch.softmax(level1_output, dim=1)
592
+ level1_pred = torch.argmax(level1_output, dim=1)
593
+ level1_conf = torch.gather(level1_probs, 1, level1_pred.unsqueeze(1)).squeeze(1)
594
+
595
+ start_idx = 0
596
+ level1_unsure = self.detect_anomalies(level1_output, level1_pred, start_idx)
597
+
598
+ predictions.append(level1_pred)
599
+ confidences.append(level1_conf)
600
+ is_unsure.append(level1_unsure)
601
+
602
+ # Check if taxonomy mappings are available
603
+ if self.level_to_idx is None or self.parent_child_relationship is None:
604
+ # Return basic predictions if taxonomy isn't available
605
+ for level in range(1, self.num_levels):
606
+ level_output = outputs[level]
607
+ level_probs = torch.softmax(level_output, dim=1)
608
+ level_pred = torch.argmax(level_output, dim=1)
609
+ level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
610
+ start_idx += outputs[level-1].shape[1]
611
+ level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
612
+
613
+ predictions.append(level_pred)
614
+ confidences.append(level_conf)
615
+ is_unsure.append(level_unsure)
616
+
617
+ return predictions, confidences, is_unsure
618
+
619
+ # If taxonomy is available, use hierarchical constraints
620
+ for level in range(1, self.num_levels):
621
+ level_output = outputs[level]
622
+ level_probs = torch.softmax(level_output, dim=1)
623
+ level_pred = torch.argmax(level_output, dim=1)
624
+ level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
625
+
626
+ start_idx += outputs[level-1].shape[1]
627
+ level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
628
+
629
+ level_unsure_hierarchy = torch.zeros_like(level_pred, dtype=torch.bool)
630
+ for i in range(level_pred.shape[0]):
631
+ prev_level_pred_idx = predictions[level-1][i].item()
632
+ curr_level_pred_idx = level_pred[i].item()
633
+
634
+ prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
635
+ curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
636
+
637
+ if (level+1, prev_level_label) in self.parent_child_relationship:
638
+ valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
639
+ if curr_level_label not in valid_children:
640
+ level_unsure_hierarchy[i] = True
641
+ else:
642
+ level_unsure_hierarchy[i] = True
643
+
644
+ level_unsure = torch.logical_or(level_unsure, level_unsure_hierarchy)
645
+
646
+ predictions.append(level_pred)
647
+ confidences.append(level_conf)
648
+ is_unsure.append(level_unsure)
649
+
650
+ return predictions, confidences, is_unsure
651
+
652
+ def detect_anomalies(self, outputs, predictions, start_idx):
653
+ unsure = torch.zeros_like(predictions, dtype=torch.bool)
654
+
655
+ if self.training:
656
+ for i in range(outputs.shape[0]):
657
+ pred_class = predictions[i].item()
658
+ class_idx = start_idx + pred_class
659
+ self.output_history[class_idx].append(outputs[i, pred_class].item())
660
+ else:
661
+ for i in range(outputs.shape[0]):
662
+ pred_class = predictions[i].item()
663
+ class_idx = start_idx + pred_class
664
+
665
+ if len(self.output_history[class_idx]) > 0:
666
+ mean = np.mean(self.output_history[class_idx])
667
+ std = np.std(self.output_history[class_idx])
668
+ threshold = mean - 2 * std
669
+
670
+ if outputs[i, pred_class].item() < threshold:
671
+ unsure[i] = True
672
+
673
+ return unsure
674
+
675
+ def update_anomaly_stats(self):
676
+ for class_idx, outputs in self.output_history.items():
677
+ if len(outputs) > 0:
678
+ self.class_means[class_idx] = torch.tensor(np.mean(outputs))
679
+ self.class_stds[class_idx] = torch.tensor(np.std(outputs))
680
+
681
+ class HierarchicalLoss(nn.Module):
682
+ """
683
+ Custom loss function for hierarchical classification.
684
+ Combines cross-entropy loss with dependency loss to enforce taxonomic constraints.
685
+ """
686
+ def __init__(self, alpha=0.5, level_to_idx=None, parent_child_relationship=None):
687
+ super(HierarchicalLoss, self).__init__()
688
+ self.alpha = alpha
689
+ self.ce_loss = nn.CrossEntropyLoss()
690
+ self.level_to_idx = level_to_idx
691
+ self.parent_child_relationship = parent_child_relationship
692
+
693
+ def forward(self, outputs, targets, predictions):
694
+ ce_losses = []
695
+ for level, output in enumerate(outputs):
696
+ ce_losses.append(self.ce_loss(output, targets[:, level]))
697
+
698
+ total_ce_loss = sum(ce_losses)
699
+
700
+ dependency_losses = []
701
+
702
+ # Skip dependency loss calculation if taxonomy isn't available
703
+ if self.level_to_idx is None or self.parent_child_relationship is None:
704
+ return total_ce_loss, total_ce_loss, torch.zeros(1, device=outputs[0].device)
705
+
706
+ for level in range(1, len(outputs)):
707
+ dependency_loss = torch.zeros(1, device=outputs[0].device)
708
+ for i in range(targets.shape[0]):
709
+ prev_level_pred_idx = predictions[level-1][i].item()
710
+ curr_level_pred_idx = predictions[level][i].item()
711
+
712
+ prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
713
+ curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
714
+
715
+ is_valid = False
716
+ if (level+1, prev_level_label) in self.parent_child_relationship:
717
+ valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
718
+ if curr_level_label in valid_children:
719
+ is_valid = True
720
+
721
+ D_l = 0 if is_valid else 1
722
+ dependency_loss += torch.exp(torch.tensor(D_l, device=outputs[0].device)) - 1
723
+
724
+ dependency_loss /= targets.shape[0]
725
+ dependency_losses.append(dependency_loss)
726
+
727
+ total_dependency_loss = sum(dependency_losses) if dependency_losses else torch.zeros(1, device=outputs[0].device)
728
+
729
+ total_loss = self.alpha * total_ce_loss + (1 - self.alpha) * total_dependency_loss
730
+
731
+ return total_loss, total_ce_loss, total_dependency_loss
732
+
733
+ def get_transforms(is_training=True, img_size=640):
734
+ """
735
+ Creates image transformation pipelines.
736
+ Returns different transformations for training and validation data.
737
+ """
738
+ if is_training:
739
+ return transforms.Compose([
740
+ transforms.RandomResizedCrop(img_size),
741
+ transforms.RandomHorizontalFlip(),
742
+ transforms.RandomVerticalFlip(),
743
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
744
+ transforms.RandomPerspective(distortion_scale=0.2),
745
+ transforms.ToTensor(),
746
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
747
+ ])
748
+ else:
749
+ return transforms.Compose([
750
+ transforms.Resize((img_size, img_size)), # Fixed size for all validation images
751
+ transforms.ToTensor(),
752
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
753
+ ])
754
+
755
+ def train_model(model, train_loader, val_loader, criterion, optimizer, level_to_idx, parent_child_relationship, taxonomy, species_list, num_epochs=10, patience=5, best_model_path='best_multitask.pt', backbone='resnet50'):
756
+ """
757
+ Trains the hierarchical classifier model.
758
+ Implements early stopping, validation, and model checkpointing.
759
+ """
760
+ logger.info("Starting training")
761
+ model.to(device)
762
+
763
+ best_val_loss = float('inf')
764
+ epochs_without_improvement = 0
765
+ validation_enabled = val_loader is not None
766
+
767
+ for epoch in range(num_epochs):
768
+ model.train()
769
+ running_loss = 0.0
770
+ running_ce_loss = 0.0
771
+ running_dep_loss = 0.0
772
+ correct_predictions = [0] * model.num_levels
773
+ total_predictions = 0
774
+
775
+ train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
776
+
777
+ for batch_idx, (images, labels) in enumerate(train_pbar):
778
+ try:
779
+ images = images.to(device)
780
+ labels = labels.to(device)
781
+
782
+ outputs = model(images)
783
+
784
+ predictions = []
785
+ for output in outputs:
786
+ pred = torch.argmax(output, dim=1)
787
+ predictions.append(pred)
788
+
789
+ loss, ce_loss, dep_loss = criterion(outputs, labels, predictions)
790
+
791
+ optimizer.zero_grad()
792
+ loss.backward()
793
+ optimizer.step()
794
+
795
+ running_loss += loss.item()
796
+ running_ce_loss += ce_loss.item()
797
+ running_dep_loss += dep_loss.item() if dep_loss.numel() > 0 else 0
798
+
799
+ for level in range(model.num_levels):
800
+ correct_predictions[level] += (predictions[level] == labels[:, level]).sum().item()
801
+ total_predictions += labels.size(0)
802
+
803
+ train_pbar.set_postfix(loss=f"{loss.item():.4f}")
804
+
805
+ except Exception as e:
806
+ logger.error(f"Error in training batch {batch_idx}: {str(e)}")
807
+ continue # Skip this batch and continue with the next one
808
+
809
+ epoch_loss = running_loss / len(train_loader)
810
+ epoch_ce_loss = running_ce_loss / len(train_loader)
811
+ epoch_dep_loss = running_dep_loss / len(train_loader)
812
+ epoch_accuracies = [correct / total_predictions for correct in correct_predictions]
813
+
814
+ model.update_anomaly_stats()
815
+
816
+ if not validation_enabled:
817
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
818
+ print(f"Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dep: {epoch_dep_loss:.4f})")
819
+ print("Validation skipped (no valid data found).")
820
+ print('-' * 60)
821
+ continue
822
+
823
+ model.eval()
824
+ val_running_loss = 0.0
825
+ val_correct_predictions = [0] * model.num_levels
826
+ val_total_predictions = 0
827
+ val_unsure_count = [0] * model.num_levels
828
+
829
+ val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
830
+
831
+ with torch.no_grad():
832
+ for batch_idx, (images, labels) in enumerate(val_pbar):
833
+ try:
834
+ images = images.to(device)
835
+ labels = labels.to(device)
836
+
837
+ predictions, confidences, is_unsure = model.predict_with_hierarchy(images)
838
+ outputs = model(images)
839
+
840
+ loss, _, _ = criterion(outputs, labels, predictions)
841
+ val_running_loss += loss.item()
842
+
843
+ for level in range(model.num_levels):
844
+ correct_mask = (predictions[level] == labels[:, level]) & ~is_unsure[level]
845
+ val_correct_predictions[level] += correct_mask.sum().item()
846
+ val_unsure_count[level] += is_unsure[level].sum().item()
847
+ val_total_predictions += labels.size(0)
848
+
849
+ val_pbar.set_postfix(loss=f"{loss.item():.4f}")
850
+
851
+ except Exception as e:
852
+ logger.error(f"Error in validation batch {batch_idx}: {str(e)}")
853
+ continue
854
+
855
+ val_epoch_loss = val_running_loss / len(val_loader)
856
+ val_epoch_accuracies = [correct / val_total_predictions for correct in val_correct_predictions]
857
+ val_unsure_rates = [unsure / val_total_predictions for unsure in val_unsure_count]
858
+
859
+ # Print epoch summary
860
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
861
+ print(f"Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dep: {epoch_dep_loss:.4f})")
862
+ print(f"Valid Loss: {val_epoch_loss:.4f}")
863
+
864
+ for level in range(model.num_levels):
865
+ print(f"Level {level+1} - Train Acc: {epoch_accuracies[level]:.4f}, "
866
+ f"Valid Acc: {val_epoch_accuracies[level]:.4f}, "
867
+ f"Unsure: {val_unsure_rates[level]:.4f}")
868
+ print('-' * 60)
869
+
870
+ if val_epoch_loss < best_val_loss:
871
+ best_val_loss = val_epoch_loss
872
+ epochs_without_improvement = 0
873
+
874
+ torch.save({
875
+ 'model_state_dict': model.state_dict(),
876
+ 'taxonomy': taxonomy,
877
+ 'level_to_idx': level_to_idx,
878
+ 'parent_child_relationship': parent_child_relationship,
879
+ 'species_list': species_list,
880
+ 'backbone': backbone
881
+ }, best_model_path)
882
+ logger.info(f"Saved best model at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
883
+ else:
884
+ epochs_without_improvement += 1
885
+ logger.info(f"No improvement for {epochs_without_improvement} epochs. Best val loss: {best_val_loss:.4f}")
886
+
887
+ if epochs_without_improvement >= patience:
888
+ logger.info(f"Early stopping triggered after {epoch+1} epochs")
889
+ print(f"Early stopping triggered after {epoch+1} epochs")
890
+ break
891
+
892
+ if not validation_enabled:
893
+ torch.save({
894
+ 'model_state_dict': model.state_dict(),
895
+ 'taxonomy': taxonomy,
896
+ 'level_to_idx': level_to_idx,
897
+ 'parent_child_relationship': parent_child_relationship,
898
+ 'species_list': species_list,
899
+ 'backbone': backbone
900
+ }, best_model_path)
901
+ logger.info(f"Saved model (validation skipped) at {best_model_path}")
902
+
903
+ logger.info("Training completed successfully")
904
+ return model
905
+
906
+ if __name__ == '__main__':
907
+ species_list = [
908
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
909
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
910
+ "Eristalis tenax"
911
+ ]
912
+ train(species_list=species_list, epochs=2)
913
+