bplusplus 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bplusplus might be problematic. Click here for more details.
- bplusplus/__init__.py +13 -5
- bplusplus/inference.py +929 -0
- bplusplus/prepare.py +416 -648
- bplusplus/{hierarchical/test.py → test.py} +32 -9
- bplusplus/tracker.py +261 -0
- bplusplus/{hierarchical/train.py → train.py} +48 -14
- bplusplus-1.2.4.dist-info/METADATA +207 -0
- bplusplus-1.2.4.dist-info/RECORD +11 -0
- {bplusplus-1.2.2.dist-info → bplusplus-1.2.4.dist-info}/WHEEL +1 -1
- bplusplus/resnet/test.py +0 -473
- bplusplus/resnet/train.py +0 -329
- bplusplus/train_validate.py +0 -11
- bplusplus-1.2.2.dist-info/METADATA +0 -260
- bplusplus-1.2.2.dist-info/RECORD +0 -12
- {bplusplus-1.2.2.dist-info → bplusplus-1.2.4.dist-info}/LICENSE +0 -0
|
@@ -23,7 +23,7 @@ import sys
|
|
|
23
23
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
24
24
|
logger = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
|
-
def
|
|
26
|
+
def test(species_list, test_set, yolo_weights, hierarchical_weights, output_dir="."):
|
|
27
27
|
"""
|
|
28
28
|
Run the two-stage classifier on a test set.
|
|
29
29
|
|
|
@@ -74,6 +74,16 @@ def setup_gpu():
|
|
|
74
74
|
logger.warning("Falling back to CPU")
|
|
75
75
|
return torch.device("cpu")
|
|
76
76
|
|
|
77
|
+
# Add this check for backwards compatibility
|
|
78
|
+
if hasattr(torch.serialization, 'add_safe_globals'):
|
|
79
|
+
torch.serialization.add_safe_globals([
|
|
80
|
+
'torch.LongTensor',
|
|
81
|
+
'torch.cuda.LongTensor',
|
|
82
|
+
'torch.FloatStorage',
|
|
83
|
+
'torch.FloatStorage',
|
|
84
|
+
'torch.cuda.FloatStorage',
|
|
85
|
+
])
|
|
86
|
+
|
|
77
87
|
class HierarchicalInsectClassifier(nn.Module):
|
|
78
88
|
def __init__(self, num_classes_per_level):
|
|
79
89
|
"""
|
|
@@ -243,8 +253,15 @@ class TestTwoStage:
|
|
|
243
253
|
if "species_list" in checkpoint:
|
|
244
254
|
saved_species = checkpoint["species_list"]
|
|
245
255
|
print(f"Saved model was trained on: {', '.join(saved_species)}")
|
|
246
|
-
|
|
247
|
-
taxonomy
|
|
256
|
+
|
|
257
|
+
# Use saved taxonomy mappings if available
|
|
258
|
+
if "species_to_genus" in checkpoint and "genus_to_family" in checkpoint:
|
|
259
|
+
species_to_genus = checkpoint["species_to_genus"]
|
|
260
|
+
genus_to_family = checkpoint["genus_to_family"]
|
|
261
|
+
else:
|
|
262
|
+
# Fallback: fetch from GBIF but this may cause index mismatches
|
|
263
|
+
print("Warning: No taxonomy mappings in checkpoint, fetching from GBIF")
|
|
264
|
+
_, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
248
265
|
else:
|
|
249
266
|
taxonomy, species_to_genus, genus_to_family = get_taxonomy(species_names)
|
|
250
267
|
else:
|
|
@@ -285,8 +302,6 @@ class TestTwoStage:
|
|
|
285
302
|
self.classification_model.eval()
|
|
286
303
|
|
|
287
304
|
self.classification_transform = transforms.Compose([
|
|
288
|
-
transforms.Resize((768, 768)), # Fixed size for all validation images
|
|
289
|
-
transforms.CenterCrop(640),
|
|
290
305
|
transforms.ToTensor(),
|
|
291
306
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
292
307
|
])
|
|
@@ -467,10 +482,18 @@ class TestTwoStage:
|
|
|
467
482
|
predicted_genus_frames, true_genus_frames,
|
|
468
483
|
predicted_family_frames, true_family_frames):
|
|
469
484
|
"""Calculate metrics at all taxonomic levels"""
|
|
470
|
-
# Get list of species, families and genera
|
|
485
|
+
# Get list of species, families and genera using the same order as model training
|
|
471
486
|
species_list = self.species_names
|
|
472
|
-
|
|
473
|
-
|
|
487
|
+
|
|
488
|
+
# Use the index mappings from the model to ensure consistency
|
|
489
|
+
if 1 in self.idx_to_level and 2 in self.idx_to_level:
|
|
490
|
+
family_list = [self.idx_to_level[1][i] for i in sorted(self.idx_to_level[1].keys())]
|
|
491
|
+
genus_list = [self.idx_to_level[2][i] for i in sorted(self.idx_to_level[2].keys())]
|
|
492
|
+
else:
|
|
493
|
+
# Fallback to sorted lists (may cause issues)
|
|
494
|
+
print("Warning: Using fallback sorted lists for taxonomy - this may cause index mismatches")
|
|
495
|
+
genus_list = sorted(list(set(self.species_to_genus.values())))
|
|
496
|
+
family_list = sorted(list(set(self.genus_to_family.values())))
|
|
474
497
|
|
|
475
498
|
# Print the index mappings we're using for evaluation
|
|
476
499
|
print("\nUsing the following index mappings for evaluation:")
|
|
@@ -665,4 +688,4 @@ if __name__ == "__main__":
|
|
|
665
688
|
hierarchical_model_path = "/mnt/nvme0n1p1/mit/two-stage-detection/hierarchical/hierarchical-weights.pth"
|
|
666
689
|
output_directory = "./output"
|
|
667
690
|
|
|
668
|
-
|
|
691
|
+
test(species_names, test_directory, yolo_model_path, hierarchical_model_path, output_directory)
|
bplusplus/tracker.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import uuid
|
|
3
|
+
from scipy.optimize import linear_sum_assignment
|
|
4
|
+
from collections import deque
|
|
5
|
+
|
|
6
|
+
class BoundingBox:
|
|
7
|
+
def __init__(self, x, y, width, height, frame_id, track_id=None):
|
|
8
|
+
self.x = x
|
|
9
|
+
self.y = y
|
|
10
|
+
self.width = width
|
|
11
|
+
self.height = height
|
|
12
|
+
self.area = width * height
|
|
13
|
+
self.frame_id = frame_id
|
|
14
|
+
self.track_id = track_id
|
|
15
|
+
|
|
16
|
+
def center(self):
|
|
17
|
+
return (self.x + self.width/2, self.y + self.height/2)
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_xyxy(cls, x1, y1, x2, y2, frame_id, track_id=None):
|
|
21
|
+
"""Create BoundingBox from x1,y1,x2,y2 coordinates"""
|
|
22
|
+
width = x2 - x1
|
|
23
|
+
height = y2 - y1
|
|
24
|
+
return cls(x1, y1, width, height, frame_id, track_id)
|
|
25
|
+
|
|
26
|
+
class InsectTracker:
|
|
27
|
+
def __init__(self, image_height, image_width, max_frames=30, w_dist=0.7, w_area=0.3, cost_threshold=0.8, track_memory_frames=None, debug=False):
|
|
28
|
+
self.image_height = image_height
|
|
29
|
+
self.image_width = image_width
|
|
30
|
+
self.max_dist = np.sqrt(image_height**2 + image_width**2)
|
|
31
|
+
self.max_frames = max_frames
|
|
32
|
+
self.w_dist = w_dist
|
|
33
|
+
self.w_area = w_area
|
|
34
|
+
self.cost_threshold = cost_threshold
|
|
35
|
+
self.debug = debug
|
|
36
|
+
|
|
37
|
+
# If track_memory_frames not specified, use max_frames (full history window)
|
|
38
|
+
self.track_memory_frames = track_memory_frames if track_memory_frames is not None else max_frames
|
|
39
|
+
if self.debug:
|
|
40
|
+
print(f"DEBUG: Tracker initialized with max_frames={max_frames}, track_memory_frames={self.track_memory_frames}")
|
|
41
|
+
|
|
42
|
+
self.tracking_history = deque(maxlen=max_frames)
|
|
43
|
+
self.current_tracks = []
|
|
44
|
+
self.lost_tracks = {} # track_id -> {box: BoundingBox, frames_lost: int}
|
|
45
|
+
|
|
46
|
+
def _generate_track_id(self):
|
|
47
|
+
"""Generate a unique UUID for a new track"""
|
|
48
|
+
return str(uuid.uuid4())
|
|
49
|
+
|
|
50
|
+
def calculate_cost(self, box1, box2):
|
|
51
|
+
"""Calculate cost between two bounding boxes as per equation (4)"""
|
|
52
|
+
# Calculate center points
|
|
53
|
+
cx1, cy1 = box1.center()
|
|
54
|
+
cx2, cy2 = box2.center()
|
|
55
|
+
|
|
56
|
+
# Euclidean distance (equation 1)
|
|
57
|
+
dist = np.sqrt((cx2 - cx1)**2 + (cy2 - cy1)**2)
|
|
58
|
+
|
|
59
|
+
# Normalized distance (equation 2 used for normalization)
|
|
60
|
+
norm_dist = dist / self.max_dist
|
|
61
|
+
|
|
62
|
+
# Area cost (equation 3)
|
|
63
|
+
min_area = min(box1.area, box2.area)
|
|
64
|
+
max_area = max(box1.area, box2.area)
|
|
65
|
+
area_cost = min_area / max_area if max_area > 0 else 1.0
|
|
66
|
+
|
|
67
|
+
# Final cost (equation 4)
|
|
68
|
+
cost = (norm_dist * self.w_dist) + ((1 - area_cost) * self.w_area)
|
|
69
|
+
|
|
70
|
+
return cost
|
|
71
|
+
|
|
72
|
+
def build_cost_matrix(self, prev_boxes, curr_boxes):
|
|
73
|
+
"""Build cost matrix for Hungarian algorithm"""
|
|
74
|
+
n_prev = len(prev_boxes)
|
|
75
|
+
n_curr = len(curr_boxes)
|
|
76
|
+
n = max(n_prev, n_curr)
|
|
77
|
+
|
|
78
|
+
# Initialize cost matrix with high values
|
|
79
|
+
cost_matrix = np.ones((n, n)) * 999.0
|
|
80
|
+
|
|
81
|
+
# Fill in actual costs
|
|
82
|
+
for i in range(n_prev):
|
|
83
|
+
for j in range(n_curr):
|
|
84
|
+
cost_matrix[i, j] = self.calculate_cost(prev_boxes[i], curr_boxes[j])
|
|
85
|
+
|
|
86
|
+
return cost_matrix, n_prev, n_curr
|
|
87
|
+
|
|
88
|
+
def update(self, new_detections, frame_id):
|
|
89
|
+
"""
|
|
90
|
+
Update tracking with new detections from YOLO
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
new_detections: List of YOLO detection boxes (x1, y1, x2, y2 format)
|
|
94
|
+
frame_id: Current frame number
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
List of track IDs corresponding to each detection
|
|
98
|
+
"""
|
|
99
|
+
# Handle empty detection list (no detections in this frame)
|
|
100
|
+
if not new_detections:
|
|
101
|
+
if self.debug:
|
|
102
|
+
print(f"DEBUG: Frame {frame_id} has no detections")
|
|
103
|
+
# Move all current tracks to lost tracks
|
|
104
|
+
for track in self.current_tracks:
|
|
105
|
+
if track.track_id not in self.lost_tracks:
|
|
106
|
+
self.lost_tracks[track.track_id] = {
|
|
107
|
+
'box': track,
|
|
108
|
+
'frames_lost': 1
|
|
109
|
+
}
|
|
110
|
+
if self.debug:
|
|
111
|
+
print(f"DEBUG: Moved track {track.track_id} to lost tracks")
|
|
112
|
+
else:
|
|
113
|
+
self.lost_tracks[track.track_id]['frames_lost'] += 1
|
|
114
|
+
|
|
115
|
+
# Age lost tracks and remove old ones
|
|
116
|
+
self._age_lost_tracks()
|
|
117
|
+
|
|
118
|
+
self.current_tracks = []
|
|
119
|
+
self.tracking_history.append([])
|
|
120
|
+
return []
|
|
121
|
+
|
|
122
|
+
# Convert YOLO detections to BoundingBox objects
|
|
123
|
+
new_boxes = []
|
|
124
|
+
for i, detection in enumerate(new_detections):
|
|
125
|
+
x1, y1, x2, y2 = detection[:4]
|
|
126
|
+
bbox = BoundingBox.from_xyxy(x1, y1, x2, y2, frame_id)
|
|
127
|
+
new_boxes.append(bbox)
|
|
128
|
+
|
|
129
|
+
# If this is the first frame or no existing tracks, assign new track IDs to all boxes
|
|
130
|
+
if not self.current_tracks and not self.lost_tracks:
|
|
131
|
+
track_ids = []
|
|
132
|
+
for box in new_boxes:
|
|
133
|
+
box.track_id = self._generate_track_id()
|
|
134
|
+
track_ids.append(box.track_id)
|
|
135
|
+
if self.debug:
|
|
136
|
+
print(f"DEBUG: FIRST FRAME - Assigned track ID {box.track_id} to new detection")
|
|
137
|
+
self.current_tracks = new_boxes
|
|
138
|
+
self.tracking_history.append(new_boxes)
|
|
139
|
+
return track_ids
|
|
140
|
+
|
|
141
|
+
# Combine current tracks and lost tracks for matching
|
|
142
|
+
all_previous_tracks = self.current_tracks.copy()
|
|
143
|
+
lost_track_list = []
|
|
144
|
+
|
|
145
|
+
for track_id, lost_info in self.lost_tracks.items():
|
|
146
|
+
lost_track_list.append(lost_info['box'])
|
|
147
|
+
lost_track_list[-1].track_id = track_id # Ensure track_id is preserved
|
|
148
|
+
|
|
149
|
+
all_previous_tracks.extend(lost_track_list)
|
|
150
|
+
|
|
151
|
+
if not all_previous_tracks:
|
|
152
|
+
# No previous tracks at all, assign new IDs
|
|
153
|
+
track_ids = []
|
|
154
|
+
for box in new_boxes:
|
|
155
|
+
box.track_id = self._generate_track_id()
|
|
156
|
+
track_ids.append(box.track_id)
|
|
157
|
+
if self.debug:
|
|
158
|
+
print(f"DEBUG: No previous tracks - Assigned track ID {box.track_id} to new detection")
|
|
159
|
+
self.current_tracks = new_boxes
|
|
160
|
+
self.tracking_history.append(new_boxes)
|
|
161
|
+
return track_ids
|
|
162
|
+
|
|
163
|
+
# Build cost matrix including lost tracks
|
|
164
|
+
cost_matrix, n_prev, n_curr = self.build_cost_matrix(all_previous_tracks, new_boxes)
|
|
165
|
+
|
|
166
|
+
# Apply Hungarian algorithm
|
|
167
|
+
row_indices, col_indices = linear_sum_assignment(cost_matrix)
|
|
168
|
+
|
|
169
|
+
# Assign track IDs based on the matching
|
|
170
|
+
assigned_curr_indices = set()
|
|
171
|
+
track_ids = [None] * len(new_boxes)
|
|
172
|
+
recovered_tracks = set() # Track IDs that were recovered from lost tracks
|
|
173
|
+
|
|
174
|
+
if self.debug:
|
|
175
|
+
print(f"DEBUG: Hungarian assignment - rows: {row_indices}, cols: {col_indices}")
|
|
176
|
+
print(f"DEBUG: Cost threshold: {self.cost_threshold}")
|
|
177
|
+
print(f"DEBUG: Current tracks: {len(self.current_tracks)}, Lost tracks: {len(self.lost_tracks)}")
|
|
178
|
+
|
|
179
|
+
for i, j in zip(row_indices, col_indices):
|
|
180
|
+
# Only consider valid assignments (not dummy rows/columns)
|
|
181
|
+
if i < n_prev and j < n_curr:
|
|
182
|
+
cost = cost_matrix[i, j]
|
|
183
|
+
if self.debug:
|
|
184
|
+
print(f"DEBUG: Checking assignment {i}->{j}, cost: {cost:.3f}")
|
|
185
|
+
# Check if cost is below threshold
|
|
186
|
+
if cost < self.cost_threshold:
|
|
187
|
+
# Assign the track ID from previous box to current box
|
|
188
|
+
prev_track_id = all_previous_tracks[i].track_id
|
|
189
|
+
new_boxes[j].track_id = prev_track_id
|
|
190
|
+
track_ids[j] = prev_track_id
|
|
191
|
+
assigned_curr_indices.add(j)
|
|
192
|
+
|
|
193
|
+
# Check if this was a lost track being recovered
|
|
194
|
+
if prev_track_id in self.lost_tracks:
|
|
195
|
+
recovered_tracks.add(prev_track_id)
|
|
196
|
+
if self.debug:
|
|
197
|
+
print(f"DEBUG: RECOVERED lost track ID {prev_track_id} for detection {j} (was lost for {self.lost_tracks[prev_track_id]['frames_lost']} frames)")
|
|
198
|
+
else:
|
|
199
|
+
if self.debug:
|
|
200
|
+
print(f"DEBUG: Continued track ID {prev_track_id} for detection {j}")
|
|
201
|
+
else:
|
|
202
|
+
if self.debug:
|
|
203
|
+
print(f"DEBUG: Cost {cost:.3f} above threshold {self.cost_threshold}, not assigning")
|
|
204
|
+
|
|
205
|
+
# Remove recovered tracks from lost tracks
|
|
206
|
+
for track_id in recovered_tracks:
|
|
207
|
+
del self.lost_tracks[track_id]
|
|
208
|
+
|
|
209
|
+
# Assign new track IDs to unassigned current boxes (new insects)
|
|
210
|
+
for j in range(n_curr):
|
|
211
|
+
if j not in assigned_curr_indices:
|
|
212
|
+
new_boxes[j].track_id = self._generate_track_id()
|
|
213
|
+
track_ids[j] = new_boxes[j].track_id
|
|
214
|
+
if self.debug:
|
|
215
|
+
print(f"DEBUG: Assigned NEW track ID {new_boxes[j].track_id} to detection {j}")
|
|
216
|
+
|
|
217
|
+
# Move unmatched current tracks to lost tracks (tracks that disappeared this frame)
|
|
218
|
+
matched_track_ids = {track_ids[j] for j in assigned_curr_indices if track_ids[j] is not None}
|
|
219
|
+
for track in self.current_tracks:
|
|
220
|
+
if track.track_id not in matched_track_ids and track.track_id not in recovered_tracks:
|
|
221
|
+
if track.track_id not in self.lost_tracks:
|
|
222
|
+
self.lost_tracks[track.track_id] = {
|
|
223
|
+
'box': track,
|
|
224
|
+
'frames_lost': 1
|
|
225
|
+
}
|
|
226
|
+
if self.debug:
|
|
227
|
+
print(f"DEBUG: Track {track.track_id} disappeared, moved to lost tracks")
|
|
228
|
+
|
|
229
|
+
# Age lost tracks and remove old ones
|
|
230
|
+
self._age_lost_tracks()
|
|
231
|
+
|
|
232
|
+
# Update current tracks
|
|
233
|
+
self.current_tracks = new_boxes
|
|
234
|
+
|
|
235
|
+
# Add to tracking history
|
|
236
|
+
self.tracking_history.append(new_boxes)
|
|
237
|
+
|
|
238
|
+
return track_ids
|
|
239
|
+
|
|
240
|
+
def _age_lost_tracks(self):
|
|
241
|
+
"""Age lost tracks and remove those that have been lost too long"""
|
|
242
|
+
tracks_to_remove = []
|
|
243
|
+
for track_id, lost_info in self.lost_tracks.items():
|
|
244
|
+
lost_info['frames_lost'] += 1
|
|
245
|
+
if lost_info['frames_lost'] > self.track_memory_frames:
|
|
246
|
+
tracks_to_remove.append(track_id)
|
|
247
|
+
if self.debug:
|
|
248
|
+
print(f"DEBUG: Permanently removing track {track_id} (lost for {lost_info['frames_lost']} frames)")
|
|
249
|
+
|
|
250
|
+
for track_id in tracks_to_remove:
|
|
251
|
+
del self.lost_tracks[track_id]
|
|
252
|
+
|
|
253
|
+
def get_tracking_stats(self):
|
|
254
|
+
"""Get current tracking statistics for debugging/monitoring"""
|
|
255
|
+
return {
|
|
256
|
+
'active_tracks': len(self.current_tracks),
|
|
257
|
+
'lost_tracks': len(self.lost_tracks),
|
|
258
|
+
'active_track_ids': [track.track_id for track in self.current_tracks],
|
|
259
|
+
'lost_track_ids': list(self.lost_tracks.keys()),
|
|
260
|
+
'total_history_frames': len(self.tracking_history)
|
|
261
|
+
}
|
|
@@ -14,18 +14,28 @@ import logging
|
|
|
14
14
|
from tqdm import tqdm
|
|
15
15
|
import sys
|
|
16
16
|
|
|
17
|
-
def
|
|
17
|
+
def train(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='input', output_dir='./output', species_list=None, num_workers=4):
|
|
18
18
|
"""
|
|
19
19
|
Main function to run the entire training pipeline.
|
|
20
20
|
Sets up datasets, model, training process and handles errors.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
batch_size (int): Number of samples per batch. Default: 4
|
|
24
|
+
epochs (int): Maximum number of training epochs. Default: 30
|
|
25
|
+
patience (int): Early stopping patience (epochs without improvement). Default: 3
|
|
26
|
+
img_size (int): Target image size for training. Default: 640
|
|
27
|
+
data_dir (str): Directory containing train/valid subdirectories. Default: 'input'
|
|
28
|
+
output_dir (str): Directory to save trained model and logs. Default: './output'
|
|
29
|
+
species_list (list): List of species names for training. Required.
|
|
30
|
+
num_workers (int): Number of DataLoader worker processes.
|
|
31
|
+
Set to 0 to disable multiprocessing (most stable). Default: 4
|
|
21
32
|
"""
|
|
22
33
|
global logger, device
|
|
23
34
|
|
|
24
35
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
25
36
|
logger = logging.getLogger(__name__)
|
|
26
37
|
|
|
27
|
-
logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}")
|
|
28
|
-
|
|
38
|
+
logger.info(f"Hyperparameters - Batch size: {batch_size}, Epochs: {epochs}, Patience: {patience}, Image size: {img_size}, Data directory: {data_dir}, Output directory: {output_dir}, Num workers: {num_workers}")
|
|
29
39
|
|
|
30
40
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
31
41
|
|
|
@@ -52,7 +62,7 @@ def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir=
|
|
|
52
62
|
|
|
53
63
|
taxonomy = get_taxonomy(species_list)
|
|
54
64
|
|
|
55
|
-
level_to_idx, parent_child_relationship = create_mappings(taxonomy)
|
|
65
|
+
level_to_idx, parent_child_relationship = create_mappings(taxonomy, species_list)
|
|
56
66
|
|
|
57
67
|
num_classes_per_level = [len(taxonomy[level]) if isinstance(taxonomy[level], list)
|
|
58
68
|
else len(taxonomy[level].keys()) for level in sorted(taxonomy.keys())]
|
|
@@ -75,14 +85,14 @@ def train_multitask(batch_size=4, epochs=30, patience=3, img_size=640, data_dir=
|
|
|
75
85
|
train_dataset,
|
|
76
86
|
batch_size=batch_size,
|
|
77
87
|
shuffle=True,
|
|
78
|
-
num_workers=
|
|
88
|
+
num_workers=num_workers
|
|
79
89
|
)
|
|
80
90
|
|
|
81
91
|
val_loader = DataLoader(
|
|
82
92
|
val_dataset,
|
|
83
93
|
batch_size=batch_size,
|
|
84
94
|
shuffle=False,
|
|
85
|
-
num_workers=
|
|
95
|
+
num_workers=num_workers
|
|
86
96
|
)
|
|
87
97
|
|
|
88
98
|
try:
|
|
@@ -150,14 +160,17 @@ def get_taxonomy(species_list):
|
|
|
150
160
|
species_to_genus = {}
|
|
151
161
|
genus_to_family = {}
|
|
152
162
|
|
|
153
|
-
|
|
163
|
+
species_list_for_gbif = [s for s in species_list if s.lower() != 'unknown']
|
|
164
|
+
has_unknown = len(species_list_for_gbif) != len(species_list)
|
|
165
|
+
|
|
166
|
+
logger.info(f"Building taxonomy from GBIF for {len(species_list_for_gbif)} species")
|
|
154
167
|
|
|
155
168
|
print("\nTaxonomy Results:")
|
|
156
169
|
print("-" * 80)
|
|
157
170
|
print(f"{'Species':<30} {'Family':<20} {'Genus':<20} {'Status'}")
|
|
158
171
|
print("-" * 80)
|
|
159
172
|
|
|
160
|
-
for species_name in
|
|
173
|
+
for species_name in species_list_for_gbif:
|
|
161
174
|
url = f"https://api.gbif.org/v1/species/match?name={species_name}&verbose=true"
|
|
162
175
|
try:
|
|
163
176
|
response = requests.get(url)
|
|
@@ -199,6 +212,19 @@ def get_taxonomy(species_list):
|
|
|
199
212
|
print(f"{species_name:<30} {'Error':<20} {'Error':<20} FAILED")
|
|
200
213
|
print(f"Error: {error_msg}")
|
|
201
214
|
sys.exit(1) # Stop the script
|
|
215
|
+
|
|
216
|
+
if has_unknown:
|
|
217
|
+
unknown_family = "Unknown"
|
|
218
|
+
unknown_genus = "Unknown"
|
|
219
|
+
unknown_species = "unknown"
|
|
220
|
+
|
|
221
|
+
if unknown_family not in taxonomy[1]:
|
|
222
|
+
taxonomy[1].append(unknown_family)
|
|
223
|
+
|
|
224
|
+
taxonomy[2][unknown_genus] = unknown_family
|
|
225
|
+
taxonomy[3][unknown_species] = unknown_genus
|
|
226
|
+
|
|
227
|
+
print(f"{unknown_species:<30} {unknown_family:<20} {unknown_genus:<20} {'OK'}")
|
|
202
228
|
|
|
203
229
|
taxonomy[1] = sorted(list(set(taxonomy[1])))
|
|
204
230
|
print("-" * 80)
|
|
@@ -212,7 +238,7 @@ def get_taxonomy(species_list):
|
|
|
212
238
|
print(f" {i}: {family}")
|
|
213
239
|
|
|
214
240
|
print("\nGenus indices:")
|
|
215
|
-
for i, genus in enumerate(taxonomy[2].keys()):
|
|
241
|
+
for i, genus in enumerate(sorted(taxonomy[2].keys())):
|
|
216
242
|
print(f" {i}: {genus}")
|
|
217
243
|
|
|
218
244
|
print("\nSpecies indices:")
|
|
@@ -244,7 +270,7 @@ def get_species_from_directory(train_dir):
|
|
|
244
270
|
logger.info(f"Found {len(species_list)} species in {train_dir}")
|
|
245
271
|
return species_list
|
|
246
272
|
|
|
247
|
-
def create_mappings(taxonomy):
|
|
273
|
+
def create_mappings(taxonomy, species_list=None):
|
|
248
274
|
"""
|
|
249
275
|
Creates mapping dictionaries from taxonomy data.
|
|
250
276
|
Returns level-to-index mapping and parent-child relationships between taxonomic levels.
|
|
@@ -254,9 +280,17 @@ def create_mappings(taxonomy):
|
|
|
254
280
|
|
|
255
281
|
for level, labels in taxonomy.items():
|
|
256
282
|
if isinstance(labels, list):
|
|
283
|
+
# Level 1: Family (already sorted)
|
|
257
284
|
level_to_idx[level] = {label: idx for idx, label in enumerate(labels)}
|
|
258
|
-
else:
|
|
259
|
-
|
|
285
|
+
else: # dict for levels 2 and 3
|
|
286
|
+
if level == 3 and species_list is not None:
|
|
287
|
+
# For species, the order is determined by species_list
|
|
288
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(species_list)}
|
|
289
|
+
else:
|
|
290
|
+
# For genus (and as a fallback for species), sort alphabetically
|
|
291
|
+
sorted_keys = sorted(labels.keys())
|
|
292
|
+
level_to_idx[level] = {label: idx for idx, label in enumerate(sorted_keys)}
|
|
293
|
+
|
|
260
294
|
for child, parent in labels.items():
|
|
261
295
|
if (level, parent) not in parent_child_relationship:
|
|
262
296
|
parent_child_relationship[(level, parent)] = []
|
|
@@ -670,7 +704,7 @@ if __name__ == '__main__':
|
|
|
670
704
|
species_list = [
|
|
671
705
|
"Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius", "Bombus terrestris",
|
|
672
706
|
"Eupeodes corollae", "Episyrphus balteatus", "Aglais urticae", "Vespula vulgaris",
|
|
673
|
-
"Eristalis tenax"
|
|
707
|
+
"Eristalis tenax", "unknown"
|
|
674
708
|
]
|
|
675
|
-
|
|
709
|
+
train(species_list=species_list, epochs=2)
|
|
676
710
|
|