bplusplus 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +4 -2
- bplusplus/collect.py +72 -3
- bplusplus/hierarchical/test.py +670 -0
- bplusplus/hierarchical/train.py +676 -0
- bplusplus/prepare.py +236 -71
- bplusplus/resnet/test.py +473 -0
- bplusplus/resnet/train.py +329 -0
- bplusplus-1.2.1.dist-info/METADATA +252 -0
- bplusplus-1.2.1.dist-info/RECORD +12 -0
- bplusplus/yolov5detect/__init__.py +0 -1
- bplusplus/yolov5detect/detect.py +0 -444
- bplusplus/yolov5detect/export.py +0 -1530
- bplusplus/yolov5detect/insect.yaml +0 -8
- bplusplus/yolov5detect/models/__init__.py +0 -0
- bplusplus/yolov5detect/models/common.py +0 -1109
- bplusplus/yolov5detect/models/experimental.py +0 -130
- bplusplus/yolov5detect/models/hub/anchors.yaml +0 -56
- bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +0 -52
- bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +0 -42
- bplusplus/yolov5detect/models/hub/yolov3.yaml +0 -52
- bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +0 -43
- bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +0 -55
- bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +0 -42
- bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +0 -57
- bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +0 -68
- bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5l6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5m6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5n6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +0 -50
- bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +0 -49
- bplusplus/yolov5detect/models/hub/yolov5s6.yaml +0 -61
- bplusplus/yolov5detect/models/hub/yolov5x6.yaml +0 -61
- bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +0 -49
- bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +0 -49
- bplusplus/yolov5detect/models/tf.py +0 -797
- bplusplus/yolov5detect/models/yolo.py +0 -495
- bplusplus/yolov5detect/models/yolov5l.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5m.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5n.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5s.yaml +0 -49
- bplusplus/yolov5detect/models/yolov5x.yaml +0 -49
- bplusplus/yolov5detect/utils/__init__.py +0 -97
- bplusplus/yolov5detect/utils/activations.py +0 -134
- bplusplus/yolov5detect/utils/augmentations.py +0 -448
- bplusplus/yolov5detect/utils/autoanchor.py +0 -175
- bplusplus/yolov5detect/utils/autobatch.py +0 -70
- bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
- bplusplus/yolov5detect/utils/aws/mime.sh +0 -26
- bplusplus/yolov5detect/utils/aws/resume.py +0 -41
- bplusplus/yolov5detect/utils/aws/userdata.sh +0 -27
- bplusplus/yolov5detect/utils/callbacks.py +0 -72
- bplusplus/yolov5detect/utils/dataloaders.py +0 -1385
- bplusplus/yolov5detect/utils/docker/Dockerfile +0 -73
- bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +0 -40
- bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +0 -42
- bplusplus/yolov5detect/utils/downloads.py +0 -136
- bplusplus/yolov5detect/utils/flask_rest_api/README.md +0 -70
- bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +0 -17
- bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +0 -49
- bplusplus/yolov5detect/utils/general.py +0 -1294
- bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +0 -25
- bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +0 -6
- bplusplus/yolov5detect/utils/google_app_engine/app.yaml +0 -16
- bplusplus/yolov5detect/utils/loggers/__init__.py +0 -476
- bplusplus/yolov5detect/utils/loggers/clearml/README.md +0 -222
- bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
- bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +0 -230
- bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +0 -90
- bplusplus/yolov5detect/utils/loggers/comet/README.md +0 -250
- bplusplus/yolov5detect/utils/loggers/comet/__init__.py +0 -551
- bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +0 -151
- bplusplus/yolov5detect/utils/loggers/comet/hpo.py +0 -126
- bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +0 -135
- bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
- bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +0 -210
- bplusplus/yolov5detect/utils/loss.py +0 -259
- bplusplus/yolov5detect/utils/metrics.py +0 -381
- bplusplus/yolov5detect/utils/plots.py +0 -517
- bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
- bplusplus/yolov5detect/utils/segment/augmentations.py +0 -100
- bplusplus/yolov5detect/utils/segment/dataloaders.py +0 -366
- bplusplus/yolov5detect/utils/segment/general.py +0 -160
- bplusplus/yolov5detect/utils/segment/loss.py +0 -198
- bplusplus/yolov5detect/utils/segment/metrics.py +0 -225
- bplusplus/yolov5detect/utils/segment/plots.py +0 -152
- bplusplus/yolov5detect/utils/torch_utils.py +0 -482
- bplusplus/yolov5detect/utils/triton.py +0 -90
- bplusplus-1.1.0.dist-info/METADATA +0 -179
- bplusplus-1.1.0.dist-info/RECORD +0 -92
- {bplusplus-1.1.0.dist-info → bplusplus-1.2.1.dist-info}/LICENSE +0 -0
- {bplusplus-1.1.0.dist-info → bplusplus-1.2.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# pip install ultralytics torchvision pillow numpy scikit-learn tabulate tqdm
|
|
2
|
+
#python3 train-resnet.py --data_dir '' --output_dir '' --arch resnet50 --img_size 956 --num_epochs 50 --batch_size 4
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.optim as optim
|
|
7
|
+
from torch.utils.data import Dataset
|
|
8
|
+
from torchvision import datasets, models, transforms
|
|
9
|
+
import time
|
|
10
|
+
import os
|
|
11
|
+
import copy
|
|
12
|
+
from PIL import Image
|
|
13
|
+
import logging
|
|
14
|
+
from torchvision.models import ResNet152_Weights, ResNet50_Weights
|
|
15
|
+
import numpy as np
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
from collections import Counter
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
|
|
21
|
+
def train_resnet(species_list, model_type='resnet152', batch_size=4, num_epochs=50, patience=5, output_dir=None, data_dir=None, img_size=1024):
|
|
22
|
+
"""Main entry point for training the model."""
|
|
23
|
+
# Setup output directory
|
|
24
|
+
output_dir = setup_directories(output_dir)
|
|
25
|
+
train_dir = os.path.join(data_dir, 'train')
|
|
26
|
+
val_dir = os.path.join(data_dir, 'valid')
|
|
27
|
+
|
|
28
|
+
# Get transforms directly using the img_size parameter
|
|
29
|
+
train_transforms, val_transforms = get_transforms(img_size)
|
|
30
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {num_epochs}, Patience: {patience}, Data directory: {data_dir}, Output directory: {output_dir}")
|
|
34
|
+
|
|
35
|
+
# Use InsectDataset instead of OrderedImageFolder
|
|
36
|
+
train_dataset = InsectDataset(train_dir, species_list=species_list, transform=train_transforms)
|
|
37
|
+
val_dataset = InsectDataset(val_dir, species_list=species_list, transform=val_transforms)
|
|
38
|
+
|
|
39
|
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
|
40
|
+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
|
41
|
+
|
|
42
|
+
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
|
|
43
|
+
class_names = train_dataset.classes
|
|
44
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
45
|
+
|
|
46
|
+
logger.info(f"Using device: {device}")
|
|
47
|
+
|
|
48
|
+
# Initialize the model based on the model_type parameter
|
|
49
|
+
logger.info(f"Initializing {model_type} model...")
|
|
50
|
+
if model_type == 'resnet152':
|
|
51
|
+
model = models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
|
|
52
|
+
else:
|
|
53
|
+
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
|
54
|
+
|
|
55
|
+
num_ftrs = model.fc.in_features
|
|
56
|
+
model.fc = nn.Sequential(
|
|
57
|
+
nn.Dropout(0.4),
|
|
58
|
+
nn.Linear(num_ftrs, len(class_names))
|
|
59
|
+
)
|
|
60
|
+
model = model.to(device)
|
|
61
|
+
logger.info(f"Model structure initialized with {sum(p.numel() for p in model.parameters())} parameters")
|
|
62
|
+
|
|
63
|
+
criterion = nn.CrossEntropyLoss()
|
|
64
|
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
|
65
|
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
|
|
66
|
+
|
|
67
|
+
# Print class distributions as in hierarchical-train.py
|
|
68
|
+
print("\nTraining set class distribution:")
|
|
69
|
+
train_counts = Counter(train_dataset.targets)
|
|
70
|
+
for idx, count in train_counts.items():
|
|
71
|
+
print(f" {class_names[idx]}: {count}")
|
|
72
|
+
|
|
73
|
+
print("\nValidation set class distribution:")
|
|
74
|
+
val_counts = Counter(val_dataset.targets)
|
|
75
|
+
for idx, count in val_counts.items():
|
|
76
|
+
print(f" {class_names[idx]}: {count}")
|
|
77
|
+
|
|
78
|
+
logger.info("Starting training process...")
|
|
79
|
+
model = train_model(model, criterion, optimizer, scheduler, num_epochs, device, train_loader, val_loader,
|
|
80
|
+
dataset_sizes, class_names, output_dir, patience=patience)
|
|
81
|
+
|
|
82
|
+
# Save the best model
|
|
83
|
+
model_filename = f'best_{model_type}.pt'
|
|
84
|
+
if output_dir:
|
|
85
|
+
model_path = os.path.join(output_dir, model_filename)
|
|
86
|
+
else:
|
|
87
|
+
model_path = model_filename
|
|
88
|
+
|
|
89
|
+
torch.save(model.state_dict(), model_path)
|
|
90
|
+
logger.info(f"Saved best model to {model_path}")
|
|
91
|
+
print(f"Model saved successfully!")
|
|
92
|
+
|
|
93
|
+
return model
|
|
94
|
+
|
|
95
|
+
class InsectDataset(Dataset):
|
|
96
|
+
"""Dataset for loading and processing insect images with validation."""
|
|
97
|
+
def __init__(self, root_dir, species_list, transform=None):
|
|
98
|
+
self.root_dir = root_dir
|
|
99
|
+
self.transform = transform
|
|
100
|
+
self.samples = []
|
|
101
|
+
self.targets = []
|
|
102
|
+
self.classes = species_list
|
|
103
|
+
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(species_list)}
|
|
104
|
+
|
|
105
|
+
logger = logging.getLogger(__name__)
|
|
106
|
+
logger.info(f"Loading dataset from {root_dir} with {len(species_list)} species")
|
|
107
|
+
|
|
108
|
+
# Validate all images similar to SafeImageFolder
|
|
109
|
+
invalid_count = 0
|
|
110
|
+
for species_name in species_list:
|
|
111
|
+
species_path = os.path.join(root_dir, species_name)
|
|
112
|
+
if os.path.isdir(species_path):
|
|
113
|
+
for img_file in os.listdir(species_path):
|
|
114
|
+
if img_file.lower().endswith(('.jpg', '.png', '.jpeg')):
|
|
115
|
+
img_path = os.path.join(species_path, img_file)
|
|
116
|
+
try:
|
|
117
|
+
# Validate the image can be opened
|
|
118
|
+
with Image.open(img_path) as img:
|
|
119
|
+
img.convert('RGB')
|
|
120
|
+
# Only add valid images to samples
|
|
121
|
+
class_idx = self.class_to_idx[species_name]
|
|
122
|
+
self.samples.append((img_path, class_idx))
|
|
123
|
+
self.targets.append(class_idx)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
invalid_count += 1
|
|
126
|
+
logger.warning(f"Skipping invalid image {img_path}: {str(e)}")
|
|
127
|
+
else:
|
|
128
|
+
logger.warning(f"Species directory not found: {species_path}")
|
|
129
|
+
|
|
130
|
+
logger.info(f"Dataset loaded with {len(self.samples)} valid images ({invalid_count} invalid images skipped)")
|
|
131
|
+
|
|
132
|
+
def __len__(self):
|
|
133
|
+
return len(self.samples)
|
|
134
|
+
|
|
135
|
+
def __getitem__(self, idx):
|
|
136
|
+
"""Get an image and its target in a simple approach."""
|
|
137
|
+
img_path, target = self.samples[idx]
|
|
138
|
+
# Simple direct loading
|
|
139
|
+
image = Image.open(img_path).convert('RGB')
|
|
140
|
+
|
|
141
|
+
if self.transform is not None:
|
|
142
|
+
image = self.transform(image)
|
|
143
|
+
|
|
144
|
+
return image, target
|
|
145
|
+
|
|
146
|
+
def setup_directories(output_dir):
|
|
147
|
+
"""Create output directory if needed."""
|
|
148
|
+
if output_dir:
|
|
149
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
150
|
+
return output_dir
|
|
151
|
+
|
|
152
|
+
def get_transforms(img_size):
|
|
153
|
+
"""Return training and validation transforms based on the desired image size."""
|
|
154
|
+
train_transforms = transforms.Compose([
|
|
155
|
+
transforms.RandomResizedCrop(img_size),
|
|
156
|
+
transforms.RandomHorizontalFlip(),
|
|
157
|
+
transforms.RandomVerticalFlip(),
|
|
158
|
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
|
159
|
+
transforms.RandomPerspective(distortion_scale=0.2),
|
|
160
|
+
transforms.ToTensor(),
|
|
161
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
162
|
+
])
|
|
163
|
+
|
|
164
|
+
val_transforms = transforms.Compose([
|
|
165
|
+
transforms.Resize((int(img_size * 1.2), int(img_size * 1.2))),
|
|
166
|
+
transforms.CenterCrop(img_size),
|
|
167
|
+
transforms.ToTensor(),
|
|
168
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
169
|
+
])
|
|
170
|
+
|
|
171
|
+
return train_transforms, val_transforms
|
|
172
|
+
|
|
173
|
+
def print_class_assignments(dataset, class_names, max_images_per_class=10):
|
|
174
|
+
"""
|
|
175
|
+
Print information about which images belong to each class.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
dataset: The dataset to analyze
|
|
179
|
+
class_names: List of class names
|
|
180
|
+
max_images_per_class: Maximum number of images to print per class
|
|
181
|
+
"""
|
|
182
|
+
print("\n==== Class Assignment Information ====")
|
|
183
|
+
|
|
184
|
+
# Count images per class
|
|
185
|
+
class_counts = {}
|
|
186
|
+
for _, target in dataset.samples:
|
|
187
|
+
class_name = class_names[target]
|
|
188
|
+
if class_name not in class_counts:
|
|
189
|
+
class_counts[class_name] = 0
|
|
190
|
+
class_counts[class_name] += 1
|
|
191
|
+
|
|
192
|
+
# Print summary of class distribution
|
|
193
|
+
print("\nClass distribution:")
|
|
194
|
+
for class_name, count in class_counts.items():
|
|
195
|
+
print(f" {class_name}: {count} images")
|
|
196
|
+
|
|
197
|
+
# Print sample file paths for each class
|
|
198
|
+
print("\nSample images for each class:")
|
|
199
|
+
for class_name in class_names:
|
|
200
|
+
print(f"\nClass: {class_name}")
|
|
201
|
+
|
|
202
|
+
# Collect paths for this class
|
|
203
|
+
paths = []
|
|
204
|
+
for path, target in dataset.samples:
|
|
205
|
+
if class_names[target] == class_name:
|
|
206
|
+
paths.append(path)
|
|
207
|
+
|
|
208
|
+
# Limit the number of paths to print
|
|
209
|
+
if len(paths) >= max_images_per_class:
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
# Print paths
|
|
213
|
+
for i, path in enumerate(paths, 1):
|
|
214
|
+
print(f" {i}. {path}")
|
|
215
|
+
|
|
216
|
+
# Check if any images were found
|
|
217
|
+
if not paths:
|
|
218
|
+
print(" No images found for this class!")
|
|
219
|
+
|
|
220
|
+
def train_model(model, criterion, optimizer, scheduler, num_epochs, device, train_loader, val_loader, dataset_sizes, class_names, output_dir, patience=10):
|
|
221
|
+
"""Train the model and return the best model based on validation accuracy."""
|
|
222
|
+
since = time.time() # Record start time
|
|
223
|
+
best_model_wts = copy.deepcopy(model.state_dict()) # Store initial model state
|
|
224
|
+
best_acc = 0.0 # Initialize best accuracy
|
|
225
|
+
scaler = torch.cuda.amp.GradScaler() # For mixed precision training
|
|
226
|
+
|
|
227
|
+
# Early stopping variables
|
|
228
|
+
epochs_no_improve = 0
|
|
229
|
+
early_stop = False
|
|
230
|
+
|
|
231
|
+
for epoch in range(num_epochs):
|
|
232
|
+
# Display current epoch
|
|
233
|
+
print(f'\nEpoch {epoch+1}/{num_epochs}')
|
|
234
|
+
print('-' * 60)
|
|
235
|
+
|
|
236
|
+
for phase in ['train', 'val']:
|
|
237
|
+
if phase == 'train':
|
|
238
|
+
model.train() # Set model to training mode
|
|
239
|
+
loader = train_loader
|
|
240
|
+
else:
|
|
241
|
+
model.eval() # Set model to evaluation mode
|
|
242
|
+
loader = val_loader
|
|
243
|
+
|
|
244
|
+
running_loss = 0.0 # Initialize loss accumulator
|
|
245
|
+
running_corrects = 0 # Initialize correct predictions accumulator
|
|
246
|
+
|
|
247
|
+
# Create progress bar with appropriate description
|
|
248
|
+
pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs} [{phase.capitalize()}]")
|
|
249
|
+
|
|
250
|
+
for inputs, labels in pbar:
|
|
251
|
+
inputs = inputs.to(device) # Move inputs to device
|
|
252
|
+
labels = labels.to(device) # Move labels to device
|
|
253
|
+
|
|
254
|
+
optimizer.zero_grad() # Reset gradients
|
|
255
|
+
|
|
256
|
+
with torch.set_grad_enabled(phase == 'train'): # Enable gradients only in train phase
|
|
257
|
+
with torch.cuda.amp.autocast(): # Enable mixed precision
|
|
258
|
+
outputs = model(inputs) # Forward pass
|
|
259
|
+
_, preds = torch.max(outputs, 1) # Get predictions
|
|
260
|
+
loss = criterion(outputs, labels) # Compute loss
|
|
261
|
+
|
|
262
|
+
if phase == 'train':
|
|
263
|
+
scaler.scale(loss).backward() # Backpropagation with scaling
|
|
264
|
+
scaler.step(optimizer) # Update parameters
|
|
265
|
+
scaler.update() # Update scaler for mixed precision
|
|
266
|
+
|
|
267
|
+
running_loss += loss.item() * inputs.size(0) # Accumulate loss
|
|
268
|
+
running_corrects += torch.sum(preds == labels.data) # Accumulate correct predictions
|
|
269
|
+
|
|
270
|
+
# Update progress bar with current loss
|
|
271
|
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
|
272
|
+
|
|
273
|
+
if phase == 'train':
|
|
274
|
+
scheduler.step() # Update learning rate scheduler
|
|
275
|
+
|
|
276
|
+
# Calculate epoch metrics
|
|
277
|
+
epoch_loss = running_loss / dataset_sizes[phase]
|
|
278
|
+
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
|
279
|
+
|
|
280
|
+
# Print epoch results
|
|
281
|
+
print(f"{phase.capitalize()} Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
|
|
282
|
+
|
|
283
|
+
# Save best model if validation accuracy improved
|
|
284
|
+
if phase == 'val' and epoch_acc > best_acc:
|
|
285
|
+
best_acc = epoch_acc
|
|
286
|
+
best_model_wts = copy.deepcopy(model.state_dict())
|
|
287
|
+
epochs_no_improve = 0 # Reset counter since we improved
|
|
288
|
+
print(f"Saved best model with validation accuracy: {best_acc:.4f}")
|
|
289
|
+
elif phase == 'val':
|
|
290
|
+
epochs_no_improve += 1
|
|
291
|
+
print(f"No improvement for {epochs_no_improve} epochs. Best accuracy: {best_acc:.4f}")
|
|
292
|
+
|
|
293
|
+
# Check early stopping condition
|
|
294
|
+
if epochs_no_improve >= patience:
|
|
295
|
+
print(f"Early stopping triggered after {epoch+1} epochs without improvement")
|
|
296
|
+
early_stop = True
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
torch.cuda.empty_cache() # Clear GPU cache after each epoch
|
|
300
|
+
|
|
301
|
+
if early_stop:
|
|
302
|
+
print(f"Training stopped early due to no improvement for {patience} epochs")
|
|
303
|
+
|
|
304
|
+
time_elapsed = time.time() - since # Calculate elapsed time
|
|
305
|
+
print(f'Training complete in {int(time_elapsed // 60)}m {int(time_elapsed % 60)}s')
|
|
306
|
+
print(f'Best val Accuracy: {best_acc:.4f}')
|
|
307
|
+
|
|
308
|
+
model.load_state_dict(best_model_wts) # Load best model weights
|
|
309
|
+
return model
|
|
310
|
+
|
|
311
|
+
if __name__ == "__main__":
|
|
312
|
+
DEFAULT_SPECIES_LIST = [
|
|
313
|
+
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
314
|
+
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
315
|
+
"Eristalis tenax"
|
|
316
|
+
]
|
|
317
|
+
|
|
318
|
+
# Call main with default parameters
|
|
319
|
+
train_resnet(
|
|
320
|
+
species_list=DEFAULT_SPECIES_LIST,
|
|
321
|
+
model_type='resnet50',
|
|
322
|
+
batch_size=4,
|
|
323
|
+
num_epochs=2,
|
|
324
|
+
patience=5,
|
|
325
|
+
output_dir='./output',
|
|
326
|
+
data_dir='/mnt/nvme0n1p1/datasets/insect/bjerge-train2',
|
|
327
|
+
img_size=256
|
|
328
|
+
)
|
|
329
|
+
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: bplusplus
|
|
3
|
+
Version: 1.2.1
|
|
4
|
+
Summary: A simple method to create AI models for biodiversity, with collect and prepare pipeline
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Titus Venverloo
|
|
7
|
+
Author-email: tvenver@mit.edu
|
|
8
|
+
Requires-Python: >=3.9.0,<4.0.0
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Requires-Dist: prettytable (==3.7.0)
|
|
16
|
+
Requires-Dist: pygbif (>=0.6.4,<0.7.0)
|
|
17
|
+
Requires-Dist: requests (==2.25.1)
|
|
18
|
+
Requires-Dist: scikit-learn (>=1.6.1,<2.0.0)
|
|
19
|
+
Requires-Dist: tabulate (>=0.9.0,<0.10.0)
|
|
20
|
+
Requires-Dist: torch (==2.5.0)
|
|
21
|
+
Requires-Dist: ultralytics (==8.0.195)
|
|
22
|
+
Requires-Dist: validators (>=0.33.0,<0.34.0)
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# B++ repository
|
|
26
|
+
|
|
27
|
+
[](https://zenodo.org/badge/latestdoi/765250194)
|
|
28
|
+
[](https://pypi.org/project/bplusplus/)
|
|
29
|
+
[](https://pypi.org/project/bplusplus/)
|
|
30
|
+
[](https://pypi.org/project/bplusplus/)
|
|
31
|
+
[](https://pepy.tech/project/bplusplus)
|
|
32
|
+
[](https://pepy.tech/project/bplusplus)
|
|
33
|
+
[](https://pepy.tech/project/bplusplus)
|
|
34
|
+
|
|
35
|
+
This repo can be used to quickly generate models for biodiversity monitoring, relying on the GBIF dataset.
|
|
36
|
+
|
|
37
|
+
# Three pipeline options
|
|
38
|
+
|
|
39
|
+
## One stage YOLO
|
|
40
|
+
|
|
41
|
+
For the one stage pipeline, we first collect `collect()` the data from GBIF, then prepare the data for training by running the `prepare()` function, which adds bounding boxes to the images using a pretrained YOLO model. We then train the model with YOLOv8 using the `train()` function.
|
|
42
|
+
|
|
43
|
+
## Two stage YOLO/Resnet
|
|
44
|
+
|
|
45
|
+
For the two stage pipeline, we first collect `collect()` the data from GBIF, then prepare `prepare()` this (classification) data for training by either size filtering (recommended "large") which also splits the data into train and valid. We then train the model with resnet using the `train_resnet()` function. The trained model is a resnet classification model which will then be paired with a pretrained YOLOv8 insect detection model (hence two stage).
|
|
46
|
+
|
|
47
|
+
## Two stage YOLO/Multitask-Resnet
|
|
48
|
+
|
|
49
|
+
For the two stage pipeline, we first collect `collect()` the data from GBIF, then prepare `prepare()` this (classification) data for training by either size filtering (recommended "large") which also splits the data into train and valid. We then train the model with resnet using the `train_multitask()` function. The difference here is that it is training for species, order and family simultaneously. The trained model is a resnet classification model which will then be paired with a pretrained YOLOv8 insect detection model (hence two stage).
|
|
50
|
+
|
|
51
|
+
# Setup
|
|
52
|
+
|
|
53
|
+
### Install package
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
pip install bplusplus
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
### bplusplus.collect() (All pipelines)
|
|
60
|
+
|
|
61
|
+
This function takes three arguments:
|
|
62
|
+
- **search_parameters: dict[str, Any]** - List of scientific names of the species you want to collect from the GBIF database
|
|
63
|
+
- **images_per_group: int** - Number of images per species collected for training. Max 9000.
|
|
64
|
+
- **output_directory: str** - Directory to store collected images
|
|
65
|
+
- **num_threads: int** - Number of threads you want to run for collecting images. We recommend using a moderate number (3-5) to avoid overwhelming tha API server.
|
|
66
|
+
|
|
67
|
+
Example run:
|
|
68
|
+
```python
|
|
69
|
+
import bplusplus
|
|
70
|
+
|
|
71
|
+
species_list=[ "Vanessa atalanta", "Gonepteryx rhamni", "Bombus hortorum"]
|
|
72
|
+
# convert to dict
|
|
73
|
+
search: dict[str, any] = {
|
|
74
|
+
"scientificName": species_list
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
images_per_group=20
|
|
78
|
+
output_directory="/dataset/selected-species"
|
|
79
|
+
num_threads=2
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# Collect data from GBIF
|
|
83
|
+
bplusplus.collect(
|
|
84
|
+
search_parameters=search,
|
|
85
|
+
images_per_group=images_per_group,
|
|
86
|
+
output_directory=output_directory,
|
|
87
|
+
group_by_key=bplusplus.Group.scientificName,
|
|
88
|
+
num_threads=num_threads
|
|
89
|
+
)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### bplusplus.prepare() (All pipelines)
|
|
93
|
+
|
|
94
|
+
Prepares the dataset for training by performing the following steps:
|
|
95
|
+
1. Copies images from the input directory to a temporary directory.
|
|
96
|
+
2. Deletes corrupted images.
|
|
97
|
+
3. Downloads YOLOv5 weights for *insect detection* if not already present.
|
|
98
|
+
4. Runs YOLOv5 inference to generate labels for the images.
|
|
99
|
+
5. Deletes orphaned images and inferences.
|
|
100
|
+
6. Updates labels based on class mapping.
|
|
101
|
+
7. Splits the data into train, test, and validation sets.
|
|
102
|
+
8. Counts the total number of images across all splits.
|
|
103
|
+
9. Makes a YAML configuration file for YOLOv8.
|
|
104
|
+
|
|
105
|
+
This function takes three arguments:
|
|
106
|
+
- **input_directory: str** - The path to the input directory containing the images.
|
|
107
|
+
- **output_directory: str** - The path to the output directory where the prepared dataset will be saved.
|
|
108
|
+
- **with_background: bool = False** - Set to False if you don't want to include/download background images
|
|
109
|
+
- **one_stage: bool = False** - Set to True if you want to train a one stage model
|
|
110
|
+
- **size_filter: bool = False** - Set to True if you want to filter by size of insect
|
|
111
|
+
- **sizes: list = None** - List of sizes to filter by. If None, all sizes will be used, ["large", "medium", "small"].
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
# Prepare data
|
|
115
|
+
bplusplus.prepare(
|
|
116
|
+
input_directory='/dataset/selected-species',
|
|
117
|
+
output_directory='/dataset/prepared-data',
|
|
118
|
+
with_background=False,
|
|
119
|
+
one_stage=False,
|
|
120
|
+
size_filter=True,
|
|
121
|
+
sizes=["large"]
|
|
122
|
+
)
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### bplusplus.train() (One stage pipeline)
|
|
126
|
+
|
|
127
|
+
This function takes five arguments:
|
|
128
|
+
- **input_yaml: str** - yaml file created to train the model
|
|
129
|
+
- **output_directory: str**
|
|
130
|
+
- **epochs: int = 30** - Number of epochs to train the model
|
|
131
|
+
- **imgsz: int = 640** - Image size
|
|
132
|
+
- **batch: int = 16** - Batch size for training
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
# Train model
|
|
136
|
+
model = bplusplus.train(
|
|
137
|
+
input_yaml="/dataset/prepared-data/dataset.yaml", # Make sure to add the correct path
|
|
138
|
+
output_directory="trained-model",
|
|
139
|
+
epochs=30,
|
|
140
|
+
batch=16
|
|
141
|
+
)
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
### bplusplus.train_resnet() (Two stage (standard resnet) pipeline)
|
|
145
|
+
|
|
146
|
+
This function takes eight arguments:
|
|
147
|
+
- **species_list: list** - List of species to train the model on
|
|
148
|
+
- **model_type: str** - The type of resnet model to train. Options are "resnet50", "resnet152"
|
|
149
|
+
- **batch_size: int** - The batch size for training
|
|
150
|
+
- **num_epochs: int** - The number of epochs to train the model
|
|
151
|
+
- **patience: int** - The number of epochs to wait before early stopping
|
|
152
|
+
- **output_dir: str** - The path to the output directory where the trained model will be saved
|
|
153
|
+
- **data_dir: str** - The path to the directory containing the prepared data
|
|
154
|
+
- **img_size: int** - The size of the images to train the model on
|
|
155
|
+
|
|
156
|
+
```python
|
|
157
|
+
# Train resnet model
|
|
158
|
+
bplusplus.train_resnet(
|
|
159
|
+
species_list=["Vanessa atalanta", "Gonepteryx rhamni", "Bombus hortorum"],
|
|
160
|
+
model_type="resnet50",
|
|
161
|
+
batch_size=16,
|
|
162
|
+
num_epochs=30,
|
|
163
|
+
patience=5,
|
|
164
|
+
output_dir="trained-model",
|
|
165
|
+
data_dir="prepared-data",
|
|
166
|
+
img_size=256
|
|
167
|
+
)
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
### bplusplus.train_multitask() (Two stage (multitask resnet) pipeline)
|
|
171
|
+
|
|
172
|
+
This function takes seven arguments:
|
|
173
|
+
- **batch_size: int** - The batch size for training
|
|
174
|
+
- **epochs: int** - The number of epochs to train the model
|
|
175
|
+
- **patience: int** - The number of epochs to wait before early stopping
|
|
176
|
+
- **img_size: int** - The size of the images to train the model on
|
|
177
|
+
- **data_dir: str** - The path to the directory containing the prepared data
|
|
178
|
+
- **output_dir: str** - The path to the output directory where the trained model will be saved
|
|
179
|
+
- **species_list: list** - List of species to train the model on
|
|
180
|
+
|
|
181
|
+
```python
|
|
182
|
+
# Train multitask model
|
|
183
|
+
bplusplus.train_multitask(
|
|
184
|
+
batch_size=16,
|
|
185
|
+
epochs=30,
|
|
186
|
+
patience=5,
|
|
187
|
+
img_size=256,
|
|
188
|
+
data_dir="prepared-data",
|
|
189
|
+
output_dir="trained-model",
|
|
190
|
+
species_list=["Vanessa atalanta", "Gonepteryx rhamni", "Bombus hortorum"]
|
|
191
|
+
)
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
### bplusplus.validate() (One stage pipeline)
|
|
196
|
+
|
|
197
|
+
This function takes two arguments:
|
|
198
|
+
- **model** - The trained YOLO model
|
|
199
|
+
- **Path to yaml file**
|
|
200
|
+
|
|
201
|
+
```python
|
|
202
|
+
metrics = bplusplus.validate(model, '/dataset/prepared-data/dataset.yaml')
|
|
203
|
+
print(metrics)
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
### bplusplus.test_resnet() (Two stage (standard resnet) pipeline)
|
|
207
|
+
|
|
208
|
+
This function takes six arguments:
|
|
209
|
+
- **data_path: str** - The path to the directory containing the test data
|
|
210
|
+
- **yolo_weights: str** - The path to the YOLO weights
|
|
211
|
+
- **resnet_weights: str** - The path to the resnet weights
|
|
212
|
+
- **model: str** - The type of resnet model to use
|
|
213
|
+
- **species_names: list** - The list of species names
|
|
214
|
+
- **output_dir: str** - The path to the output directory where the test results will be saved
|
|
215
|
+
|
|
216
|
+
```python
|
|
217
|
+
|
|
218
|
+
bplusplus.test_resnet(
|
|
219
|
+
data_path=TEST_DATA_DIR,
|
|
220
|
+
yolo_weights=YOLO_WEIGHTS,
|
|
221
|
+
resnet_weights=RESNET_WEIGHTS,
|
|
222
|
+
model="resnet50",
|
|
223
|
+
species_names=species_list,
|
|
224
|
+
output_dir=TRAINED_MODEL_DIR
|
|
225
|
+
)
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
### bplusplus.test_multitask() (Two stage (multitask resnet) pipeline)
|
|
229
|
+
|
|
230
|
+
This function takes five arguments:
|
|
231
|
+
- **species_list: list** - List of species to test the model on
|
|
232
|
+
- **test_set: str** - The path to the directory containing the test data
|
|
233
|
+
- **yolo_weights: str** - The path to the YOLO weights
|
|
234
|
+
- **hierarchical_weights: str** - The path to the hierarchical weights
|
|
235
|
+
- **output_dir: str** - The path to the output directory where the test results will be saved
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
```python
|
|
239
|
+
bplusplus.test_multitask(
|
|
240
|
+
species_list,
|
|
241
|
+
test_set=TEST_DATA_DIR,
|
|
242
|
+
yolo_weights=YOLO_WEIGHTS,
|
|
243
|
+
hierarchical_weights=RESNET_MULTITASK_WEIGHTS,
|
|
244
|
+
output_dir=TRAINED_MODEL_DIR
|
|
245
|
+
)
|
|
246
|
+
```
|
|
247
|
+
# Citation
|
|
248
|
+
|
|
249
|
+
All information in this GitHub is available under MIT license, as long as credit is given to the authors.
|
|
250
|
+
|
|
251
|
+
**Venverloo, T., Duarte, F., B++: Towards Real-Time Monitoring of Insect Species. MIT Senseable City Laboratory, AMS Institute.**
|
|
252
|
+
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
bplusplus/__init__.py,sha256=JEFlFeUBotbHCrRELKl2k8cjGMuQDqVLncUBeO_7bv0,279
|
|
2
|
+
bplusplus/collect.py,sha256=w4G78oQY0vT0yBILNas2rdCaeVbtDSLJpNP2Sqqq_7Q,6960
|
|
3
|
+
bplusplus/hierarchical/test.py,sha256=l6j48NFD_M7gX-8Vx7EqUi6lbhYOP-5pYI0005mJO-k,29746
|
|
4
|
+
bplusplus/hierarchical/train.py,sha256=eAMInSYhk0zwgXRRO_r8awIX-fSsrOv-LV9NTCgMrOM,28115
|
|
5
|
+
bplusplus/prepare.py,sha256=CWAAnbgsUnC5oMTFplBieojr3AZA8LYcLPYJDrRx5OI,29459
|
|
6
|
+
bplusplus/resnet/test.py,sha256=jS1fWSlK-5uBkkRPDFrvhQbAVhDd1lSbO1rmaJsBg1M,21589
|
|
7
|
+
bplusplus/resnet/train.py,sha256=0_GOc2qRAkWbvOaDIhU4OQxp2Cc7m2wbriJccDS0VEk,13926
|
|
8
|
+
bplusplus/train_validate.py,sha256=uqWPXyknoAYXkFIg_YOynd1UnBraibI1fliFOK5vWwE,533
|
|
9
|
+
bplusplus-1.2.1.dist-info/LICENSE,sha256=rRkeHptDnlmviR0_WWgNT9t696eys_cjfVUU8FEO4k4,1071
|
|
10
|
+
bplusplus-1.2.1.dist-info/METADATA,sha256=yp-3t92xNILaoi1h3zuwokrYu2ttsH-mgUT1VCzss8w,9960
|
|
11
|
+
bplusplus-1.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
12
|
+
bplusplus-1.2.1.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .detect import run
|