bplusplus 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,580 @@
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image
6
+ import numpy as np
7
+ from collections import defaultdict
8
+ from tabulate import tabulate
9
+ from tqdm import tqdm
10
+ import requests
11
+ import logging
12
+ import sys
13
+ import argparse
14
+
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def validate(species_list, validation_dir, hierarchical_weights, img_size=640, batch_size=32, backbone: str = "resnet50"):
19
+ """
20
+ Validate the hierarchical classifier on a directory of images organized by species.
21
+
22
+ Args:
23
+ species_list (list): List of species names used for training
24
+ validation_dir (str): Path to directory containing subdirectories for each species
25
+ hierarchical_weights (str): Path to the hierarchical classifier model file
26
+ img_size (int): Image size for validation (should match training, default: 640)
27
+ batch_size (int): Batch size for processing images (default: 32)
28
+ backbone (str): ResNet backbone to use ('resnet18', 'resnet50', 'resnet101'). Default: 'resnet50'
29
+
30
+ Returns:
31
+ Dictionary containing validation results
32
+ """
33
+ validator = HierarchicalValidator(hierarchical_weights, species_list, img_size, batch_size, backbone)
34
+ results = validator.run(validation_dir)
35
+ print("\nValidation complete with metrics calculated at all taxonomic levels")
36
+ return results
37
+
38
+ def cuda_cleanup():
39
+ """Clear CUDA cache and reset device"""
40
+ if torch.cuda.is_available():
41
+ torch.cuda.empty_cache()
42
+ torch.cuda.reset_peak_memory_stats()
43
+
44
+ def setup_gpu():
45
+ """Set up GPU with better error handling and reporting"""
46
+ if not torch.cuda.is_available():
47
+ logger.warning("CUDA is not available on this system")
48
+ return torch.device("cpu")
49
+
50
+ try:
51
+ gpu_count = torch.cuda.device_count()
52
+ logger.info(f"Found {gpu_count} CUDA device(s)")
53
+
54
+ for i in range(gpu_count):
55
+ gpu_properties = torch.cuda.get_device_properties(i)
56
+ logger.info(f"GPU {i}: {gpu_properties.name} with {gpu_properties.total_memory / 1e9:.2f} GB memory")
57
+
58
+ device = torch.device("cuda:0")
59
+ test_tensor = torch.ones(1, device=device)
60
+ test_result = test_tensor * 2
61
+ del test_tensor, test_result
62
+
63
+ logger.info("CUDA initialization successful")
64
+ return device
65
+ except Exception as e:
66
+ logger.error(f"CUDA initialization error: {str(e)}")
67
+ logger.warning("Falling back to CPU")
68
+ return torch.device("cpu")
69
+
70
+ # Add this check for backwards compatibility
71
+ if hasattr(torch.serialization, 'add_safe_globals'):
72
+ torch.serialization.add_safe_globals([
73
+ 'torch.LongTensor',
74
+ 'torch.cuda.LongTensor',
75
+ 'torch.FloatStorage',
76
+ 'torch.FloatStorage',
77
+ 'torch.cuda.FloatStorage',
78
+ ])
79
+
80
+ class HierarchicalInsectClassifier(nn.Module):
81
+ def __init__(self, num_classes_per_level, level_to_idx=None, parent_child_relationship=None, backbone: str = "resnet50"):
82
+ """
83
+ Args:
84
+ num_classes_per_level (list): Number of classes for each taxonomic level
85
+ level_to_idx (dict): Mapping from level to label indices
86
+ parent_child_relationship (dict): Parent-child relationships in taxonomy
87
+ backbone (str): ResNet backbone to use ('resnet18', 'resnet50', 'resnet101')
88
+ """
89
+ super(HierarchicalInsectClassifier, self).__init__()
90
+
91
+ self.backbone = self._build_backbone(backbone)
92
+ self.backbone_name = backbone
93
+ backbone_output_features = self.backbone.fc.in_features
94
+ self.backbone.fc = nn.Identity()
95
+
96
+ self.branches = nn.ModuleList()
97
+ for num_classes in num_classes_per_level:
98
+ branch = nn.Sequential(
99
+ nn.Linear(backbone_output_features, 512),
100
+ nn.ReLU(),
101
+ nn.Dropout(0.5),
102
+ nn.Linear(512, num_classes)
103
+ )
104
+ self.branches.append(branch)
105
+
106
+ self.num_levels = len(num_classes_per_level)
107
+
108
+ # Store the taxonomy mappings
109
+ self.level_to_idx = level_to_idx
110
+ self.parent_child_relationship = parent_child_relationship
111
+
112
+ self.register_buffer('class_means', torch.zeros(sum(num_classes_per_level)))
113
+ self.register_buffer('class_stds', torch.ones(sum(num_classes_per_level)))
114
+ self.class_counts = [0] * sum(num_classes_per_level)
115
+ self.output_history = defaultdict(list)
116
+
117
+ @staticmethod
118
+ def _build_backbone(backbone: str):
119
+ name = backbone.lower()
120
+ if name == "resnet18":
121
+ return models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
122
+ if name == "resnet50":
123
+ return models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
124
+ if name == "resnet101":
125
+ return models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
126
+ raise ValueError(f"Unsupported backbone '{backbone}'. Choose from 'resnet18', 'resnet50', 'resnet101'.")
127
+
128
+ def forward(self, x):
129
+ R0 = self.backbone(x)
130
+
131
+ outputs = []
132
+ for branch in self.branches:
133
+ outputs.append(branch(R0))
134
+
135
+ return outputs
136
+
137
+ def get_taxonomy(species_list):
138
+ """
139
+ Retrieves taxonomic information for a list of species from GBIF API.
140
+ Creates a hierarchical taxonomy dictionary with family, genus, and species relationships.
141
+ """
142
+ taxonomy = {1: [], 2: {}, 3: {}}
143
+ species_to_genus = {}
144
+ genus_to_family = {}
145
+
146
+ species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
147
+ has_unknown = len(species_list_for_gbif) != len(species_list)
148
+
149
+ logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
150
+
151
+ print("\nTaxonomy Results:")
152
+ print("-" * 80)
153
+ print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
154
+ print("-" * 80)
155
+
156
+ for species_name in species_list_for_gbif:
157
+ url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
158
+ try:
159
+ response = requests.get(url)
160
+ data = response.json()
161
+
162
+ if data.get('status') == 'ACCEPTED' or data.get('status') == 'SYNONYM':
163
+ family = data.get('family')
164
+ genus = data.get('genus')
165
+
166
+ if family and genus:
167
+ status = "OK"
168
+
169
+ print(f"{species_name:<30} {family:<20} {genus:<20} {status}")
170
+
171
+ species_to_genus[species_name] = genus
172
+ genus_to_family[genus] = family
173
+
174
+ if family not in taxonomy[1]:
175
+ taxonomy[1].append(family)
176
+
177
+ taxonomy[2][genus] = family
178
+ taxonomy[3][species_name] = genus
179
+ else:
180
+ error_msg = f"Species '{species_name}' found in GBIF but family and genus not found, could be spelling error in species, check GBIF"
181
+ logger.error(error_msg)
182
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
183
+ print(f"Error: {error_msg}")
184
+ sys.exit(1)
185
+ else:
186
+ error_msg = f"Species '{species_name}' not found in GBIF, could be spelling error, check GBIF"
187
+ logger.error(error_msg)
188
+ print(f"{species_name:<30} {'Not found':<20} {'Not found':<20} ERROR")
189
+ print(f"Error: {error_msg}")
190
+ sys.exit(1)
191
+
192
+ except Exception as e:
193
+ error_msg = f"Error retrieving data for species '{species_name}': {str(e)}"
194
+ logger.error(error_msg)
195
+ print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
196
+ print(f"Error: {error_msg}")
197
+ sys.exit(1)
198
+
199
+ if has_unknown:
200
+ unknown_family = "Unknown"
201
+ unknown_genus = "Unknown"
202
+ unknown_species = "unknown"
203
+
204
+ if unknown_family not in taxonomy[1]:
205
+ taxonomy[1].append(unknown_family)
206
+
207
+ taxonomy[2][unknown_genus] = unknown_family
208
+ taxonomy[3][unknown_species] = unknown_genus
209
+
210
+ print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
211
+
212
+ taxonomy[1] = sorted(list(set(taxonomy[1])))
213
+ print("-" * 80)
214
+
215
+ num_families = len(taxonomy[1])
216
+ num_genera = len(taxonomy[2])
217
+ num_species = len(taxonomy[3])
218
+
219
+ print("\nFamily indices:")
220
+ for i, family in enumerate(taxonomy[1]):
221
+ print(f" {i}: {family}")
222
+
223
+ print("\nGenus indices:")
224
+ for i, genus in enumerate(sorted(taxonomy[2].keys())):
225
+ print(f" {i}: {genus}")
226
+
227
+ print("\nSpecies indices:")
228
+ for i, species in enumerate(species_list):
229
+ print(f" {i}: {species}")
230
+
231
+ logger.info(f"Taxonomy built: {num_families} families, {num_genera} genera, {num_species} species")
232
+ return taxonomy, species_to_genus, genus_to_family
233
+
234
+ def create_mappings(taxonomy, species_list=None):
235
+ """Create index mappings from taxonomy"""
236
+ level_to_idx = {}
237
+ idx_to_level = {}
238
+
239
+ for level, labels in taxonomy.items():
240
+ if isinstance(labels, list):
241
+ # Level 1: Family (already sorted)
242
+ level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
243
+ idx_to_level[level] = {idx: label for idx, label in enumerate(labels)}
244
+ else: # dict for levels 2 and 3
245
+ if level == 3 and species_list is not None:
246
+ # For species, the order is determined by species_list
247
+ level_to_idx[level] = {label: idx for idx, label in enumerate(species_list)}
248
+ idx_to_level[level] = {idx: label for idx, label in enumerate(species_list)}
249
+ else:
250
+ # For genus, sort alphabetically
251
+ sorted_keys = sorted(labels.keys())
252
+ level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
253
+ idx_to_level[level] = {idx: label for idx, label in enumerate(sorted_keys)}
254
+
255
+ return level_to_idx, idx_to_level
256
+
257
+ class HierarchicalValidator:
258
+ def __init__(self, hierarchical_model_path, species_names, img_size=640, batch_size=32, backbone: str = "resnet50"):
259
+ cuda_cleanup()
260
+
261
+ self.device = setup_gpu()
262
+ logger.info(f"Using device: {self.device}")
263
+ print(f"Using device: {self.device}")
264
+
265
+ self.species_names = species_names
266
+ self.img_size = img_size
267
+ self.batch_size = batch_size
268
+ self.backbone = backbone
269
+
270
+ logger.info(f"Loading model from {hierarchical_model_path}")
271
+ try:
272
+ checkpoint = torch.load(hierarchical_model_path, map_location='cpu')
273
+ logger.info("Model loaded to CPU successfully")
274
+ except Exception as e:
275
+ logger.error(f"Error loading model: {e}")
276
+ raise
277
+
278
+ # Extract taxonomy and model state
279
+ if "model_state_dict" in checkpoint:
280
+ state_dict = checkpoint["model_state_dict"]
281
+ checkpoint_backbone = checkpoint.get("backbone", backbone)
282
+ self.backbone = checkpoint_backbone
283
+
284
+ if "taxonomy" in checkpoint:
285
+ print("Using taxonomy from saved model")
286
+ taxonomy = checkpoint["taxonomy"]
287
+ if "species_list" in checkpoint:
288
+ saved_species = checkpoint["species_list"]
289
+ print(f"Saved model was trained on: {', '.join(saved_species)}")
290
+
291
+ # Construct mappings from saved taxonomy
292
+ if "level_to_idx" in checkpoint:
293
+ level_to_idx = checkpoint["level_to_idx"]
294
+ # Create idx_to_level from level_to_idx
295
+ idx_to_level = {}
296
+ for level, label_dict in level_to_idx.items():
297
+ idx_to_level[level] = {idx: label for label, idx in label_dict.items()}
298
+ species_to_genus = {species: genus for species, genus in taxonomy[3].items()}
299
+ genus_to_family = {genus: family for genus, family in taxonomy[2].items()}
300
+ else:
301
+ # Fallback: create mappings from taxonomy
302
+ print("Warning: No level_to_idx in checkpoint, creating from taxonomy")
303
+ level_to_idx, idx_to_level = create_mappings(taxonomy, species_names)
304
+ species_to_genus = {species: genus for species, genus in taxonomy[3].items()}
305
+ genus_to_family = {genus: family for genus, family in taxonomy[2].items()}
306
+ else:
307
+ # Fetch from GBIF
308
+ print("No taxonomy in checkpoint, fetching from GBIF")
309
+ taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
310
+ level_to_idx, idx_to_level = create_mappings(taxonomy, species_names)
311
+ else:
312
+ # Old format without model_state_dict wrapper
313
+ state_dict = checkpoint
314
+ taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
315
+ level_to_idx, idx_to_level = create_mappings(taxonomy, species_names)
316
+ # keep user-provided backbone when old checkpoint is used
317
+
318
+ self.level_to_idx = level_to_idx
319
+ self.idx_to_level = idx_to_level
320
+ self.taxonomy = taxonomy
321
+ self.species_to_genus = species_to_genus
322
+ self.genus_to_family = genus_to_family
323
+
324
+ # Get number of classes per level
325
+ if hasattr(taxonomy, "items"):
326
+ num_classes_per_level = [len(classes) if isinstance(classes, list) else len(classes.keys())
327
+ for level, classes in taxonomy.items()]
328
+
329
+ print(f"Using model with class counts: {num_classes_per_level}")
330
+
331
+ # Initialize and load model
332
+ self.model = HierarchicalInsectClassifier(
333
+ num_classes_per_level,
334
+ level_to_idx=level_to_idx,
335
+ parent_child_relationship=checkpoint.get("parent_child_relationship", None) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint else None,
336
+ backbone=self.backbone
337
+ )
338
+
339
+ try:
340
+ self.model.load_state_dict(state_dict)
341
+ print("Model weights loaded successfully")
342
+ except Exception as e:
343
+ print(f"Error loading model weights: {e}")
344
+ print("Attempting to load with strict=False...")
345
+ self.model.load_state_dict(state_dict, strict=False)
346
+ print("Model weights loaded with strict=False")
347
+
348
+ try:
349
+ self.model.to(self.device)
350
+ print(f"Model successfully transferred to {self.device}")
351
+ except RuntimeError as e:
352
+ logger.error(f"Error transferring model to {self.device}: {e}")
353
+ print(f"Error transferring model to {self.device}, falling back to CPU")
354
+ self.device = torch.device("cpu")
355
+
356
+ self.model.eval()
357
+
358
+ # Transform for validation (same as training validation transform)
359
+ self.transform = transforms.Compose([
360
+ transforms.Resize((self.img_size, self.img_size)),
361
+ transforms.ToTensor(),
362
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
363
+ ])
364
+
365
+ print("Model successfully loaded")
366
+ print(f"Using image size: {self.img_size}x{self.img_size}")
367
+ print(f"Using species: {', '.join(species_names)}")
368
+
369
+ def load_images_from_directory(self, validation_dir):
370
+ """Load all images from subdirectories organized by species"""
371
+ images = []
372
+ labels = []
373
+ image_paths = []
374
+
375
+ print(f"\nLoading images from {validation_dir}")
376
+
377
+ for species_name in self.species_names:
378
+ species_dir = os.path.join(validation_dir, species_name)
379
+
380
+ if not os.path.exists(species_dir):
381
+ logger.warning(f"Directory not found for species: {species_name}")
382
+ continue
383
+
384
+ species_idx = self.species_names.index(species_name)
385
+
386
+ # Get taxonomy for this species
387
+ if species_name in self.species_to_genus:
388
+ genus_name = self.species_to_genus[species_name]
389
+ family_name = self.genus_to_family[genus_name]
390
+
391
+ genus_idx = self.level_to_idx[2][genus_name]
392
+ family_idx = self.level_to_idx[1][family_name]
393
+ else:
394
+ logger.warning(f"Taxonomy not found for species: {species_name}")
395
+ continue
396
+
397
+ # Load all images for this species
398
+ image_files = [f for f in os.listdir(species_dir)
399
+ if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
400
+
401
+ for img_file in image_files:
402
+ img_path = os.path.join(species_dir, img_file)
403
+ try:
404
+ img = Image.open(img_path).convert('RGB')
405
+ images.append(img)
406
+ labels.append([family_idx, genus_idx, species_idx])
407
+ image_paths.append(img_path)
408
+ except Exception as e:
409
+ logger.warning(f"Error loading image {img_path}: {e}")
410
+ continue
411
+
412
+ print(f" {species_name}: {len(image_files)} images")
413
+
414
+ print(f"\nTotal images loaded: {len(images)}")
415
+ return images, labels, image_paths
416
+
417
+ def predict_images(self, images):
418
+ """Run prediction on a list of images"""
419
+ predictions = []
420
+
421
+ # Process in batches
422
+ for i in tqdm(range(0, len(images), self.batch_size), desc="Processing batches"):
423
+ batch_images = images[i:i + self.batch_size]
424
+
425
+ # Transform and batch
426
+ batch_tensors = []
427
+ for img in batch_images:
428
+ tensor = self.transform(img)
429
+ batch_tensors.append(tensor)
430
+
431
+ batch_tensor = torch.stack(batch_tensors).to(self.device)
432
+
433
+ # Predict
434
+ with torch.no_grad():
435
+ outputs = self.model(batch_tensor)
436
+
437
+ # Get predictions for each level
438
+ for j in range(len(batch_images)):
439
+ pred = []
440
+ for level_output in outputs:
441
+ level_pred = level_output[j].argmax().item()
442
+ pred.append(level_pred)
443
+ predictions.append(pred)
444
+
445
+ return predictions
446
+
447
+ def run(self, validation_dir):
448
+ """Run validation on the dataset"""
449
+ # Load images
450
+ images, labels, image_paths = self.load_images_from_directory(validation_dir)
451
+
452
+ if len(images) == 0:
453
+ logger.error("No images found in validation directory")
454
+ return None
455
+
456
+ # Get predictions
457
+ print("\nRunning predictions...")
458
+ predictions = self.predict_images(images)
459
+
460
+ # Calculate metrics
461
+ print("\n" + "="*80)
462
+ print("VALIDATION RESULTS")
463
+ print("="*80)
464
+ self.calculate_metrics(predictions, labels)
465
+
466
+ return {
467
+ 'predictions': predictions,
468
+ 'labels': labels,
469
+ 'image_paths': image_paths
470
+ }
471
+
472
+ def calculate_metrics(self, predictions, labels):
473
+ """Calculate metrics at all taxonomic levels"""
474
+ # Convert to numpy arrays for easier manipulation
475
+ predictions = np.array(predictions)
476
+ labels = np.array(labels)
477
+
478
+ level_names = ['Family', 'Genus', 'Species']
479
+
480
+ for level in range(3):
481
+ print(f"\n{'='*80}")
482
+ print(f"{level_names[level]}-level Metrics")
483
+ print(f"{'='*80}")
484
+
485
+ # Get predictions and labels for this level
486
+ level_preds = predictions[:, level]
487
+ level_labels = labels[:, level]
488
+
489
+ # Get label names for this level
490
+ if level == 0: # Family
491
+ label_names = [self.idx_to_level[1][i] for i in sorted(self.idx_to_level[1].keys())]
492
+ elif level == 1: # Genus
493
+ label_names = [self.idx_to_level[2][i] for i in sorted(self.idx_to_level[2].keys())]
494
+ else: # Species
495
+ label_names = self.species_names
496
+
497
+ self.print_classification_report(level_preds, level_labels, label_names)
498
+
499
+ def print_classification_report(self, predictions, labels, label_names):
500
+ """Print detailed classification metrics"""
501
+ # Calculate per-class metrics
502
+ unique_labels = sorted(set(labels))
503
+
504
+ table_data = []
505
+ total_correct = 0
506
+ total_samples = 0
507
+
508
+ for label_idx in unique_labels:
509
+ # Get indices for this class
510
+ class_mask = labels == label_idx
511
+ class_preds = predictions[class_mask]
512
+ class_labels = labels[class_mask]
513
+
514
+ # Calculate metrics
515
+ support = len(class_labels)
516
+ correct = (class_preds == class_labels).sum()
517
+
518
+ # Calculate precision, recall, f1
519
+ tp = correct
520
+ fp = ((predictions == label_idx) & (labels != label_idx)).sum()
521
+ fn = support - correct
522
+
523
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
524
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
525
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
526
+
527
+ label_name = label_names[label_idx] if label_idx < len(label_names) else f"Label {label_idx}"
528
+ table_data.append([label_name, f"{precision:.4f}", f"{recall:.4f}", f"{f1_score:.4f}", support])
529
+
530
+ total_correct += correct
531
+ total_samples += support
532
+
533
+ # Calculate overall metrics
534
+ overall_accuracy = total_correct / total_samples if total_samples > 0 else 0
535
+
536
+ # Calculate macro and weighted averages
537
+ macro_precision = np.mean([float(row[1]) for row in table_data])
538
+ macro_recall = np.mean([float(row[2]) for row in table_data])
539
+ macro_f1 = np.mean([float(row[3]) for row in table_data])
540
+
541
+ weighted_precision = np.sum([float(row[1]) * row[4] for row in table_data]) / total_samples
542
+ weighted_recall = np.sum([float(row[2]) * row[4] for row in table_data]) / total_samples
543
+ weighted_f1 = np.sum([float(row[3]) * row[4] for row in table_data]) / total_samples
544
+
545
+ # Add summary rows
546
+ table_data.append([])
547
+ table_data.append(["Macro avg", f"{macro_precision:.4f}", f"{macro_recall:.4f}", f"{macro_f1:.4f}", total_samples])
548
+ table_data.append(["Weighted avg", f"{weighted_precision:.4f}", f"{weighted_recall:.4f}", f"{weighted_f1:.4f}", total_samples])
549
+ table_data.append([])
550
+ table_data.append(["Overall Accuracy", "", "", f"{overall_accuracy:.4f}", total_samples])
551
+
552
+ headers = ["Label", "Precision", "Recall", "F1 Score", "Support"]
553
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
554
+
555
+ if __name__ == "__main__":
556
+ parser = argparse.ArgumentParser(description='Validate hierarchical insect classifier')
557
+ parser.add_argument('--validation_dir', type=str, required=True,
558
+ help='Path to validation directory with subdirectories for each species')
559
+ parser.add_argument('--weights', type=str, required=True,
560
+ help='Path to hierarchical model weights')
561
+ parser.add_argument('--species', type=str, nargs='+', required=True,
562
+ help='List of species names (must match training order)')
563
+ parser.add_argument('--img_size', type=int, default=640,
564
+ help='Image size for validation (should match training, default: 640)')
565
+ parser.add_argument('--batch_size', type=int, default=32,
566
+ help='Batch size for processing (default: 32)')
567
+ parser.add_argument('--backbone', type=str, default='resnet50', choices=['resnet18', 'resnet50', 'resnet101'],
568
+ help='ResNet backbone to use (default: resnet50)')
569
+
570
+ args = parser.parse_args()
571
+
572
+ validate(
573
+ species_list=args.species,
574
+ validation_dir=args.validation_dir,
575
+ hierarchical_weights=args.weights,
576
+ img_size=args.img_size,
577
+ batch_size=args.batch_size,
578
+ backbone=args.backbone
579
+ )
580
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Titus Venverloo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.