bplusplus 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

Files changed (97) hide show
  1. bplusplus/__init__.py +4 -2
  2. bplusplus/collect.py +72 -3
  3. bplusplus/hierarchical/test.py +670 -0
  4. bplusplus/hierarchical/train.py +676 -0
  5. bplusplus/prepare.py +236 -71
  6. bplusplus/resnet/test.py +473 -0
  7. bplusplus/resnet/train.py +329 -0
  8. bplusplus-1.2.1.dist-info/METADATA +252 -0
  9. bplusplus-1.2.1.dist-info/RECORD +12 -0
  10. bplusplus/yolov5detect/__init__.py +0 -1
  11. bplusplus/yolov5detect/detect.py +0 -444
  12. bplusplus/yolov5detect/export.py +0 -1530
  13. bplusplus/yolov5detect/insect.yaml +0 -8
  14. bplusplus/yolov5detect/models/__init__.py +0 -0
  15. bplusplus/yolov5detect/models/common.py +0 -1109
  16. bplusplus/yolov5detect/models/experimental.py +0 -130
  17. bplusplus/yolov5detect/models/hub/anchors.yaml +0 -56
  18. bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +0 -52
  19. bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +0 -42
  20. bplusplus/yolov5detect/models/hub/yolov3.yaml +0 -52
  21. bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +0 -49
  22. bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +0 -43
  23. bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +0 -55
  24. bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +0 -42
  25. bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +0 -57
  26. bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +0 -68
  27. bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +0 -49
  28. bplusplus/yolov5detect/models/hub/yolov5l6.yaml +0 -61
  29. bplusplus/yolov5detect/models/hub/yolov5m6.yaml +0 -61
  30. bplusplus/yolov5detect/models/hub/yolov5n6.yaml +0 -61
  31. bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +0 -50
  32. bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +0 -49
  33. bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +0 -49
  34. bplusplus/yolov5detect/models/hub/yolov5s6.yaml +0 -61
  35. bplusplus/yolov5detect/models/hub/yolov5x6.yaml +0 -61
  36. bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +0 -49
  37. bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +0 -49
  38. bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +0 -49
  39. bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +0 -49
  40. bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +0 -49
  41. bplusplus/yolov5detect/models/tf.py +0 -797
  42. bplusplus/yolov5detect/models/yolo.py +0 -495
  43. bplusplus/yolov5detect/models/yolov5l.yaml +0 -49
  44. bplusplus/yolov5detect/models/yolov5m.yaml +0 -49
  45. bplusplus/yolov5detect/models/yolov5n.yaml +0 -49
  46. bplusplus/yolov5detect/models/yolov5s.yaml +0 -49
  47. bplusplus/yolov5detect/models/yolov5x.yaml +0 -49
  48. bplusplus/yolov5detect/utils/__init__.py +0 -97
  49. bplusplus/yolov5detect/utils/activations.py +0 -134
  50. bplusplus/yolov5detect/utils/augmentations.py +0 -448
  51. bplusplus/yolov5detect/utils/autoanchor.py +0 -175
  52. bplusplus/yolov5detect/utils/autobatch.py +0 -70
  53. bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
  54. bplusplus/yolov5detect/utils/aws/mime.sh +0 -26
  55. bplusplus/yolov5detect/utils/aws/resume.py +0 -41
  56. bplusplus/yolov5detect/utils/aws/userdata.sh +0 -27
  57. bplusplus/yolov5detect/utils/callbacks.py +0 -72
  58. bplusplus/yolov5detect/utils/dataloaders.py +0 -1385
  59. bplusplus/yolov5detect/utils/docker/Dockerfile +0 -73
  60. bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +0 -40
  61. bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +0 -42
  62. bplusplus/yolov5detect/utils/downloads.py +0 -136
  63. bplusplus/yolov5detect/utils/flask_rest_api/README.md +0 -70
  64. bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +0 -17
  65. bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +0 -49
  66. bplusplus/yolov5detect/utils/general.py +0 -1294
  67. bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +0 -25
  68. bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +0 -6
  69. bplusplus/yolov5detect/utils/google_app_engine/app.yaml +0 -16
  70. bplusplus/yolov5detect/utils/loggers/__init__.py +0 -476
  71. bplusplus/yolov5detect/utils/loggers/clearml/README.md +0 -222
  72. bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
  73. bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +0 -230
  74. bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +0 -90
  75. bplusplus/yolov5detect/utils/loggers/comet/README.md +0 -250
  76. bplusplus/yolov5detect/utils/loggers/comet/__init__.py +0 -551
  77. bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +0 -151
  78. bplusplus/yolov5detect/utils/loggers/comet/hpo.py +0 -126
  79. bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +0 -135
  80. bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
  81. bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +0 -210
  82. bplusplus/yolov5detect/utils/loss.py +0 -259
  83. bplusplus/yolov5detect/utils/metrics.py +0 -381
  84. bplusplus/yolov5detect/utils/plots.py +0 -517
  85. bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
  86. bplusplus/yolov5detect/utils/segment/augmentations.py +0 -100
  87. bplusplus/yolov5detect/utils/segment/dataloaders.py +0 -366
  88. bplusplus/yolov5detect/utils/segment/general.py +0 -160
  89. bplusplus/yolov5detect/utils/segment/loss.py +0 -198
  90. bplusplus/yolov5detect/utils/segment/metrics.py +0 -225
  91. bplusplus/yolov5detect/utils/segment/plots.py +0 -152
  92. bplusplus/yolov5detect/utils/torch_utils.py +0 -482
  93. bplusplus/yolov5detect/utils/triton.py +0 -90
  94. bplusplus-1.1.0.dist-info/METADATA +0 -179
  95. bplusplus-1.1.0.dist-info/RECORD +0 -92
  96. {bplusplus-1.1.0.dist-info → bplusplus-1.2.1.dist-info}/LICENSE +0 -0
  97. {bplusplus-1.1.0.dist-info → bplusplus-1.2.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,676 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from PIL import Image
8
+ import os
9
+ import numpy as np
10
+ from collections import defaultdict
11
+ import requests
12
+ import time
13
+ import logging
14
+ from tqdm import tqdm
15
+ import sys
16
+
17
+ def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='/mnt/nvme0n1p1/datasets/insect/bjerge-train2', output_dir='./output', species_list=None):
18
+ """
19
+ Main function to run the entire training pipeline.
20
+ Sets up datasets, model, training process and handles errors.
21
+ """
22
+ global logger, device
23
+
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
+ logger = logging.getLogger(__name__)
26
+
27
+ logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}")
28
+
29
+
30
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+
32
+ torch.manual_seed(42)
33
+ np.random.seed(42)
34
+
35
+ learning_rate = 1.0e-4
36
+
37
+ train_dir = os.path.join(data_dir, 'train')
38
+ val_dir = os.path.join(data_dir, 'valid')
39
+
40
+ os.makedirs(output_dir, exist_ok=True)
41
+
42
+ missing_species = []
43
+ for species in species_list:
44
+ species_dir = os.path.join(train_dir, species)
45
+ if not os.path.isdir(species_dir):
46
+ missing_species.append(species)
47
+
48
+ if missing_species:
49
+ raise ValueError(f"The following species directories were not found: {missing_species}")
50
+
51
+ logger.info(f"Using {len(species_list)} species in the specified order")
52
+
53
+ taxonomy = get_taxonomy(species_list)
54
+
55
+ level_to_idx, parent_child_relationship = create_mappings(taxonomy)
56
+
57
+ num_classes_per_level = [len(taxonomy[level]) if isinstance(taxonomy[level], list)
58
+ else len(taxonomy[level].keys()) for level in sorted(taxonomy.keys())]
59
+
60
+ train_dataset = InsectDataset(
61
+ root_dir=train_dir,
62
+ transform=get_transforms(is_training=True, img_size=img_size),
63
+ taxonomy=taxonomy,
64
+ level_to_idx=level_to_idx
65
+ )
66
+
67
+ val_dataset = InsectDataset(
68
+ root_dir=val_dir,
69
+ transform=get_transforms(is_training=False, img_size=img_size),
70
+ taxonomy=taxonomy,
71
+ level_to_idx=level_to_idx
72
+ )
73
+
74
+ train_loader = DataLoader(
75
+ train_dataset,
76
+ batch_size=batch_size,
77
+ shuffle=True,
78
+ num_workers=4
79
+ )
80
+
81
+ val_loader = DataLoader(
82
+ val_dataset,
83
+ batch_size=batch_size,
84
+ shuffle=False,
85
+ num_workers=4
86
+ )
87
+
88
+ try:
89
+ logger.info("Initializing model...")
90
+ model = HierarchicalInsectClassifier(
91
+ num_classes_per_level=num_classes_per_level,
92
+ level_to_idx=level_to_idx,
93
+ parent_child_relationship=parent_child_relationship
94
+ )
95
+ logger.info(f"Model structure initialized with {sum(p.numel() for p in model.parameters())} parameters")
96
+
97
+ logger.info("Setting up loss function and optimizer...")
98
+ criterion = HierarchicalLoss(
99
+ alpha=0.5,
100
+ level_to_idx=level_to_idx,
101
+ parent_child_relationship=parent_child_relationship
102
+ )
103
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
104
+
105
+ logger.info("Testing model with a dummy input...")
106
+ model.to(device)
107
+ dummy_input = torch.randn(1, 3, img_size, img_size).to(device)
108
+ with torch.no_grad():
109
+ try:
110
+ test_outputs = model(dummy_input)
111
+ logger.info(f"Forward pass test successful, output shapes: {[out.shape for out in test_outputs]}")
112
+ except Exception as e:
113
+ logger.error(f"Forward pass test failed: {str(e)}")
114
+ raise
115
+
116
+ logger.info("Starting training process...")
117
+ best_model_path = os.path.join(output_dir, 'best_multitask.pt')
118
+
119
+ trained_model = train_model(
120
+ model=model,
121
+ train_loader=train_loader,
122
+ val_loader=val_loader,
123
+ criterion=criterion,
124
+ optimizer=optimizer,
125
+ level_to_idx=level_to_idx,
126
+ parent_child_relationship=parent_child_relationship,
127
+ taxonomy=taxonomy,
128
+ species_list=species_list,
129
+ num_epochs=epochs,
130
+ patience=patience,
131
+ best_model_path=best_model_path
132
+ )
133
+
134
+ logger.info("Model saved successfully with taxonomy information!")
135
+ print("Model saved successfully with taxonomy information!")
136
+
137
+ return trained_model, taxonomy, level_to_idx, parent_child_relationship
138
+
139
+ except Exception as e:
140
+ logger.error(f"Critical error during model setup or training: {str(e)}")
141
+ logger.exception("Stack trace:")
142
+ raise
143
+
144
+ def get_taxonomy(species_list):
145
+ """
146
+ Retrieves taxonomic information for a list of species from GBIF API.
147
+ Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
148
+ """
149
+ taxonomy = {1: [], 2: {}, 3: {}}
150
+ species_to_family = {}
151
+ family_to_order = {}
152
+
153
+ logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
154
+
155
+ print("\nTaxonomy Results:")
156
+ print("-" * 80)
157
+ print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
158
+ print("-" * 80)
159
+
160
+ for species_name in species_list:
161
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
162
+ try:
163
+ response = requests.get(url)
164
+ data = response.json()
165
+
166
+ if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
167
+ family = data.get('family')
168
+ order = data.get('order')
169
+
170
+ if family and order:
171
+ status = "OK"
172
+
173
+ print(f"{species_name:<30} {order:<20} {family:<20} {status}")
174
+
175
+ species_to_family[species_name] = family
176
+ family_to_order[family] = order
177
+
178
+ if order not in taxonomy[1]:
179
+ taxonomy[1].append(order)
180
+
181
+ taxonomy[2][family] = order
182
+ taxonomy[3][species_name] = family
183
+ else:
184
+ error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
185
+ logger.error(error_msg)
186
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
187
+ print(f"Error: {error_msg}")
188
+ sys.exit(1) # Stop the script
189
+ else:
190
+ error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
191
+ logger.error(error_msg)
192
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
193
+ print(f"Error: {error_msg}")
194
+ sys.exit(1) # Stop the script
195
+
196
+ except Exception as e:
197
+ error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
198
+ logger.error(error_msg)
199
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
200
+ print(f"Error: {error_msg}")
201
+ sys.exit(1) # Stop the script
202
+
203
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
204
+ print("-" * 80)
205
+
206
+ num_orders = len(taxonomy[1])
207
+ num_families = len(taxonomy[2])
208
+ num_species = len(taxonomy[3])
209
+
210
+ print("\nOrder indices:")
211
+ for i, order in enumerate(taxonomy[1]):
212
+ print(f" {i}: {order}")
213
+
214
+ print("\nFamily indices:")
215
+ for i, family in enumerate(taxonomy[2].keys()):
216
+ print(f" {i}: {family}")
217
+
218
+ print("\nSpecies indices:")
219
+ for i, species in enumerate(species_list):
220
+ print(f" {i}: {species}")
221
+
222
+ logger.info(f"Taxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
223
+ return taxonomy
224
+
225
+ def get_species_from_directory(train_dir):
226
+ """
227
+ Extracts a list of species names from subdirectories in the training directory.
228
+ Returns a sorted list of species names found.
229
+ """
230
+ if not os.path.exists(train_dir):
231
+ raise ValueError(f"Training directory does not exist: {train_dir}")
232
+
233
+ species_list = []
234
+ for item in os.listdir(train_dir):
235
+ item_path = os.path.join(train_dir, item)
236
+ if os.path.isdir(item_path):
237
+ species_list.append(item)
238
+
239
+ species_list.sort()
240
+
241
+ if not species_list:
242
+ raise ValueError(f"No species subdirectories found in {train_dir}")
243
+
244
+ logger.info(f"Found {len(species_list)} species in {train_dir}")
245
+ return species_list
246
+
247
+ def create_mappings(taxonomy):
248
+ """
249
+ Creates mapping dictionaries from taxonomy data.
250
+ Returns level-to-index mapping and parent-child relationships between taxonomic levels.
251
+ """
252
+ level_to_idx = {}
253
+ parent_child_relationship = {}
254
+
255
+ for level, labels in taxonomy.items():
256
+ if isinstance(labels, list):
257
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
258
+ else:
259
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
260
+ for child, parent in labels.items():
261
+ if (level, parent) not in parent_child_relationship:
262
+ parent_child_relationship[(level, parent)] = []
263
+ parent_child_relationship[(level, parent)].append(child)
264
+
265
+ return level_to_idx, parent_child_relationship
266
+
267
+ class InsectDataset(Dataset):
268
+ """
269
+ PyTorch dataset for loading and processing insect images.
270
+ Organizes data according to taxonomic hierarchy and validates images.
271
+ """
272
+ def __init__(self, root_dir, transform=None, taxonomy=None, level_to_idx=None):
273
+ self.root_dir = root_dir
274
+ self.transform = transform
275
+ self.taxonomy = taxonomy
276
+ self.level_to_idx = level_to_idx
277
+ self.samples = []
278
+
279
+ species_to_family = {species: family for species, family in taxonomy[3].items()}
280
+ family_to_order = {family: order for family, order in taxonomy[2].items()}
281
+
282
+ for species_name in os.listdir(root_dir):
283
+ species_path = os.path.join(root_dir, species_name)
284
+ if os.path.isdir(species_path):
285
+ if species_name in species_to_family:
286
+ family_name = species_to_family[species_name]
287
+ order_name = family_to_order[family_name]
288
+
289
+ for img_file in os.listdir(species_path):
290
+ if img_file.endswith(('.jpg', '.png', '.jpeg')):
291
+ img_path = os.path.join(species_path, img_file)
292
+ # Validate the image can be opened
293
+ try:
294
+ with Image.open(img_path) as img:
295
+ img.convert('RGB')
296
+ # Only add valid images to samples
297
+ self.samples.append({
298
+ 'image_path': img_path,
299
+ 'labels': [order_name, family_name, species_name]
300
+ })
301
+
302
+ except Exception as e:
303
+ logger.warning(f"Skipping invalid image: {img_path} - Error: {str(e)}")
304
+ else:
305
+ logger.warning(f"Warning: Species '{species_name}' not found in taxonomy. Skipping.")
306
+
307
+ # Log statistics about valid/invalid images
308
+ logger.info(f"Dataset loaded with {len(self.samples)} valid images")
309
+
310
+ def __len__(self):
311
+ return len(self.samples)
312
+
313
+ def __getitem__(self, idx):
314
+ sample = self.samples[idx]
315
+ image = Image.open(sample['image_path']).convert('RGB')
316
+
317
+ if self.transform:
318
+ image = self.transform(image)
319
+
320
+ label_indices = [self.level_to_idx[level+1][label] for level, label in enumerate(sample['labels'])]
321
+
322
+ return image, torch.tensor(label_indices)
323
+
324
+ class HierarchicalInsectClassifier(nn.Module):
325
+ """
326
+ Deep learning model for hierarchical insect classification.
327
+ Uses ResNet50 backbone with multiple classification branches for different taxonomic levels.
328
+ Includes anomaly detection capabilities.
329
+ """
330
+ def __init__(self, num_classes_per_level, level_to_idx=None, parent_child_relationship=None):
331
+ super(HierarchicalInsectClassifier, self).__init__()
332
+
333
+ self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
334
+ backbone_output_features = self.backbone.fc.in_features
335
+ self.backbone.fc = nn.Identity()
336
+
337
+ self.branches = nn.ModuleList()
338
+ for num_classes in num_classes_per_level:
339
+ branch = nn.Sequential(
340
+ nn.Linear(backbone_output_features, 512),
341
+ nn.ReLU(),
342
+ nn.Dropout(0.5),
343
+ nn.Linear(512, num_classes)
344
+ )
345
+ self.branches.append(branch)
346
+
347
+ self.num_levels = len(num_classes_per_level)
348
+
349
+ # Store the taxonomy mappings
350
+ self.level_to_idx = level_to_idx
351
+ self.parent_child_relationship = parent_child_relationship
352
+
353
+ self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
354
+ self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
355
+ self.class_counts = [0] * sum(num_classes_per_level)
356
+ self.output_history = defaultdict(list)
357
+
358
+ def forward(self, x):
359
+ R0 = self.backbone(x)
360
+
361
+ outputs = []
362
+ for branch in self.branches:
363
+ outputs.append(branch(R0))
364
+
365
+ return outputs
366
+
367
+ def predict_with_hierarchy(self, x):
368
+ outputs = self.forward(x)
369
+ predictions = []
370
+ confidences = []
371
+ is_unsure = []
372
+
373
+ level1_output = outputs[0]
374
+ level1_probs = torch.softmax(level1_output, dim=1)
375
+ level1_pred = torch.argmax(level1_output, dim=1)
376
+ level1_conf = torch.gather(level1_probs, 1, level1_pred.unsqueeze(1)).squeeze(1)
377
+
378
+ start_idx = 0
379
+ level1_unsure = self.detect_anomalies(level1_output, level1_pred, start_idx)
380
+
381
+ predictions.append(level1_pred)
382
+ confidences.append(level1_conf)
383
+ is_unsure.append(level1_unsure)
384
+
385
+ # Check if taxonomy mappings are available
386
+ if self.level_to_idx is None or self.parent_child_relationship is None:
387
+ # Return basic predictions if taxonomy isn't available
388
+ for level in range(1, self.num_levels):
389
+ level_output = outputs[level]
390
+ level_probs = torch.softmax(level_output, dim=1)
391
+ level_pred = torch.argmax(level_output, dim=1)
392
+ level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
393
+ start_idx += outputs[level-1].shape[1]
394
+ level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
395
+
396
+ predictions.append(level_pred)
397
+ confidences.append(level_conf)
398
+ is_unsure.append(level_unsure)
399
+
400
+ return predictions, confidences, is_unsure
401
+
402
+ # If taxonomy is available, use hierarchical constraints
403
+ for level in range(1, self.num_levels):
404
+ level_output = outputs[level]
405
+ level_probs = torch.softmax(level_output, dim=1)
406
+ level_pred = torch.argmax(level_output, dim=1)
407
+ level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
408
+
409
+ start_idx += outputs[level-1].shape[1]
410
+ level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
411
+
412
+ level_unsure_hierarchy = torch.zeros_like(level_pred, dtype=torch.bool)
413
+ for i in range(level_pred.shape[0]):
414
+ prev_level_pred_idx = predictions[level-1][i].item()
415
+ curr_level_pred_idx = level_pred[i].item()
416
+
417
+ prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
418
+ curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
419
+
420
+ if (level+1, prev_level_label) in self.parent_child_relationship:
421
+ valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
422
+ if curr_level_label not in valid_children:
423
+ level_unsure_hierarchy[i] = True
424
+ else:
425
+ level_unsure_hierarchy[i] = True
426
+
427
+ level_unsure = torch.logical_or(level_unsure, level_unsure_hierarchy)
428
+
429
+ predictions.append(level_pred)
430
+ confidences.append(level_conf)
431
+ is_unsure.append(level_unsure)
432
+
433
+ return predictions, confidences, is_unsure
434
+
435
+ def detect_anomalies(self, outputs, predictions, start_idx):
436
+ unsure = torch.zeros_like(predictions, dtype=torch.bool)
437
+
438
+ if self.training:
439
+ for i in range(outputs.shape[0]):
440
+ pred_class = predictions[i].item()
441
+ class_idx = start_idx + pred_class
442
+ self.output_history[class_idx].append(outputs[i, pred_class].item())
443
+ else:
444
+ for i in range(outputs.shape[0]):
445
+ pred_class = predictions[i].item()
446
+ class_idx = start_idx + pred_class
447
+
448
+ if len(self.output_history[class_idx]) > 0:
449
+ mean = np.mean(self.output_history[class_idx])
450
+ std = np.std(self.output_history[class_idx])
451
+ threshold = mean - 2 * std
452
+
453
+ if outputs[i, pred_class].item() < threshold:
454
+ unsure[i] = True
455
+
456
+ return unsure
457
+
458
+ def update_anomaly_stats(self):
459
+ for class_idx, outputs in self.output_history.items():
460
+ if len(outputs) > 0:
461
+ self.class_means[class_idx] = torch.tensor(np.mean(outputs))
462
+ self.class_stds[class_idx] = torch.tensor(np.std(outputs))
463
+
464
+ class HierarchicalLoss(nn.Module):
465
+ """
466
+ Custom loss function for hierarchical classification.
467
+ Combines cross-entropy loss with dependency loss to enforce taxonomic constraints.
468
+ """
469
+ def __init__(self, alpha=0.5, level_to_idx=None, parent_child_relationship=None):
470
+ super(HierarchicalLoss, self).__init__()
471
+ self.alpha = alpha
472
+ self.ce_loss = nn.CrossEntropyLoss()
473
+ self.level_to_idx = level_to_idx
474
+ self.parent_child_relationship = parent_child_relationship
475
+
476
+ def forward(self, outputs, targets, predictions):
477
+ ce_losses = []
478
+ for level, output in enumerate(outputs):
479
+ ce_losses.append(self.ce_loss(output, targets[:, level]))
480
+
481
+ total_ce_loss = sum(ce_losses)
482
+
483
+ dependency_losses = []
484
+
485
+ # Skip dependency loss calculation if taxonomy isn't available
486
+ if self.level_to_idx is None or self.parent_child_relationship is None:
487
+ return total_ce_loss, total_ce_loss, torch.zeros(1, device=outputs[0].device)
488
+
489
+ for level in range(1, len(outputs)):
490
+ dependency_loss = torch.zeros(1, device=outputs[0].device)
491
+ for i in range(targets.shape[0]):
492
+ prev_level_pred_idx = predictions[level-1][i].item()
493
+ curr_level_pred_idx = predictions[level][i].item()
494
+
495
+ prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
496
+ curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
497
+
498
+ is_valid = False
499
+ if (level+1, prev_level_label) in self.parent_child_relationship:
500
+ valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
501
+ if curr_level_label in valid_children:
502
+ is_valid = True
503
+
504
+ D_l = 0 if is_valid else 1
505
+ dependency_loss += torch.exp(torch.tensor(D_l, device=outputs[0].device)) - 1
506
+
507
+ dependency_loss /= targets.shape[0]
508
+ dependency_losses.append(dependency_loss)
509
+
510
+ total_dependency_loss = sum(dependency_losses) if dependency_losses else torch.zeros(1, device=outputs[0].device)
511
+
512
+ total_loss = self.alpha * total_ce_loss + (1 - self.alpha) * total_dependency_loss
513
+
514
+ return total_loss, total_ce_loss, total_dependency_loss
515
+
516
+ def get_transforms(is_training=True, img_size=640):
517
+ """
518
+ Creates image transformation pipelines.
519
+ Returns different transformations for training and validation data.
520
+ """
521
+ if is_training:
522
+ return transforms.Compose([
523
+ transforms.RandomResizedCrop(img_size),
524
+ transforms.RandomHorizontalFlip(),
525
+ transforms.RandomVerticalFlip(),
526
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
527
+ transforms.RandomPerspective(distortion_scale=0.2),
528
+ transforms.ToTensor(),
529
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
530
+ ])
531
+ else:
532
+ return transforms.Compose([
533
+ transforms.Resize((img_size, img_size)), # Fixed size for all validation images
534
+ transforms.ToTensor(),
535
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
536
+ ])
537
+
538
+ def train_model(model, train_loader, val_loader, criterion, optimizer, level_to_idx, parent_child_relationship, taxonomy, species_list, num_epochs=10, patience=5, best_model_path='best_multitask.pt'):
539
+ """
540
+ Trains the hierarchical classifier model.
541
+ Implements early stopping, validation, and model checkpointing.
542
+ """
543
+ logger.info("Starting training")
544
+ model.to(device)
545
+
546
+ best_val_loss = float('inf')
547
+ epochs_without_improvement = 0
548
+
549
+ for epoch in range(num_epochs):
550
+ model.train()
551
+ running_loss = 0.0
552
+ running_ce_loss = 0.0
553
+ running_dep_loss = 0.0
554
+ correct_predictions = [0] * model.num_levels
555
+ total_predictions = 0
556
+
557
+ train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
558
+
559
+ for batch_idx, (images, labels) in enumerate(train_pbar):
560
+ try:
561
+ images = images.to(device)
562
+ labels = labels.to(device)
563
+
564
+ outputs = model(images)
565
+
566
+ predictions = []
567
+ for output in outputs:
568
+ pred = torch.argmax(output, dim=1)
569
+ predictions.append(pred)
570
+
571
+ loss, ce_loss, dep_loss = criterion(outputs, labels, predictions)
572
+
573
+ optimizer.zero_grad()
574
+ loss.backward()
575
+ optimizer.step()
576
+
577
+ running_loss += loss.item()
578
+ running_ce_loss += ce_loss.item()
579
+ running_dep_loss += dep_loss.item() if dep_loss.numel() > 0 else 0
580
+
581
+ for level in range(model.num_levels):
582
+ correct_predictions[level] += (predictions[level] == labels[:, level]).sum().item()
583
+ total_predictions += labels.size(0)
584
+
585
+ train_pbar.set_postfix(loss=f"{loss.item():.4f}")
586
+
587
+ except Exception as e:
588
+ logger.error(f"Error in training batch {batch_idx}: {str(e)}")
589
+ continue # Skip this batch and continue with the next one
590
+
591
+ epoch_loss = running_loss / len(train_loader)
592
+ epoch_ce_loss = running_ce_loss / len(train_loader)
593
+ epoch_dep_loss = running_dep_loss / len(train_loader)
594
+ epoch_accuracies = [correct / total_predictions for correct in correct_predictions]
595
+
596
+ model.update_anomaly_stats()
597
+
598
+ model.eval()
599
+ val_running_loss = 0.0
600
+ val_correct_predictions = [0] * model.num_levels
601
+ val_total_predictions = 0
602
+ val_unsure_count = [0] * model.num_levels
603
+
604
+ val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
605
+
606
+ with torch.no_grad():
607
+ for batch_idx, (images, labels) in enumerate(val_pbar):
608
+ try:
609
+ images = images.to(device)
610
+ labels = labels.to(device)
611
+
612
+ predictions, confidences, is_unsure = model.predict_with_hierarchy(images)
613
+ outputs = model(images)
614
+
615
+ loss, _, _ = criterion(outputs, labels, predictions)
616
+ val_running_loss += loss.item()
617
+
618
+ for level in range(model.num_levels):
619
+ correct_mask = (predictions[level] == labels[:, level]) & ~is_unsure[level]
620
+ val_correct_predictions[level] += correct_mask.sum().item()
621
+ val_unsure_count[level] += is_unsure[level].sum().item()
622
+ val_total_predictions += labels.size(0)
623
+
624
+ val_pbar.set_postfix(loss=f"{loss.item():.4f}")
625
+
626
+ except Exception as e:
627
+ logger.error(f"Error in validation batch {batch_idx}: {str(e)}")
628
+ continue
629
+
630
+ val_epoch_loss = val_running_loss / len(val_loader)
631
+ val_epoch_accuracies = [correct / val_total_predictions for correct in val_correct_predictions]
632
+ val_unsure_rates = [unsure / val_total_predictions for unsure in val_unsure_count]
633
+
634
+ # Print epoch summary
635
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
636
+ print(f"Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dep: {epoch_dep_loss:.4f})")
637
+ print(f"Valid Loss: {val_epoch_loss:.4f}")
638
+
639
+ for level in range(model.num_levels):
640
+ print(f"Level {level+1} - Train Acc: {epoch_accuracies[level]:.4f}, "
641
+ f"Valid Acc: {val_epoch_accuracies[level]:.4f}, "
642
+ f"Unsure: {val_unsure_rates[level]:.4f}")
643
+ print('-' * 60)
644
+
645
+ if val_epoch_loss < best_val_loss:
646
+ best_val_loss = val_epoch_loss
647
+ epochs_without_improvement = 0
648
+
649
+ torch.save({
650
+ 'model_state_dict': model.state_dict(),
651
+ 'taxonomy': taxonomy,
652
+ 'level_to_idx': level_to_idx,
653
+ 'parent_child_relationship': parent_child_relationship,
654
+ 'species_list': species_list
655
+ }, best_model_path)
656
+ logger.info(f"Saved best model at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
657
+ else:
658
+ epochs_without_improvement += 1
659
+ logger.info(f"No improvement for {epochs_without_improvement} epochs. Best val loss: {best_val_loss:.4f}")
660
+
661
+ if epochs_without_improvement >= patience:
662
+ logger.info(f"Early stopping triggered after {epoch+1} epochs")
663
+ print(f"Early stopping triggered after {epoch+1} epochs")
664
+ break
665
+
666
+ logger.info("Training completed successfully")
667
+ return model
668
+
669
+ if __name__ == '__main__':
670
+ species_list = [
671
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
672
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
673
+ "Eristalis tenax"
674
+ ]
675
+ train_multitask(species_list=species_list, epochs=2)
676
+