bplusplus 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +4 -2
- bplusplus/collect.py +69 -5
- bplusplus/hierarchical/test.py +670 -0
- bplusplus/hierarchical/train.py +676 -0
- bplusplus/prepare.py +228 -64
- bplusplus/resnet/test.py +473 -0
- bplusplus/resnet/train.py +329 -0
- bplusplus-1.2.0.dist-info/METADATA +249 -0
- bplusplus-1.2.0.dist-info/RECORD +12 -0
- bplusplus/yolov5detect/__init__.py +0 -1
- bplusplus/yolov5detect/detect.py +0 -444
- bplusplus/yolov5detect/export.py +0 -1530
- bplusplus/yolov5detect/insect.yaml +0 -8
- bplusplus/yolov5detect/models/__init__.py +0 -0
- bplusplus/yolov5detect/models/common.py +0 -1109
- bplusplus/yolov5detect/models/experimental.py +0 -130
- bplusplus/yolov5detect/models/hub/anchors.yaml +0 -56
- bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +0 -52
- bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +0 -42
- bplusplus/yolov5detect/models/hub/yolov3.yaml +0 -52
- bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +0 -43
- bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +0 -55
- bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +0 -42
- bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +0 -57
- bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +0 -68
- bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5l6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5m6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5n6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +0 -50
- bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5s6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5x6.yaml +0 -61
- bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +0 -49
- bplusplus/yolov5detect/models/tf.py +0 -797
- bplusplus/yolov5detect/models/yolo.py +0 -495
- bplusplus/yolov5detect/models/yolov5l.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5m.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5n.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5s.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5x.yaml +0 -49
- bplusplus/yolov5detect/utils/__init__.py +0 -97
- bplusplus/yolov5detect/utils/activations.py +0 -134
- bplusplus/yolov5detect/utils/augmentations.py +0 -448
- bplusplus/yolov5detect/utils/autoanchor.py +0 -175
- bplusplus/yolov5detect/utils/autobatch.py +0 -70
- bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
- bplusplus/yolov5detect/utils/aws/mime.sh +0 -26
- bplusplus/yolov5detect/utils/aws/resume.py +0 -41
- bplusplus/yolov5detect/utils/aws/userdata.sh +0 -27
- bplusplus/yolov5detect/utils/callbacks.py +0 -72
- bplusplus/yolov5detect/utils/dataloaders.py +0 -1385
- bplusplus/yolov5detect/utils/docker/Dockerfile +0 -73
- bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +0 -40
- bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +0 -42
- bplusplus/yolov5detect/utils/downloads.py +0 -136
- bplusplus/yolov5detect/utils/flask_rest_api/README.md +0 -70
- bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +0 -17
- bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +0 -49
- bplusplus/yolov5detect/utils/general.py +0 -1294
- bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +0 -25
- bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +0 -6
- bplusplus/yolov5detect/utils/google_app_engine/app.yaml +0 -16
- bplusplus/yolov5detect/utils/loggers/__init__.py +0 -476
- bplusplus/yolov5detect/utils/loggers/clearml/README.md +0 -222
- bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
- bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +0 -230
- bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +0 -90
- bplusplus/yolov5detect/utils/loggers/comet/README.md +0 -250
- bplusplus/yolov5detect/utils/loggers/comet/__init__.py +0 -551
- bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +0 -151
- bplusplus/yolov5detect/utils/loggers/comet/hpo.py +0 -126
- bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +0 -135
- bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
- bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +0 -210
- bplusplus/yolov5detect/utils/loss.py +0 -259
- bplusplus/yolov5detect/utils/metrics.py +0 -381
- bplusplus/yolov5detect/utils/plots.py +0 -517
- bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
- bplusplus/yolov5detect/utils/segment/augmentations.py +0 -100
- bplusplus/yolov5detect/utils/segment/dataloaders.py +0 -366
- bplusplus/yolov5detect/utils/segment/general.py +0 -160
- bplusplus/yolov5detect/utils/segment/loss.py +0 -198
- bplusplus/yolov5detect/utils/segment/metrics.py +0 -225
- bplusplus/yolov5detect/utils/segment/plots.py +0 -152
- bplusplus/yolov5detect/utils/torch_utils.py +0 -482
- bplusplus/yolov5detect/utils/triton.py +0 -90
- bplusplus-1.1.0.dist-info/METADATA +0 -179
- bplusplus-1.1.0.dist-info/RECORD +0 -92
- {bplusplus-1.1.0.dist-info → bplusplus-1.2.0.dist-info}/LICENSE +0 -0
- {bplusplus-1.1.0.dist-info → bplusplus-1.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,676 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
import torchvision.models as models
|
|
5
|
+
import torchvision.transforms as transforms
|
|
6
|
+
from torch.utils.data import Dataset, DataLoader
|
|
7
|
+
from PIL import Image
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
import requests
|
|
12
|
+
import time
|
|
13
|
+
import logging
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='/mnt/nvme0n1p1/datasets/insect/bjerge-train2', output_dir='./output', species_list=None):
|
|
18
|
+
"""
|
|
19
|
+
Main function to run the entire training pipeline.
|
|
20
|
+
Sets up datasets, model, training process and handles errors.
|
|
21
|
+
"""
|
|
22
|
+
global logger, device
|
|
23
|
+
|
|
24
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
31
|
+
|
|
32
|
+
torch.manual_seed(42)
|
|
33
|
+
np.random.seed(42)
|
|
34
|
+
|
|
35
|
+
learning_rate = 1.0e-4
|
|
36
|
+
|
|
37
|
+
train_dir = os.path.join(data_dir, 'train')
|
|
38
|
+
val_dir = os.path.join(data_dir, 'valid')
|
|
39
|
+
|
|
40
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
41
|
+
|
|
42
|
+
missing_species = []
|
|
43
|
+
for species in species_list:
|
|
44
|
+
species_dir = os.path.join(train_dir, species)
|
|
45
|
+
if not os.path.isdir(species_dir):
|
|
46
|
+
missing_species.append(species)
|
|
47
|
+
|
|
48
|
+
if missing_species:
|
|
49
|
+
raise ValueError(f"The following species directories were not found: {missing_species}")
|
|
50
|
+
|
|
51
|
+
logger.info(f"Using {len(species_list)} species in the specified order")
|
|
52
|
+
|
|
53
|
+
taxonomy = get_taxonomy(species_list)
|
|
54
|
+
|
|
55
|
+
level_to_idx, parent_child_relationship = create_mappings(taxonomy)
|
|
56
|
+
|
|
57
|
+
num_classes_per_level = [len(taxonomy[level]) if isinstance(taxonomy[level], list)
|
|
58
|
+
else len(taxonomy[level].keys()) for level in sorted(taxonomy.keys())]
|
|
59
|
+
|
|
60
|
+
train_dataset = InsectDataset(
|
|
61
|
+
root_dir=train_dir,
|
|
62
|
+
transform=get_transforms(is_training=True, img_size=img_size),
|
|
63
|
+
taxonomy=taxonomy,
|
|
64
|
+
level_to_idx=level_to_idx
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
val_dataset = InsectDataset(
|
|
68
|
+
root_dir=val_dir,
|
|
69
|
+
transform=get_transforms(is_training=False, img_size=img_size),
|
|
70
|
+
taxonomy=taxonomy,
|
|
71
|
+
level_to_idx=level_to_idx
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
train_loader = DataLoader(
|
|
75
|
+
train_dataset,
|
|
76
|
+
batch_size=batch_size,
|
|
77
|
+
shuffle=True,
|
|
78
|
+
num_workers=4
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
val_loader = DataLoader(
|
|
82
|
+
val_dataset,
|
|
83
|
+
batch_size=batch_size,
|
|
84
|
+
shuffle=False,
|
|
85
|
+
num_workers=4
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
logger.info("Initializing model...")
|
|
90
|
+
model = HierarchicalInsectClassifier(
|
|
91
|
+
num_classes_per_level=num_classes_per_level,
|
|
92
|
+
level_to_idx=level_to_idx,
|
|
93
|
+
parent_child_relationship=parent_child_relationship
|
|
94
|
+
)
|
|
95
|
+
logger.info(f"Model structure initialized with {sum(p.numel() for p in model.parameters())} parameters")
|
|
96
|
+
|
|
97
|
+
logger.info("Setting up loss function and optimizer...")
|
|
98
|
+
criterion = HierarchicalLoss(
|
|
99
|
+
alpha=0.5,
|
|
100
|
+
level_to_idx=level_to_idx,
|
|
101
|
+
parent_child_relationship=parent_child_relationship
|
|
102
|
+
)
|
|
103
|
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
104
|
+
|
|
105
|
+
logger.info("Testing model with a dummy input...")
|
|
106
|
+
model.to(device)
|
|
107
|
+
dummy_input = torch.randn(1, 3, img_size, img_size).to(device)
|
|
108
|
+
with torch.no_grad():
|
|
109
|
+
try:
|
|
110
|
+
test_outputs = model(dummy_input)
|
|
111
|
+
logger.info(f"Forward pass test successful, output shapes: {[out.shape for out in test_outputs]}")
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"Forward pass test failed: {str(e)}")
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
logger.info("Starting training process...")
|
|
117
|
+
best_model_path = os.path.join(output_dir, 'best_multitask.pt')
|
|
118
|
+
|
|
119
|
+
trained_model = train_model(
|
|
120
|
+
model=model,
|
|
121
|
+
train_loader=train_loader,
|
|
122
|
+
val_loader=val_loader,
|
|
123
|
+
criterion=criterion,
|
|
124
|
+
optimizer=optimizer,
|
|
125
|
+
level_to_idx=level_to_idx,
|
|
126
|
+
parent_child_relationship=parent_child_relationship,
|
|
127
|
+
taxonomy=taxonomy,
|
|
128
|
+
species_list=species_list,
|
|
129
|
+
num_epochs=epochs,
|
|
130
|
+
patience=patience,
|
|
131
|
+
best_model_path=best_model_path
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
logger.info("Model saved successfully with taxonomy information!")
|
|
135
|
+
print("Model saved successfully with taxonomy information!")
|
|
136
|
+
|
|
137
|
+
return trained_model, taxonomy, level_to_idx, parent_child_relationship
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"Critical error during model setup or training: {str(e)}")
|
|
141
|
+
logger.exception("Stack trace:")
|
|
142
|
+
raise
|
|
143
|
+
|
|
144
|
+
def get_taxonomy(species_list):
|
|
145
|
+
"""
|
|
146
|
+
Retrieves taxonomic information for a list of species from GBIF API.
|
|
147
|
+
Creates a hierarchical taxonomy dictionary with order, family, and species relationships.
|
|
148
|
+
"""
|
|
149
|
+
taxonomy = {1: [], 2: {}, 3: {}}
|
|
150
|
+
species_to_family = {}
|
|
151
|
+
family_to_order = {}
|
|
152
|
+
|
|
153
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_list)} species")
|
|
154
|
+
|
|
155
|
+
print("\nTaxonomy Results:")
|
|
156
|
+
print("-" * 80)
|
|
157
|
+
print(f"{'Species':<30} {'Order':<20} {'Family':<20} {'Status'}")
|
|
158
|
+
print("-" * 80)
|
|
159
|
+
|
|
160
|
+
for species_name in species_list:
|
|
161
|
+
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
162
|
+
try:
|
|
163
|
+
response = requests.get(url)
|
|
164
|
+
data = response.json()
|
|
165
|
+
|
|
166
|
+
if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
|
|
167
|
+
family = data.get('family')
|
|
168
|
+
order = data.get('order')
|
|
169
|
+
|
|
170
|
+
if family and order:
|
|
171
|
+
status = "OK"
|
|
172
|
+
|
|
173
|
+
print(f"{species_name:<30} {order:<20} {family:<20} {status}")
|
|
174
|
+
|
|
175
|
+
species_to_family[species_name] = family
|
|
176
|
+
family_to_order[family] = order
|
|
177
|
+
|
|
178
|
+
if order not in taxonomy[1]:
|
|
179
|
+
taxonomy[1].append(order)
|
|
180
|
+
|
|
181
|
+
taxonomy[2][family] = order
|
|
182
|
+
taxonomy[3][species_name] = family
|
|
183
|
+
else:
|
|
184
|
+
error_msg = f"Species '{species_name}' found in GBIF but family and order not found, could be spelling error in species, check GBIF"
|
|
185
|
+
logger.error(error_msg)
|
|
186
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
187
|
+
print(f"Error: {error_msg}")
|
|
188
|
+
sys.exit(1) # Stop the script
|
|
189
|
+
else:
|
|
190
|
+
error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
|
|
191
|
+
logger.error(error_msg)
|
|
192
|
+
print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
|
|
193
|
+
print(f"Error: {error_msg}")
|
|
194
|
+
sys.exit(1) # Stop the script
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
|
|
198
|
+
logger.error(error_msg)
|
|
199
|
+
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
200
|
+
print(f"Error: {error_msg}")
|
|
201
|
+
sys.exit(1) # Stop the script
|
|
202
|
+
|
|
203
|
+
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
204
|
+
print("-" * 80)
|
|
205
|
+
|
|
206
|
+
num_orders = len(taxonomy[1])
|
|
207
|
+
num_families = len(taxonomy[2])
|
|
208
|
+
num_species = len(taxonomy[3])
|
|
209
|
+
|
|
210
|
+
print("\nOrder indices:")
|
|
211
|
+
for i, order in enumerate(taxonomy[1]):
|
|
212
|
+
print(f" {i}: {order}")
|
|
213
|
+
|
|
214
|
+
print("\nFamily indices:")
|
|
215
|
+
for i, family in enumerate(taxonomy[2].keys()):
|
|
216
|
+
print(f" {i}: {family}")
|
|
217
|
+
|
|
218
|
+
print("\nSpecies indices:")
|
|
219
|
+
for i, species in enumerate(species_list):
|
|
220
|
+
print(f" {i}: {species}")
|
|
221
|
+
|
|
222
|
+
logger.info(f"Taxonomy built: {num_orders} orders, {num_families} families, {num_species} species")
|
|
223
|
+
return taxonomy
|
|
224
|
+
|
|
225
|
+
def get_species_from_directory(train_dir):
|
|
226
|
+
"""
|
|
227
|
+
Extracts a list of species names from subdirectories in the training directory.
|
|
228
|
+
Returns a sorted list of species names found.
|
|
229
|
+
"""
|
|
230
|
+
if not os.path.exists(train_dir):
|
|
231
|
+
raise ValueError(f"Training directory does not exist: {train_dir}")
|
|
232
|
+
|
|
233
|
+
species_list = []
|
|
234
|
+
for item in os.listdir(train_dir):
|
|
235
|
+
item_path = os.path.join(train_dir, item)
|
|
236
|
+
if os.path.isdir(item_path):
|
|
237
|
+
species_list.append(item)
|
|
238
|
+
|
|
239
|
+
species_list.sort()
|
|
240
|
+
|
|
241
|
+
if not species_list:
|
|
242
|
+
raise ValueError(f"No species subdirectories found in {train_dir}")
|
|
243
|
+
|
|
244
|
+
logger.info(f"Found {len(species_list)} species in {train_dir}")
|
|
245
|
+
return species_list
|
|
246
|
+
|
|
247
|
+
def create_mappings(taxonomy):
|
|
248
|
+
"""
|
|
249
|
+
Creates mapping dictionaries from taxonomy data.
|
|
250
|
+
Returns level-to-index mapping and parent-child relationships between taxonomic levels.
|
|
251
|
+
"""
|
|
252
|
+
level_to_idx = {}
|
|
253
|
+
parent_child_relationship = {}
|
|
254
|
+
|
|
255
|
+
for level, labels in taxonomy.items():
|
|
256
|
+
if isinstance(labels, list):
|
|
257
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
258
|
+
else:
|
|
259
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(labels.keys())}
|
|
260
|
+
for child, parent in labels.items():
|
|
261
|
+
if (level, parent) not in parent_child_relationship:
|
|
262
|
+
parent_child_relationship[(level, parent)] = []
|
|
263
|
+
parent_child_relationship[(level, parent)].append(child)
|
|
264
|
+
|
|
265
|
+
return level_to_idx, parent_child_relationship
|
|
266
|
+
|
|
267
|
+
class InsectDataset(Dataset):
|
|
268
|
+
"""
|
|
269
|
+
PyTorch dataset for loading and processing insect images.
|
|
270
|
+
Organizes data according to taxonomic hierarchy and validates images.
|
|
271
|
+
"""
|
|
272
|
+
def __init__(self, root_dir, transform=None, taxonomy=None, level_to_idx=None):
|
|
273
|
+
self.root_dir = root_dir
|
|
274
|
+
self.transform = transform
|
|
275
|
+
self.taxonomy = taxonomy
|
|
276
|
+
self.level_to_idx = level_to_idx
|
|
277
|
+
self.samples = []
|
|
278
|
+
|
|
279
|
+
species_to_family = {species: family for species, family in taxonomy[3].items()}
|
|
280
|
+
family_to_order = {family: order for family, order in taxonomy[2].items()}
|
|
281
|
+
|
|
282
|
+
for species_name in os.listdir(root_dir):
|
|
283
|
+
species_path = os.path.join(root_dir, species_name)
|
|
284
|
+
if os.path.isdir(species_path):
|
|
285
|
+
if species_name in species_to_family:
|
|
286
|
+
family_name = species_to_family[species_name]
|
|
287
|
+
order_name = family_to_order[family_name]
|
|
288
|
+
|
|
289
|
+
for img_file in os.listdir(species_path):
|
|
290
|
+
if img_file.endswith(('.jpg', '.png', '.jpeg')):
|
|
291
|
+
img_path = os.path.join(species_path, img_file)
|
|
292
|
+
# Validate the image can be opened
|
|
293
|
+
try:
|
|
294
|
+
with Image.open(img_path) as img:
|
|
295
|
+
img.convert('RGB')
|
|
296
|
+
# Only add valid images to samples
|
|
297
|
+
self.samples.append({
|
|
298
|
+
'image_path': img_path,
|
|
299
|
+
'labels': [order_name, family_name, species_name]
|
|
300
|
+
})
|
|
301
|
+
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.warning(f"Skipping invalid image: {img_path} - Error: {str(e)}")
|
|
304
|
+
else:
|
|
305
|
+
logger.warning(f"Warning: Species '{species_name}' not found in taxonomy. Skipping.")
|
|
306
|
+
|
|
307
|
+
# Log statistics about valid/invalid images
|
|
308
|
+
logger.info(f"Dataset loaded with {len(self.samples)} valid images")
|
|
309
|
+
|
|
310
|
+
def __len__(self):
|
|
311
|
+
return len(self.samples)
|
|
312
|
+
|
|
313
|
+
def __getitem__(self, idx):
|
|
314
|
+
sample = self.samples[idx]
|
|
315
|
+
image = Image.open(sample['image_path']).convert('RGB')
|
|
316
|
+
|
|
317
|
+
if self.transform:
|
|
318
|
+
image = self.transform(image)
|
|
319
|
+
|
|
320
|
+
label_indices = [self.level_to_idx[level+1][label] for level, label in enumerate(sample['labels'])]
|
|
321
|
+
|
|
322
|
+
return image, torch.tensor(label_indices)
|
|
323
|
+
|
|
324
|
+
class HierarchicalInsectClassifier(nn.Module):
|
|
325
|
+
"""
|
|
326
|
+
Deep learning model for hierarchical insect classification.
|
|
327
|
+
Uses ResNet50 backbone with multiple classification branches for different taxonomic levels.
|
|
328
|
+
Includes anomaly detection capabilities.
|
|
329
|
+
"""
|
|
330
|
+
def __init__(self, num_classes_per_level, level_to_idx=None, parent_child_relationship=None):
|
|
331
|
+
super(HierarchicalInsectClassifier, self).__init__()
|
|
332
|
+
|
|
333
|
+
self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
334
|
+
backbone_output_features = self.backbone.fc.in_features
|
|
335
|
+
self.backbone.fc = nn.Identity()
|
|
336
|
+
|
|
337
|
+
self.branches = nn.ModuleList()
|
|
338
|
+
for num_classes in num_classes_per_level:
|
|
339
|
+
branch = nn.Sequential(
|
|
340
|
+
nn.Linear(backbone_output_features, 512),
|
|
341
|
+
nn.ReLU(),
|
|
342
|
+
nn.Dropout(0.5),
|
|
343
|
+
nn.Linear(512, num_classes)
|
|
344
|
+
)
|
|
345
|
+
self.branches.append(branch)
|
|
346
|
+
|
|
347
|
+
self.num_levels = len(num_classes_per_level)
|
|
348
|
+
|
|
349
|
+
# Store the taxonomy mappings
|
|
350
|
+
self.level_to_idx = level_to_idx
|
|
351
|
+
self.parent_child_relationship = parent_child_relationship
|
|
352
|
+
|
|
353
|
+
self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
|
|
354
|
+
self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
|
|
355
|
+
self.class_counts = [0] * sum(num_classes_per_level)
|
|
356
|
+
self.output_history = defaultdict(list)
|
|
357
|
+
|
|
358
|
+
def forward(self, x):
|
|
359
|
+
R0 = self.backbone(x)
|
|
360
|
+
|
|
361
|
+
outputs = []
|
|
362
|
+
for branch in self.branches:
|
|
363
|
+
outputs.append(branch(R0))
|
|
364
|
+
|
|
365
|
+
return outputs
|
|
366
|
+
|
|
367
|
+
def predict_with_hierarchy(self, x):
|
|
368
|
+
outputs = self.forward(x)
|
|
369
|
+
predictions = []
|
|
370
|
+
confidences = []
|
|
371
|
+
is_unsure = []
|
|
372
|
+
|
|
373
|
+
level1_output = outputs[0]
|
|
374
|
+
level1_probs = torch.softmax(level1_output, dim=1)
|
|
375
|
+
level1_pred = torch.argmax(level1_output, dim=1)
|
|
376
|
+
level1_conf = torch.gather(level1_probs, 1, level1_pred.unsqueeze(1)).squeeze(1)
|
|
377
|
+
|
|
378
|
+
start_idx = 0
|
|
379
|
+
level1_unsure = self.detect_anomalies(level1_output, level1_pred, start_idx)
|
|
380
|
+
|
|
381
|
+
predictions.append(level1_pred)
|
|
382
|
+
confidences.append(level1_conf)
|
|
383
|
+
is_unsure.append(level1_unsure)
|
|
384
|
+
|
|
385
|
+
# Check if taxonomy mappings are available
|
|
386
|
+
if self.level_to_idx is None or self.parent_child_relationship is None:
|
|
387
|
+
# Return basic predictions if taxonomy isn't available
|
|
388
|
+
for level in range(1, self.num_levels):
|
|
389
|
+
level_output = outputs[level]
|
|
390
|
+
level_probs = torch.softmax(level_output, dim=1)
|
|
391
|
+
level_pred = torch.argmax(level_output, dim=1)
|
|
392
|
+
level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
|
|
393
|
+
start_idx += outputs[level-1].shape[1]
|
|
394
|
+
level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
|
|
395
|
+
|
|
396
|
+
predictions.append(level_pred)
|
|
397
|
+
confidences.append(level_conf)
|
|
398
|
+
is_unsure.append(level_unsure)
|
|
399
|
+
|
|
400
|
+
return predictions, confidences, is_unsure
|
|
401
|
+
|
|
402
|
+
# If taxonomy is available, use hierarchical constraints
|
|
403
|
+
for level in range(1, self.num_levels):
|
|
404
|
+
level_output = outputs[level]
|
|
405
|
+
level_probs = torch.softmax(level_output, dim=1)
|
|
406
|
+
level_pred = torch.argmax(level_output, dim=1)
|
|
407
|
+
level_conf = torch.gather(level_probs, 1, level_pred.unsqueeze(1)).squeeze(1)
|
|
408
|
+
|
|
409
|
+
start_idx += outputs[level-1].shape[1]
|
|
410
|
+
level_unsure = self.detect_anomalies(level_output, level_pred, start_idx)
|
|
411
|
+
|
|
412
|
+
level_unsure_hierarchy = torch.zeros_like(level_pred, dtype=torch.bool)
|
|
413
|
+
for i in range(level_pred.shape[0]):
|
|
414
|
+
prev_level_pred_idx = predictions[level-1][i].item()
|
|
415
|
+
curr_level_pred_idx = level_pred[i].item()
|
|
416
|
+
|
|
417
|
+
prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
|
|
418
|
+
curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
|
|
419
|
+
|
|
420
|
+
if (level+1, prev_level_label) in self.parent_child_relationship:
|
|
421
|
+
valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
|
|
422
|
+
if curr_level_label not in valid_children:
|
|
423
|
+
level_unsure_hierarchy[i] = True
|
|
424
|
+
else:
|
|
425
|
+
level_unsure_hierarchy[i] = True
|
|
426
|
+
|
|
427
|
+
level_unsure = torch.logical_or(level_unsure, level_unsure_hierarchy)
|
|
428
|
+
|
|
429
|
+
predictions.append(level_pred)
|
|
430
|
+
confidences.append(level_conf)
|
|
431
|
+
is_unsure.append(level_unsure)
|
|
432
|
+
|
|
433
|
+
return predictions, confidences, is_unsure
|
|
434
|
+
|
|
435
|
+
def detect_anomalies(self, outputs, predictions, start_idx):
|
|
436
|
+
unsure = torch.zeros_like(predictions, dtype=torch.bool)
|
|
437
|
+
|
|
438
|
+
if self.training:
|
|
439
|
+
for i in range(outputs.shape[0]):
|
|
440
|
+
pred_class = predictions[i].item()
|
|
441
|
+
class_idx = start_idx + pred_class
|
|
442
|
+
self.output_history[class_idx].append(outputs[i, pred_class].item())
|
|
443
|
+
else:
|
|
444
|
+
for i in range(outputs.shape[0]):
|
|
445
|
+
pred_class = predictions[i].item()
|
|
446
|
+
class_idx = start_idx + pred_class
|
|
447
|
+
|
|
448
|
+
if len(self.output_history[class_idx]) > 0:
|
|
449
|
+
mean = np.mean(self.output_history[class_idx])
|
|
450
|
+
std = np.std(self.output_history[class_idx])
|
|
451
|
+
threshold = mean - 2 * std
|
|
452
|
+
|
|
453
|
+
if outputs[i, pred_class].item() < threshold:
|
|
454
|
+
unsure[i] = True
|
|
455
|
+
|
|
456
|
+
return unsure
|
|
457
|
+
|
|
458
|
+
def update_anomaly_stats(self):
|
|
459
|
+
for class_idx, outputs in self.output_history.items():
|
|
460
|
+
if len(outputs) > 0:
|
|
461
|
+
self.class_means[class_idx] = torch.tensor(np.mean(outputs))
|
|
462
|
+
self.class_stds[class_idx] = torch.tensor(np.std(outputs))
|
|
463
|
+
|
|
464
|
+
class HierarchicalLoss(nn.Module):
|
|
465
|
+
"""
|
|
466
|
+
Custom loss function for hierarchical classification.
|
|
467
|
+
Combines cross-entropy loss with dependency loss to enforce taxonomic constraints.
|
|
468
|
+
"""
|
|
469
|
+
def __init__(self, alpha=0.5, level_to_idx=None, parent_child_relationship=None):
|
|
470
|
+
super(HierarchicalLoss, self).__init__()
|
|
471
|
+
self.alpha = alpha
|
|
472
|
+
self.ce_loss = nn.CrossEntropyLoss()
|
|
473
|
+
self.level_to_idx = level_to_idx
|
|
474
|
+
self.parent_child_relationship = parent_child_relationship
|
|
475
|
+
|
|
476
|
+
def forward(self, outputs, targets, predictions):
|
|
477
|
+
ce_losses = []
|
|
478
|
+
for level, output in enumerate(outputs):
|
|
479
|
+
ce_losses.append(self.ce_loss(output, targets[:, level]))
|
|
480
|
+
|
|
481
|
+
total_ce_loss = sum(ce_losses)
|
|
482
|
+
|
|
483
|
+
dependency_losses = []
|
|
484
|
+
|
|
485
|
+
# Skip dependency loss calculation if taxonomy isn't available
|
|
486
|
+
if self.level_to_idx is None or self.parent_child_relationship is None:
|
|
487
|
+
return total_ce_loss, total_ce_loss, torch.zeros(1, device=outputs[0].device)
|
|
488
|
+
|
|
489
|
+
for level in range(1, len(outputs)):
|
|
490
|
+
dependency_loss = torch.zeros(1, device=outputs[0].device)
|
|
491
|
+
for i in range(targets.shape[0]):
|
|
492
|
+
prev_level_pred_idx = predictions[level-1][i].item()
|
|
493
|
+
curr_level_pred_idx = predictions[level][i].item()
|
|
494
|
+
|
|
495
|
+
prev_level_label = list(self.level_to_idx[level])[prev_level_pred_idx]
|
|
496
|
+
curr_level_label = list(self.level_to_idx[level+1])[curr_level_pred_idx]
|
|
497
|
+
|
|
498
|
+
is_valid = False
|
|
499
|
+
if (level+1, prev_level_label) in self.parent_child_relationship:
|
|
500
|
+
valid_children = self.parent_child_relationship[(level+1, prev_level_label)]
|
|
501
|
+
if curr_level_label in valid_children:
|
|
502
|
+
is_valid = True
|
|
503
|
+
|
|
504
|
+
D_l = 0 if is_valid else 1
|
|
505
|
+
dependency_loss += torch.exp(torch.tensor(D_l, device=outputs[0].device)) - 1
|
|
506
|
+
|
|
507
|
+
dependency_loss /= targets.shape[0]
|
|
508
|
+
dependency_losses.append(dependency_loss)
|
|
509
|
+
|
|
510
|
+
total_dependency_loss = sum(dependency_losses) if dependency_losses else torch.zeros(1, device=outputs[0].device)
|
|
511
|
+
|
|
512
|
+
total_loss = self.alpha * total_ce_loss + (1 - self.alpha) * total_dependency_loss
|
|
513
|
+
|
|
514
|
+
return total_loss, total_ce_loss, total_dependency_loss
|
|
515
|
+
|
|
516
|
+
def get_transforms(is_training=True, img_size=640):
|
|
517
|
+
"""
|
|
518
|
+
Creates image transformation pipelines.
|
|
519
|
+
Returns different transformations for training and validation data.
|
|
520
|
+
"""
|
|
521
|
+
if is_training:
|
|
522
|
+
return transforms.Compose([
|
|
523
|
+
transforms.RandomResizedCrop(img_size),
|
|
524
|
+
transforms.RandomHorizontalFlip(),
|
|
525
|
+
transforms.RandomVerticalFlip(),
|
|
526
|
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
|
527
|
+
transforms.RandomPerspective(distortion_scale=0.2),
|
|
528
|
+
transforms.ToTensor(),
|
|
529
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
530
|
+
])
|
|
531
|
+
else:
|
|
532
|
+
return transforms.Compose([
|
|
533
|
+
transforms.Resize((img_size, img_size)), # Fixed size for all validation images
|
|
534
|
+
transforms.ToTensor(),
|
|
535
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
536
|
+
])
|
|
537
|
+
|
|
538
|
+
def train_model(model, train_loader, val_loader, criterion, optimizer, level_to_idx, parent_child_relationship, taxonomy, species_list, num_epochs=10, patience=5, best_model_path='best_multitask.pt'):
|
|
539
|
+
"""
|
|
540
|
+
Trains the hierarchical classifier model.
|
|
541
|
+
Implements early stopping, validation, and model checkpointing.
|
|
542
|
+
"""
|
|
543
|
+
logger.info("Starting training")
|
|
544
|
+
model.to(device)
|
|
545
|
+
|
|
546
|
+
best_val_loss = float('inf')
|
|
547
|
+
epochs_without_improvement = 0
|
|
548
|
+
|
|
549
|
+
for epoch in range(num_epochs):
|
|
550
|
+
model.train()
|
|
551
|
+
running_loss = 0.0
|
|
552
|
+
running_ce_loss = 0.0
|
|
553
|
+
running_dep_loss = 0.0
|
|
554
|
+
correct_predictions = [0] * model.num_levels
|
|
555
|
+
total_predictions = 0
|
|
556
|
+
|
|
557
|
+
train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
|
|
558
|
+
|
|
559
|
+
for batch_idx, (images, labels) in enumerate(train_pbar):
|
|
560
|
+
try:
|
|
561
|
+
images = images.to(device)
|
|
562
|
+
labels = labels.to(device)
|
|
563
|
+
|
|
564
|
+
outputs = model(images)
|
|
565
|
+
|
|
566
|
+
predictions = []
|
|
567
|
+
for output in outputs:
|
|
568
|
+
pred = torch.argmax(output, dim=1)
|
|
569
|
+
predictions.append(pred)
|
|
570
|
+
|
|
571
|
+
loss, ce_loss, dep_loss = criterion(outputs, labels, predictions)
|
|
572
|
+
|
|
573
|
+
optimizer.zero_grad()
|
|
574
|
+
loss.backward()
|
|
575
|
+
optimizer.step()
|
|
576
|
+
|
|
577
|
+
running_loss += loss.item()
|
|
578
|
+
running_ce_loss += ce_loss.item()
|
|
579
|
+
running_dep_loss += dep_loss.item() if dep_loss.numel() > 0 else 0
|
|
580
|
+
|
|
581
|
+
for level in range(model.num_levels):
|
|
582
|
+
correct_predictions[level] += (predictions[level] == labels[:, level]).sum().item()
|
|
583
|
+
total_predictions += labels.size(0)
|
|
584
|
+
|
|
585
|
+
train_pbar.set_postfix(loss=f"{loss.item():.4f}")
|
|
586
|
+
|
|
587
|
+
except Exception as e:
|
|
588
|
+
logger.error(f"Error in training batch {batch_idx}: {str(e)}")
|
|
589
|
+
continue # Skip this batch and continue with the next one
|
|
590
|
+
|
|
591
|
+
epoch_loss = running_loss / len(train_loader)
|
|
592
|
+
epoch_ce_loss = running_ce_loss / len(train_loader)
|
|
593
|
+
epoch_dep_loss = running_dep_loss / len(train_loader)
|
|
594
|
+
epoch_accuracies = [correct / total_predictions for correct in correct_predictions]
|
|
595
|
+
|
|
596
|
+
model.update_anomaly_stats()
|
|
597
|
+
|
|
598
|
+
model.eval()
|
|
599
|
+
val_running_loss = 0.0
|
|
600
|
+
val_correct_predictions = [0] * model.num_levels
|
|
601
|
+
val_total_predictions = 0
|
|
602
|
+
val_unsure_count = [0] * model.num_levels
|
|
603
|
+
|
|
604
|
+
val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
|
|
605
|
+
|
|
606
|
+
with torch.no_grad():
|
|
607
|
+
for batch_idx, (images, labels) in enumerate(val_pbar):
|
|
608
|
+
try:
|
|
609
|
+
images = images.to(device)
|
|
610
|
+
labels = labels.to(device)
|
|
611
|
+
|
|
612
|
+
predictions, confidences, is_unsure = model.predict_with_hierarchy(images)
|
|
613
|
+
outputs = model(images)
|
|
614
|
+
|
|
615
|
+
loss, _, _ = criterion(outputs, labels, predictions)
|
|
616
|
+
val_running_loss += loss.item()
|
|
617
|
+
|
|
618
|
+
for level in range(model.num_levels):
|
|
619
|
+
correct_mask = (predictions[level] == labels[:, level]) & ~is_unsure[level]
|
|
620
|
+
val_correct_predictions[level] += correct_mask.sum().item()
|
|
621
|
+
val_unsure_count[level] += is_unsure[level].sum().item()
|
|
622
|
+
val_total_predictions += labels.size(0)
|
|
623
|
+
|
|
624
|
+
val_pbar.set_postfix(loss=f"{loss.item():.4f}")
|
|
625
|
+
|
|
626
|
+
except Exception as e:
|
|
627
|
+
logger.error(f"Error in validation batch {batch_idx}: {str(e)}")
|
|
628
|
+
continue
|
|
629
|
+
|
|
630
|
+
val_epoch_loss = val_running_loss / len(val_loader)
|
|
631
|
+
val_epoch_accuracies = [correct / val_total_predictions for correct in val_correct_predictions]
|
|
632
|
+
val_unsure_rates = [unsure / val_total_predictions for unsure in val_unsure_count]
|
|
633
|
+
|
|
634
|
+
# Print epoch summary
|
|
635
|
+
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
636
|
+
print(f"Train Loss: {epoch_loss:.4f} (CE: {epoch_ce_loss:.4f}, Dep: {epoch_dep_loss:.4f})")
|
|
637
|
+
print(f"Valid Loss: {val_epoch_loss:.4f}")
|
|
638
|
+
|
|
639
|
+
for level in range(model.num_levels):
|
|
640
|
+
print(f"Level {level+1} - Train Acc: {epoch_accuracies[level]:.4f}, "
|
|
641
|
+
f"Valid Acc: {val_epoch_accuracies[level]:.4f}, "
|
|
642
|
+
f"Unsure: {val_unsure_rates[level]:.4f}")
|
|
643
|
+
print('-' * 60)
|
|
644
|
+
|
|
645
|
+
if val_epoch_loss < best_val_loss:
|
|
646
|
+
best_val_loss = val_epoch_loss
|
|
647
|
+
epochs_without_improvement = 0
|
|
648
|
+
|
|
649
|
+
torch.save({
|
|
650
|
+
'model_state_dict': model.state_dict(),
|
|
651
|
+
'taxonomy': taxonomy,
|
|
652
|
+
'level_to_idx': level_to_idx,
|
|
653
|
+
'parent_child_relationship': parent_child_relationship,
|
|
654
|
+
'species_list': species_list
|
|
655
|
+
}, best_model_path)
|
|
656
|
+
logger.info(f"Saved best model at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
|
|
657
|
+
else:
|
|
658
|
+
epochs_without_improvement += 1
|
|
659
|
+
logger.info(f"No improvement for {epochs_without_improvement} epochs. Best val loss: {best_val_loss:.4f}")
|
|
660
|
+
|
|
661
|
+
if epochs_without_improvement >= patience:
|
|
662
|
+
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
|
663
|
+
print(f"Early stopping triggered after {epoch+1} epochs")
|
|
664
|
+
break
|
|
665
|
+
|
|
666
|
+
logger.info("Training completed successfully")
|
|
667
|
+
return model
|
|
668
|
+
|
|
669
|
+
if __name__ == '__main__':
|
|
670
|
+
species_list = [
|
|
671
|
+
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
672
|
+
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
673
|
+
"Eristalis tenax"
|
|
674
|
+
]
|
|
675
|
+
train_multitask(species_list=species_list, epochs=2)
|
|
676
|
+
|