bplusplus 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

Files changed (97) hide show
  1. bplusplus/__init__.py +4 -2
  2. bplusplus/collect.py +72 -3
  3. bplusplus/hierarchical/test.py +670 -0
  4. bplusplus/hierarchical/train.py +676 -0
  5. bplusplus/prepare.py +236 -71
  6. bplusplus/resnet/test.py +473 -0
  7. bplusplus/resnet/train.py +329 -0
  8. bplusplus-1.2.1.dist-info/METADATA +252 -0
  9. bplusplus-1.2.1.dist-info/RECORD +12 -0
  10. bplusplus/yolov5detect/__init__.py +0 -1
  11. bplusplus/yolov5detect/detect.py +0 -444
  12. bplusplus/yolov5detect/export.py +0 -1530
  13. bplusplus/yolov5detect/insect.yaml +0 -8
  14. bplusplus/yolov5detect/models/__init__.py +0 -0
  15. bplusplus/yolov5detect/models/common.py +0 -1109
  16. bplusplus/yolov5detect/models/experimental.py +0 -130
  17. bplusplus/yolov5detect/models/hub/anchors.yaml +0 -56
  18. bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +0 -52
  19. bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +0 -42
  20. bplusplus/yolov5detect/models/hub/yolov3.yaml +0 -52
  21. bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +0 -49
  22. bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +0 -43
  23. bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +0 -55
  24. bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +0 -42
  25. bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +0 -57
  26. bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +0 -68
  27. bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +0 -49
  28. bplusplus/yolov5detect/models/hub/yolov5l6.yaml +0 -61
  29. bplusplus/yolov5detect/models/hub/yolov5m6.yaml +0 -61
  30. bplusplus/yolov5detect/models/hub/yolov5n6.yaml +0 -61
  31. bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +0 -50
  32. bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +0 -49
  33. bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +0 -49
  34. bplusplus/yolov5detect/models/hub/yolov5s6.yaml +0 -61
  35. bplusplus/yolov5detect/models/hub/yolov5x6.yaml +0 -61
  36. bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +0 -49
  37. bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +0 -49
  38. bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +0 -49
  39. bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +0 -49
  40. bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +0 -49
  41. bplusplus/yolov5detect/models/tf.py +0 -797
  42. bplusplus/yolov5detect/models/yolo.py +0 -495
  43. bplusplus/yolov5detect/models/yolov5l.yaml +0 -49
  44. bplusplus/yolov5detect/models/yolov5m.yaml +0 -49
  45. bplusplus/yolov5detect/models/yolov5n.yaml +0 -49
  46. bplusplus/yolov5detect/models/yolov5s.yaml +0 -49
  47. bplusplus/yolov5detect/models/yolov5x.yaml +0 -49
  48. bplusplus/yolov5detect/utils/__init__.py +0 -97
  49. bplusplus/yolov5detect/utils/activations.py +0 -134
  50. bplusplus/yolov5detect/utils/augmentations.py +0 -448
  51. bplusplus/yolov5detect/utils/autoanchor.py +0 -175
  52. bplusplus/yolov5detect/utils/autobatch.py +0 -70
  53. bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
  54. bplusplus/yolov5detect/utils/aws/mime.sh +0 -26
  55. bplusplus/yolov5detect/utils/aws/resume.py +0 -41
  56. bplusplus/yolov5detect/utils/aws/userdata.sh +0 -27
  57. bplusplus/yolov5detect/utils/callbacks.py +0 -72
  58. bplusplus/yolov5detect/utils/dataloaders.py +0 -1385
  59. bplusplus/yolov5detect/utils/docker/Dockerfile +0 -73
  60. bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +0 -40
  61. bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +0 -42
  62. bplusplus/yolov5detect/utils/downloads.py +0 -136
  63. bplusplus/yolov5detect/utils/flask_rest_api/README.md +0 -70
  64. bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +0 -17
  65. bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +0 -49
  66. bplusplus/yolov5detect/utils/general.py +0 -1294
  67. bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +0 -25
  68. bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +0 -6
  69. bplusplus/yolov5detect/utils/google_app_engine/app.yaml +0 -16
  70. bplusplus/yolov5detect/utils/loggers/__init__.py +0 -476
  71. bplusplus/yolov5detect/utils/loggers/clearml/README.md +0 -222
  72. bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
  73. bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +0 -230
  74. bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +0 -90
  75. bplusplus/yolov5detect/utils/loggers/comet/README.md +0 -250
  76. bplusplus/yolov5detect/utils/loggers/comet/__init__.py +0 -551
  77. bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +0 -151
  78. bplusplus/yolov5detect/utils/loggers/comet/hpo.py +0 -126
  79. bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +0 -135
  80. bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
  81. bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +0 -210
  82. bplusplus/yolov5detect/utils/loss.py +0 -259
  83. bplusplus/yolov5detect/utils/metrics.py +0 -381
  84. bplusplus/yolov5detect/utils/plots.py +0 -517
  85. bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
  86. bplusplus/yolov5detect/utils/segment/augmentations.py +0 -100
  87. bplusplus/yolov5detect/utils/segment/dataloaders.py +0 -366
  88. bplusplus/yolov5detect/utils/segment/general.py +0 -160
  89. bplusplus/yolov5detect/utils/segment/loss.py +0 -198
  90. bplusplus/yolov5detect/utils/segment/metrics.py +0 -225
  91. bplusplus/yolov5detect/utils/segment/plots.py +0 -152
  92. bplusplus/yolov5detect/utils/torch_utils.py +0 -482
  93. bplusplus/yolov5detect/utils/triton.py +0 -90
  94. bplusplus-1.1.0.dist-info/METADATA +0 -179
  95. bplusplus-1.1.0.dist-info/RECORD +0 -92
  96. {bplusplus-1.1.0.dist-info → bplusplus-1.2.1.dist-info}/LICENSE +0 -0
  97. {bplusplus-1.1.0.dist-info → bplusplus-1.2.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,329 @@
1
+ # pip install ultralytics torchvision pillow numpy scikit-learn tabulate tqdm
2
+ #python3 train-resnet.py --data_dir '' --output_dir '' --arch resnet50 --img_size 956 --num_epochs 50 --batch_size 4
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset
8
+ from torchvision import datasets, models, transforms
9
+ import time
10
+ import os
11
+ import copy
12
+ from PIL import Image
13
+ import logging
14
+ from torchvision.models import ResNet152_Weights, ResNet50_Weights
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ from tqdm import tqdm
18
+ from collections import Counter
19
+ from datetime import datetime
20
+
21
+ def train_resnet(species_list, model_type='resnet152', batch_size=4, num_epochs=50, patience=5, output_dir=None, data_dir=None, img_size=1024):
22
+ """Main entry point for training the model."""
23
+ # Setup output directory
24
+ output_dir = setup_directories(output_dir)
25
+ train_dir = os.path.join(data_dir, 'train')
26
+ val_dir = os.path.join(data_dir, 'valid')
27
+
28
+ # Get transforms directly using the img_size parameter
29
+ train_transforms, val_transforms = get_transforms(img_size)
30
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
31
+ logger = logging.getLogger(__name__)
32
+
33
+ logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {num_epochs}, Patience: {patience}, Data directory: {data_dir}, Output directory: {output_dir}")
34
+
35
+ # Use InsectDataset instead of OrderedImageFolder
36
+ train_dataset = InsectDataset(train_dir, species_list=species_list, transform=train_transforms)
37
+ val_dataset = InsectDataset(val_dir, species_list=species_list, transform=val_transforms)
38
+
39
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
40
+ val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
41
+
42
+ dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
43
+ class_names = train_dataset.classes
44
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
+
46
+ logger.info(f"Using device: {device}")
47
+
48
+ # Initialize the model based on the model_type parameter
49
+ logger.info(f"Initializing {model_type} model...")
50
+ if model_type == 'resnet152':
51
+ model = models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
52
+ else:
53
+ model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
54
+
55
+ num_ftrs = model.fc.in_features
56
+ model.fc = nn.Sequential(
57
+ nn.Dropout(0.4),
58
+ nn.Linear(num_ftrs, len(class_names))
59
+ )
60
+ model = model.to(device)
61
+ logger.info(f"Model structure initialized with {sum(p.numel() for p in model.parameters())} parameters")
62
+
63
+ criterion = nn.CrossEntropyLoss()
64
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
65
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
66
+
67
+ # Print class distributions as in hierarchical-train.py
68
+ print("\nTraining set class distribution:")
69
+ train_counts = Counter(train_dataset.targets)
70
+ for idx, count in train_counts.items():
71
+ print(f" {class_names[idx]}: {count}")
72
+
73
+ print("\nValidation set class distribution:")
74
+ val_counts = Counter(val_dataset.targets)
75
+ for idx, count in val_counts.items():
76
+ print(f" {class_names[idx]}: {count}")
77
+
78
+ logger.info("Starting training process...")
79
+ model = train_model(model, criterion, optimizer, scheduler, num_epochs, device, train_loader, val_loader,
80
+ dataset_sizes, class_names, output_dir, patience=patience)
81
+
82
+ # Save the best model
83
+ model_filename = f'best_{model_type}.pt'
84
+ if output_dir:
85
+ model_path = os.path.join(output_dir, model_filename)
86
+ else:
87
+ model_path = model_filename
88
+
89
+ torch.save(model.state_dict(), model_path)
90
+ logger.info(f"Saved best model to {model_path}")
91
+ print(f"Model saved successfully!")
92
+
93
+ return model
94
+
95
+ class InsectDataset(Dataset):
96
+ """Dataset for loading and processing insect images with validation."""
97
+ def __init__(self, root_dir, species_list, transform=None):
98
+ self.root_dir = root_dir
99
+ self.transform = transform
100
+ self.samples = []
101
+ self.targets = []
102
+ self.classes = species_list
103
+ self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(species_list)}
104
+
105
+ logger = logging.getLogger(__name__)
106
+ logger.info(f"Loading dataset from {root_dir} with {len(species_list)} species")
107
+
108
+ # Validate all images similar to SafeImageFolder
109
+ invalid_count = 0
110
+ for species_name in species_list:
111
+ species_path = os.path.join(root_dir, species_name)
112
+ if os.path.isdir(species_path):
113
+ for img_file in os.listdir(species_path):
114
+ if img_file.lower().endswith(('.jpg', '.png', '.jpeg')):
115
+ img_path = os.path.join(species_path, img_file)
116
+ try:
117
+ # Validate the image can be opened
118
+ with Image.open(img_path) as img:
119
+ img.convert('RGB')
120
+ # Only add valid images to samples
121
+ class_idx = self.class_to_idx[species_name]
122
+ self.samples.append((img_path, class_idx))
123
+ self.targets.append(class_idx)
124
+ except Exception as e:
125
+ invalid_count += 1
126
+ logger.warning(f"Skipping invalid image {img_path}: {str(e)}")
127
+ else:
128
+ logger.warning(f"Species directory not found: {species_path}")
129
+
130
+ logger.info(f"Dataset loaded with {len(self.samples)} valid images ({invalid_count} invalid images skipped)")
131
+
132
+ def __len__(self):
133
+ return len(self.samples)
134
+
135
+ def __getitem__(self, idx):
136
+ """Get an image and its target in a simple approach."""
137
+ img_path, target = self.samples[idx]
138
+ # Simple direct loading
139
+ image = Image.open(img_path).convert('RGB')
140
+
141
+ if self.transform is not None:
142
+ image = self.transform(image)
143
+
144
+ return image, target
145
+
146
+ def setup_directories(output_dir):
147
+ """Create output directory if needed."""
148
+ if output_dir:
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ return output_dir
151
+
152
+ def get_transforms(img_size):
153
+ """Return training and validation transforms based on the desired image size."""
154
+ train_transforms = transforms.Compose([
155
+ transforms.RandomResizedCrop(img_size),
156
+ transforms.RandomHorizontalFlip(),
157
+ transforms.RandomVerticalFlip(),
158
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
159
+ transforms.RandomPerspective(distortion_scale=0.2),
160
+ transforms.ToTensor(),
161
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
162
+ ])
163
+
164
+ val_transforms = transforms.Compose([
165
+ transforms.Resize((int(img_size * 1.2), int(img_size * 1.2))),
166
+ transforms.CenterCrop(img_size),
167
+ transforms.ToTensor(),
168
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
169
+ ])
170
+
171
+ return train_transforms, val_transforms
172
+
173
+ def print_class_assignments(dataset, class_names, max_images_per_class=10):
174
+ """
175
+ Print information about which images belong to each class.
176
+
177
+ Args:
178
+ dataset: The dataset to analyze
179
+ class_names: List of class names
180
+ max_images_per_class: Maximum number of images to print per class
181
+ """
182
+ print("\n==== Class Assignment Information ====")
183
+
184
+ # Count images per class
185
+ class_counts = {}
186
+ for _, target in dataset.samples:
187
+ class_name = class_names[target]
188
+ if class_name not in class_counts:
189
+ class_counts[class_name] = 0
190
+ class_counts[class_name] += 1
191
+
192
+ # Print summary of class distribution
193
+ print("\nClass distribution:")
194
+ for class_name, count in class_counts.items():
195
+ print(f" {class_name}: {count} images")
196
+
197
+ # Print sample file paths for each class
198
+ print("\nSample images for each class:")
199
+ for class_name in class_names:
200
+ print(f"\nClass: {class_name}")
201
+
202
+ # Collect paths for this class
203
+ paths = []
204
+ for path, target in dataset.samples:
205
+ if class_names[target] == class_name:
206
+ paths.append(path)
207
+
208
+ # Limit the number of paths to print
209
+ if len(paths) >= max_images_per_class:
210
+ break
211
+
212
+ # Print paths
213
+ for i, path in enumerate(paths, 1):
214
+ print(f" {i}. {path}")
215
+
216
+ # Check if any images were found
217
+ if not paths:
218
+ print(" No images found for this class!")
219
+
220
+ def train_model(model, criterion, optimizer, scheduler, num_epochs, device, train_loader, val_loader, dataset_sizes, class_names, output_dir, patience=10):
221
+ """Train the model and return the best model based on validation accuracy."""
222
+ since = time.time() # Record start time
223
+ best_model_wts = copy.deepcopy(model.state_dict()) # Store initial model state
224
+ best_acc = 0.0 # Initialize best accuracy
225
+ scaler = torch.cuda.amp.GradScaler() # For mixed precision training
226
+
227
+ # Early stopping variables
228
+ epochs_no_improve = 0
229
+ early_stop = False
230
+
231
+ for epoch in range(num_epochs):
232
+ # Display current epoch
233
+ print(f'\nEpoch {epoch+1}/{num_epochs}')
234
+ print('-' * 60)
235
+
236
+ for phase in ['train', 'val']:
237
+ if phase == 'train':
238
+ model.train() # Set model to training mode
239
+ loader = train_loader
240
+ else:
241
+ model.eval() # Set model to evaluation mode
242
+ loader = val_loader
243
+
244
+ running_loss = 0.0 # Initialize loss accumulator
245
+ running_corrects = 0 # Initialize correct predictions accumulator
246
+
247
+ # Create progress bar with appropriate description
248
+ pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs} [{phase.capitalize()}]")
249
+
250
+ for inputs, labels in pbar:
251
+ inputs = inputs.to(device) # Move inputs to device
252
+ labels = labels.to(device) # Move labels to device
253
+
254
+ optimizer.zero_grad() # Reset gradients
255
+
256
+ with torch.set_grad_enabled(phase == 'train'): # Enable gradients only in train phase
257
+ with torch.cuda.amp.autocast(): # Enable mixed precision
258
+ outputs = model(inputs) # Forward pass
259
+ _, preds = torch.max(outputs, 1) # Get predictions
260
+ loss = criterion(outputs, labels) # Compute loss
261
+
262
+ if phase == 'train':
263
+ scaler.scale(loss).backward() # Backpropagation with scaling
264
+ scaler.step(optimizer) # Update parameters
265
+ scaler.update() # Update scaler for mixed precision
266
+
267
+ running_loss += loss.item() * inputs.size(0) # Accumulate loss
268
+ running_corrects += torch.sum(preds == labels.data) # Accumulate correct predictions
269
+
270
+ # Update progress bar with current loss
271
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
272
+
273
+ if phase == 'train':
274
+ scheduler.step() # Update learning rate scheduler
275
+
276
+ # Calculate epoch metrics
277
+ epoch_loss = running_loss / dataset_sizes[phase]
278
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
279
+
280
+ # Print epoch results
281
+ print(f"{phase.capitalize()} Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
282
+
283
+ # Save best model if validation accuracy improved
284
+ if phase == 'val' and epoch_acc > best_acc:
285
+ best_acc = epoch_acc
286
+ best_model_wts = copy.deepcopy(model.state_dict())
287
+ epochs_no_improve = 0 # Reset counter since we improved
288
+ print(f"Saved best model with validation accuracy: {best_acc:.4f}")
289
+ elif phase == 'val':
290
+ epochs_no_improve += 1
291
+ print(f"No improvement for {epochs_no_improve} epochs. Best accuracy: {best_acc:.4f}")
292
+
293
+ # Check early stopping condition
294
+ if epochs_no_improve >= patience:
295
+ print(f"Early stopping triggered after {epoch+1} epochs without improvement")
296
+ early_stop = True
297
+ break
298
+
299
+ torch.cuda.empty_cache() # Clear GPU cache after each epoch
300
+
301
+ if early_stop:
302
+ print(f"Training stopped early due to no improvement for {patience} epochs")
303
+
304
+ time_elapsed = time.time() - since # Calculate elapsed time
305
+ print(f'Training complete in {int(time_elapsed // 60)}m {int(time_elapsed % 60)}s')
306
+ print(f'Best val Accuracy: {best_acc:.4f}')
307
+
308
+ model.load_state_dict(best_model_wts) # Load best model weights
309
+ return model
310
+
311
+ if __name__ == "__main__":
312
+ DEFAULT_SPECIES_LIST = [
313
+ "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
314
+ "Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
315
+ "Eristalis tenax"
316
+ ]
317
+
318
+ # Call main with default parameters
319
+ train_resnet(
320
+ species_list=DEFAULT_SPECIES_LIST,
321
+ model_type='resnet50',
322
+ batch_size=4,
323
+ num_epochs=2,
324
+ patience=5,
325
+ output_dir='./output',
326
+ data_dir='/mnt/nvme0n1p1/datasets/insect/bjerge-train2',
327
+ img_size=256
328
+ )
329
+
@@ -0,0 +1,252 @@
1
+ Metadata-Version: 2.1
2
+ Name: bplusplus
3
+ Version: 1.2.1
4
+ Summary: A simple method to create AI models for biodiversity, with collect and prepare pipeline
5
+ License: MIT
6
+ Author: Titus Venverloo
7
+ Author-email: tvenver@mit.edu
8
+ Requires-Python: >=3.9.0,<4.0.0
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: prettytable (==3.7.0)
16
+ Requires-Dist: pygbif (>=0.6.4,<0.7.0)
17
+ Requires-Dist: requests (==2.25.1)
18
+ Requires-Dist: scikit-learn (>=1.6.1,<2.0.0)
19
+ Requires-Dist: tabulate (>=0.9.0,<0.10.0)
20
+ Requires-Dist: torch (==2.5.0)
21
+ Requires-Dist: ultralytics (==8.0.195)
22
+ Requires-Dist: validators (>=0.33.0,<0.34.0)
23
+ Description-Content-Type: text/markdown
24
+
25
+ # B++ repository
26
+
27
+ [![DOI](https://zenodo.org/badge/765250194.svg)](https://zenodo.org/badge/latestdoi/765250194)
28
+ [![PyPi version](https://img.shields.io/pypi/v/bplusplus.svg)](https://pypi.org/project/bplusplus/)
29
+ [![Python versions](https://img.shields.io/pypi/pyversions/bplusplus.svg)](https://pypi.org/project/bplusplus/)
30
+ [![License](https://img.shields.io/pypi/l/bplusplus.svg)](https://pypi.org/project/bplusplus/)
31
+ [![Downloads](https://static.pepy.tech/badge/bplusplus)](https://pepy.tech/project/bplusplus)
32
+ [![Downloads](https://static.pepy.tech/badge/bplusplus/month)](https://pepy.tech/project/bplusplus)
33
+ [![Downloads](https://static.pepy.tech/badge/bplusplus/week)](https://pepy.tech/project/bplusplus)
34
+
35
+ This repo can be used to quickly generate models for biodiversity monitoring, relying on the GBIF dataset.
36
+
37
+ # Three pipeline options
38
+
39
+ ## One stage YOLO
40
+
41
+ For the one stage pipeline, we first collect `collect()` the data from GBIF, then prepare the data for training by running the `prepare()` function, which adds bounding boxes to the images using a pretrained YOLO model. We then train the model with YOLOv8 using the `train()` function.
42
+
43
+ ## Two stage YOLO/Resnet
44
+
45
+ For the two stage pipeline, we first collect `collect()` the data from GBIF, then prepare `prepare()` this (classification) data for training by either size filtering (recommended "large") which also splits the data into train and valid. We then train the model with resnet using the `train_resnet()` function. The trained model is a resnet classification model which will then be paired with a pretrained YOLOv8 insect detection model (hence two stage).
46
+
47
+ ## Two stage YOLO/Multitask-Resnet
48
+
49
+ For the two stage pipeline, we first collect `collect()` the data from GBIF, then prepare `prepare()` this (classification) data for training by either size filtering (recommended "large") which also splits the data into train and valid. We then train the model with resnet using the `train_multitask()` function. The difference here is that it is training for species, order and family simultaneously. The trained model is a resnet classification model which will then be paired with a pretrained YOLOv8 insect detection model (hence two stage).
50
+
51
+ # Setup
52
+
53
+ ### Install package
54
+
55
+ ```python
56
+ pip install bplusplus
57
+ ```
58
+
59
+ ### bplusplus.collect() (All pipelines)
60
+
61
+ This function takes three arguments:
62
+ - **search_parameters: dict[str, Any]** - List of scientific names of the species you want to collect from the GBIF database
63
+ - **images_per_group: int** - Number of images per species collected for training. Max 9000.
64
+ - **output_directory: str** - Directory to store collected images
65
+ - **num_threads: int** - Number of threads you want to run for collecting images. We recommend using a moderate number (3-5) to avoid overwhelming tha API server.
66
+
67
+ Example run:
68
+ ```python
69
+ import bplusplus
70
+
71
+ species_list=[ "Vanessa atalanta", "Gonepteryx rhamni", "Bombus hortorum"]
72
+ # convert to dict
73
+ search: dict[str, any] = {
74
+ "scientificName": species_list
75
+ }
76
+
77
+ images_per_group=20
78
+ output_directory="/dataset/selected-species"
79
+ num_threads=2
80
+
81
+
82
+ # Collect data from GBIF
83
+ bplusplus.collect(
84
+ search_parameters=search,
85
+ images_per_group=images_per_group,
86
+ output_directory=output_directory,
87
+ group_by_key=bplusplus.Group.scientificName,
88
+ num_threads=num_threads
89
+ )
90
+ ```
91
+
92
+ ### bplusplus.prepare() (All pipelines)
93
+
94
+ Prepares the dataset for training by performing the following steps:
95
+ 1. Copies images from the input directory to a temporary directory.
96
+ 2. Deletes corrupted images.
97
+ 3. Downloads YOLOv5 weights for *insect detection* if not already present.
98
+ 4. Runs YOLOv5 inference to generate labels for the images.
99
+ 5. Deletes orphaned images and inferences.
100
+ 6. Updates labels based on class mapping.
101
+ 7. Splits the data into train, test, and validation sets.
102
+ 8. Counts the total number of images across all splits.
103
+ 9. Makes a YAML configuration file for YOLOv8.
104
+
105
+ This function takes three arguments:
106
+ - **input_directory: str** - The path to the input directory containing the images.
107
+ - **output_directory: str** - The path to the output directory where the prepared dataset will be saved.
108
+ - **with_background: bool = False** - Set to False if you don't want to include/download background images
109
+ - **one_stage: bool = False** - Set to True if you want to train a one stage model
110
+ - **size_filter: bool = False** - Set to True if you want to filter by size of insect
111
+ - **sizes: list = None** - List of sizes to filter by. If None, all sizes will be used, ["large", "medium", "small"].
112
+
113
+ ```python
114
+ # Prepare data
115
+ bplusplus.prepare(
116
+ input_directory='/dataset/selected-species',
117
+ output_directory='/dataset/prepared-data',
118
+ with_background=False,
119
+ one_stage=False,
120
+ size_filter=True,
121
+ sizes=["large"]
122
+ )
123
+ ```
124
+
125
+ ### bplusplus.train() (One stage pipeline)
126
+
127
+ This function takes five arguments:
128
+ - **input_yaml: str** - yaml file created to train the model
129
+ - **output_directory: str**
130
+ - **epochs: int = 30** - Number of epochs to train the model
131
+ - **imgsz: int = 640** - Image size
132
+ - **batch: int = 16** - Batch size for training
133
+
134
+ ```python
135
+ # Train model
136
+ model = bplusplus.train(
137
+ input_yaml="/dataset/prepared-data/dataset.yaml", # Make sure to add the correct path
138
+ output_directory="trained-model",
139
+ epochs=30,
140
+ batch=16
141
+ )
142
+ ```
143
+
144
+ ### bplusplus.train_resnet() (Two stage (standard resnet) pipeline)
145
+
146
+ This function takes eight arguments:
147
+ - **species_list: list** - List of species to train the model on
148
+ - **model_type: str** - The type of resnet model to train. Options are "resnet50", "resnet152"
149
+ - **batch_size: int** - The batch size for training
150
+ - **num_epochs: int** - The number of epochs to train the model
151
+ - **patience: int** - The number of epochs to wait before early stopping
152
+ - **output_dir: str** - The path to the output directory where the trained model will be saved
153
+ - **data_dir: str** - The path to the directory containing the prepared data
154
+ - **img_size: int** - The size of the images to train the model on
155
+
156
+ ```python
157
+ # Train resnet model
158
+ bplusplus.train_resnet(
159
+ species_list=["Vanessa atalanta", "Gonepteryx rhamni", "Bombus hortorum"],
160
+ model_type="resnet50",
161
+ batch_size=16,
162
+ num_epochs=30,
163
+ patience=5,
164
+ output_dir="trained-model",
165
+ data_dir="prepared-data",
166
+ img_size=256
167
+ )
168
+ ```
169
+
170
+ ### bplusplus.train_multitask() (Two stage (multitask resnet) pipeline)
171
+
172
+ This function takes seven arguments:
173
+ - **batch_size: int** - The batch size for training
174
+ - **epochs: int** - The number of epochs to train the model
175
+ - **patience: int** - The number of epochs to wait before early stopping
176
+ - **img_size: int** - The size of the images to train the model on
177
+ - **data_dir: str** - The path to the directory containing the prepared data
178
+ - **output_dir: str** - The path to the output directory where the trained model will be saved
179
+ - **species_list: list** - List of species to train the model on
180
+
181
+ ```python
182
+ # Train multitask model
183
+ bplusplus.train_multitask(
184
+ batch_size=16,
185
+ epochs=30,
186
+ patience=5,
187
+ img_size=256,
188
+ data_dir="prepared-data",
189
+ output_dir="trained-model",
190
+ species_list=["Vanessa atalanta", "Gonepteryx rhamni", "Bombus hortorum"]
191
+ )
192
+ ```
193
+
194
+
195
+ ### bplusplus.validate() (One stage pipeline)
196
+
197
+ This function takes two arguments:
198
+ - **model** - The trained YOLO model
199
+ - **Path to yaml file**
200
+
201
+ ```python
202
+ metrics = bplusplus.validate(model, '/dataset/prepared-data/dataset.yaml')
203
+ print(metrics)
204
+ ```
205
+
206
+ ### bplusplus.test_resnet() (Two stage (standard resnet) pipeline)
207
+
208
+ This function takes six arguments:
209
+ - **data_path: str** - The path to the directory containing the test data
210
+ - **yolo_weights: str** - The path to the YOLO weights
211
+ - **resnet_weights: str** - The path to the resnet weights
212
+ - **model: str** - The type of resnet model to use
213
+ - **species_names: list** - The list of species names
214
+ - **output_dir: str** - The path to the output directory where the test results will be saved
215
+
216
+ ```python
217
+
218
+ bplusplus.test_resnet(
219
+ data_path=TEST_DATA_DIR,
220
+ yolo_weights=YOLO_WEIGHTS,
221
+ resnet_weights=RESNET_WEIGHTS,
222
+ model="resnet50",
223
+ species_names=species_list,
224
+ output_dir=TRAINED_MODEL_DIR
225
+ )
226
+ ```
227
+
228
+ ### bplusplus.test_multitask() (Two stage (multitask resnet) pipeline)
229
+
230
+ This function takes five arguments:
231
+ - **species_list: list** - List of species to test the model on
232
+ - **test_set: str** - The path to the directory containing the test data
233
+ - **yolo_weights: str** - The path to the YOLO weights
234
+ - **hierarchical_weights: str** - The path to the hierarchical weights
235
+ - **output_dir: str** - The path to the output directory where the test results will be saved
236
+
237
+
238
+ ```python
239
+ bplusplus.test_multitask(
240
+ species_list,
241
+ test_set=TEST_DATA_DIR,
242
+ yolo_weights=YOLO_WEIGHTS,
243
+ hierarchical_weights=RESNET_MULTITASK_WEIGHTS,
244
+ output_dir=TRAINED_MODEL_DIR
245
+ )
246
+ ```
247
+ # Citation
248
+
249
+ All information in this GitHub is available under MIT license, as long as credit is given to the authors.
250
+
251
+ **Venverloo, T., Duarte, F., B++: Towards Real-Time Monitoring of Insect Species. MIT Senseable City Laboratory, AMS Institute.**
252
+
@@ -0,0 +1,12 @@
1
+ bplusplus/__init__.py,sha256=JEFlFeUBotbHCrRELKl2k8cjGMuQDqVLncUBeO_7bv0,279
2
+ bplusplus/collect.py,sha256=w4G78oQY0vT0yBILNas2rdCaeVbtDSLJpNP2Sqqq_7Q,6960
3
+ bplusplus/hierarchical/test.py,sha256=l6j48NFD_M7gX-8Vx7EqUi6lbhYOP-5pYI0005mJO-k,29746
4
+ bplusplus/hierarchical/train.py,sha256=eAMInSYhk0zwgXRRO_r8awIX-fSsrOv-LV9NTCgMrOM,28115
5
+ bplusplus/prepare.py,sha256=CWAAnbgsUnC5oMTFplBieojr3AZA8LYcLPYJDrRx5OI,29459
6
+ bplusplus/resnet/test.py,sha256=jS1fWSlK-5uBkkRPDFrvhQbAVhDd1lSbO1rmaJsBg1M,21589
7
+ bplusplus/resnet/train.py,sha256=0_GOc2qRAkWbvOaDIhU4OQxp2Cc7m2wbriJccDS0VEk,13926
8
+ bplusplus/train_validate.py,sha256=uqWPXyknoAYXkFIg_YOynd1UnBraibI1fliFOK5vWwE,533
9
+ bplusplus-1.2.1.dist-info/LICENSE,sha256=rRkeHptDnlmviR0_WWgNT9t696eys_cjfVUU8FEO4k4,1071
10
+ bplusplus-1.2.1.dist-info/METADATA,sha256=yp-3t92xNILaoi1h3zuwokrYu2ttsH-mgUT1VCzss8w,9960
11
+ bplusplus-1.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
12
+ bplusplus-1.2.1.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- from .detect import run