megadetector 10.0.2__py3-none-any.whl → 10.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/data_management/animl_to_md.py +158 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/process_video.py +165 -946
- megadetector/detection/pytorch_detector.py +575 -276
- megadetector/detection/run_detector_batch.py +629 -202
- megadetector/detection/run_md_and_speciesnet.py +1319 -0
- megadetector/detection/video_utils.py +243 -107
- megadetector/postprocessing/classification_postprocessing.py +12 -1
- megadetector/postprocessing/combine_batch_outputs.py +2 -0
- megadetector/postprocessing/compare_batch_results.py +21 -2
- megadetector/postprocessing/merge_detections.py +16 -12
- megadetector/postprocessing/separate_detections_into_folders.py +1 -1
- megadetector/postprocessing/subset_json_detector_output.py +1 -3
- megadetector/postprocessing/validate_batch_results.py +25 -2
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/ct_utils.py +69 -5
- megadetector/utils/extract_frames_from_video.py +303 -0
- megadetector/utils/md_tests.py +583 -524
- megadetector/utils/path_utils.py +4 -15
- megadetector/utils/wi_utils.py +20 -4
- megadetector/visualization/visualization_utils.py +1 -1
- megadetector/visualization/visualize_db.py +8 -22
- megadetector/visualization/visualize_detector_output.py +7 -5
- megadetector/visualization/visualize_video_output.py +607 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/METADATA +134 -135
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/RECORD +30 -23
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/licenses/LICENSE +0 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/top_level.txt +0 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.4.dist-info}/WHEEL +0 -0
|
@@ -186,7 +186,8 @@ def merge_detections(source_files,target_file,output_file,options=None):
|
|
|
186
186
|
|
|
187
187
|
image_filename = source_im['file']
|
|
188
188
|
|
|
189
|
-
assert image_filename in fn_to_image,
|
|
189
|
+
assert image_filename in fn_to_image, \
|
|
190
|
+
'Image {} not in target image set'.format(image_filename)
|
|
190
191
|
target_im = fn_to_image[image_filename]
|
|
191
192
|
|
|
192
193
|
if 'detections' not in source_im or source_im['detections'] is None:
|
|
@@ -294,10 +295,15 @@ def merge_detections(source_files,target_file,output_file,options=None):
|
|
|
294
295
|
|
|
295
296
|
print('Saved merged results to {}'.format(output_file))
|
|
296
297
|
|
|
298
|
+
# ...def merge_detections(...)
|
|
299
|
+
|
|
297
300
|
|
|
298
301
|
#%% Command-line driver
|
|
299
302
|
|
|
300
|
-
def main():
|
|
303
|
+
def main():
|
|
304
|
+
"""
|
|
305
|
+
Command-line driver for merge_detections.py
|
|
306
|
+
"""
|
|
301
307
|
|
|
302
308
|
default_options = MergeDetectionsOptions()
|
|
303
309
|
|
|
@@ -305,7 +311,7 @@ def main(): # noqa
|
|
|
305
311
|
description='Merge detections from one or more MegaDetector results files into an existing reuslts file')
|
|
306
312
|
parser.add_argument(
|
|
307
313
|
'source_files',
|
|
308
|
-
nargs=
|
|
314
|
+
nargs='+',
|
|
309
315
|
help='Path to source .json file(s) to merge from')
|
|
310
316
|
parser.add_argument(
|
|
311
317
|
'target_file',
|
|
@@ -325,11 +331,11 @@ def main(): # noqa
|
|
|
325
331
|
default=default_options.min_detection_size,
|
|
326
332
|
type=float,
|
|
327
333
|
help='Ignore detections with an area smaller than this (as a fraction of ' + \
|
|
328
|
-
|
|
329
|
-
|
|
334
|
+
'image size) (default {})'.format(
|
|
335
|
+
default_options.min_detection_size))
|
|
330
336
|
parser.add_argument(
|
|
331
337
|
'--source_confidence_thresholds',
|
|
332
|
-
nargs=
|
|
338
|
+
nargs='+',
|
|
333
339
|
type=float,
|
|
334
340
|
default=default_options.source_confidence_thresholds,
|
|
335
341
|
help='List of thresholds for each source file (default {}). '.format(
|
|
@@ -340,19 +346,18 @@ def main(): # noqa
|
|
|
340
346
|
'--target_confidence_threshold',
|
|
341
347
|
type=float,
|
|
342
348
|
default=default_options.target_confidence_threshold,
|
|
343
|
-
help=
|
|
344
|
-
|
|
345
|
-
default_options.target_confidence_threshold))
|
|
349
|
+
help="Do not merge if target file detection confidence is already higher " + \
|
|
350
|
+
"than this (default {})".format(default_options.target_confidence_threshold))
|
|
346
351
|
parser.add_argument(
|
|
347
352
|
'--categories_to_include',
|
|
348
353
|
type=int,
|
|
349
|
-
nargs=
|
|
354
|
+
nargs='+',
|
|
350
355
|
default=None,
|
|
351
356
|
help='List of numeric detection category IDs to include')
|
|
352
357
|
parser.add_argument(
|
|
353
358
|
'--categories_to_exclude',
|
|
354
359
|
type=int,
|
|
355
|
-
nargs=
|
|
360
|
+
nargs='+',
|
|
356
361
|
default=None,
|
|
357
362
|
help='List of numeric detection categories to include')
|
|
358
363
|
parser.add_argument(
|
|
@@ -387,4 +392,3 @@ def main(): # noqa
|
|
|
387
392
|
|
|
388
393
|
if __name__ == '__main__':
|
|
389
394
|
main()
|
|
390
|
-
|
|
@@ -587,7 +587,7 @@ def separate_detections_into_folders(options):
|
|
|
587
587
|
# token = tokens[0]
|
|
588
588
|
for token in tokens:
|
|
589
589
|
subtokens = token.split('=')
|
|
590
|
-
assert len(subtokens) == 2 and is_float(subtokens[1]), \
|
|
590
|
+
assert (len(subtokens) == 2) and (is_float(subtokens[1])), \
|
|
591
591
|
'Illegal classification threshold {}'.format(token)
|
|
592
592
|
classification_thresholds[subtokens[0]] = float(subtokens[1])
|
|
593
593
|
|
|
@@ -221,7 +221,7 @@ def remove_classification_categories_below_count(data, options):
|
|
|
221
221
|
classification_category_ids_to_keep = set()
|
|
222
222
|
|
|
223
223
|
for classification_category_id in classification_category_id_to_count:
|
|
224
|
-
if classification_category_id_to_count[classification_category_id]
|
|
224
|
+
if classification_category_id_to_count[classification_category_id] >= \
|
|
225
225
|
options.remove_classification_categories_below_count:
|
|
226
226
|
classification_category_ids_to_keep.add(classification_category_id)
|
|
227
227
|
|
|
@@ -235,7 +235,6 @@ def remove_classification_categories_below_count(data, options):
|
|
|
235
235
|
if n_categories_removed == 0:
|
|
236
236
|
return data
|
|
237
237
|
|
|
238
|
-
|
|
239
238
|
# Filter the category list
|
|
240
239
|
output_classification_categories = {}
|
|
241
240
|
for category_id in data['classification_categories']:
|
|
@@ -245,7 +244,6 @@ def remove_classification_categories_below_count(data, options):
|
|
|
245
244
|
data['classification_categories'] = output_classification_categories
|
|
246
245
|
assert len(data['classification_categories']) == len(classification_category_ids_to_keep)
|
|
247
246
|
|
|
248
|
-
|
|
249
247
|
# If necessary, filter the category descriptions
|
|
250
248
|
if 'classification_category_descriptions' in data:
|
|
251
249
|
output_classification_category_descriptions = {}
|
|
@@ -172,7 +172,9 @@ def validate_batch_results(json_filename,options=None):
|
|
|
172
172
|
file = im['file']
|
|
173
173
|
|
|
174
174
|
if 'detections' in im and im['detections'] is not None:
|
|
175
|
+
|
|
175
176
|
for det in im['detections']:
|
|
177
|
+
|
|
176
178
|
assert 'category' in det, 'Image {} has a detection with no category'.format(file)
|
|
177
179
|
assert 'conf' in det, 'Image {} has a detection with no confidence'.format(file)
|
|
178
180
|
assert isinstance(det['conf'],float), \
|
|
@@ -182,6 +184,21 @@ def validate_batch_results(json_filename,options=None):
|
|
|
182
184
|
'Image {} has a detection with an unmapped category {}'.format(
|
|
183
185
|
file,det['category'])
|
|
184
186
|
|
|
187
|
+
if 'classifications' in det and det['classifications'] is not None:
|
|
188
|
+
for c in det['classifications']:
|
|
189
|
+
assert isinstance(c[0],str), \
|
|
190
|
+
'Image {} has an illegal classification category: {}'.format(file,c[0])
|
|
191
|
+
try:
|
|
192
|
+
_ = int(c[0])
|
|
193
|
+
except Exception:
|
|
194
|
+
raise ValueError('Image {} has an illegal classification category: {}'.format(
|
|
195
|
+
file,c[0]))
|
|
196
|
+
assert isinstance(c[1],float) or isinstance(c[1], int)
|
|
197
|
+
|
|
198
|
+
# ...for each detection
|
|
199
|
+
|
|
200
|
+
# ...if this image has a detections field
|
|
201
|
+
|
|
185
202
|
if options.check_image_existence:
|
|
186
203
|
|
|
187
204
|
if options.relative_path_base is None:
|
|
@@ -207,13 +224,19 @@ def validate_batch_results(json_filename,options=None):
|
|
|
207
224
|
if not isinstance(im['detections'],list):
|
|
208
225
|
raise ValueError('Invalid detections list for image {}'.format(im['file']))
|
|
209
226
|
|
|
227
|
+
if is_video_file(im['file']) and (format_version >= 1.5):
|
|
228
|
+
|
|
229
|
+
if 'frames_processed' not in im:
|
|
230
|
+
raise ValueError('Video without frames_processed field: {}'.format(im['file']))
|
|
231
|
+
|
|
210
232
|
if is_video_file(im['file']) and (format_version >= 1.4):
|
|
211
233
|
|
|
212
234
|
if 'frame_rate' not in im:
|
|
213
235
|
raise ValueError('Video without frame rate: {}'.format(im['file']))
|
|
214
236
|
if im['frame_rate'] < 0:
|
|
215
|
-
|
|
216
|
-
|
|
237
|
+
if 'failure' not in im:
|
|
238
|
+
raise ValueError('Video with illegal frame rate {}: {}'.format(
|
|
239
|
+
str(im['frame_rate']),im['file']))
|
|
217
240
|
if 'detections' in im and im['detections'] is not None:
|
|
218
241
|
for det in im['detections']:
|
|
219
242
|
if 'frame_number' not in det:
|
|
File without changes
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
Test script for validating NMS functionality with synthetic data.
|
|
4
|
+
|
|
5
|
+
This script creates synthetic detection scenarios where we know exactly which
|
|
6
|
+
boxes should be suppressed by NMS, allowing us to verify the correctness of
|
|
7
|
+
the NMS implementation.
|
|
8
|
+
|
|
9
|
+
This is an AI-generated test module.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#%% Imports
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from megadetector.detection.pytorch_detector import nms
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
#%% Support functions
|
|
22
|
+
|
|
23
|
+
def calculate_iou_boxes(box1, box2):
|
|
24
|
+
"""
|
|
25
|
+
Calculate IoU between two boxes in [x1, y1, x2, y2] format.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
box1: torch.Tensor or list of [x1, y1, x2, y2]
|
|
29
|
+
box2: torch.Tensor or list of [x1, y1, x2, y2]
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
float: IoU value between 0 and 1
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
if isinstance(box1, list):
|
|
36
|
+
box1 = torch.tensor(box1, dtype=torch.float)
|
|
37
|
+
if isinstance(box2, list):
|
|
38
|
+
box2 = torch.tensor(box2, dtype=torch.float)
|
|
39
|
+
|
|
40
|
+
# Calculate intersection area
|
|
41
|
+
x1_inter = max(box1[0], box2[0])
|
|
42
|
+
y1_inter = max(box1[1], box2[1])
|
|
43
|
+
x2_inter = min(box1[2], box2[2])
|
|
44
|
+
y2_inter = min(box1[3], box2[3])
|
|
45
|
+
|
|
46
|
+
if x2_inter <= x1_inter or y2_inter <= y1_inter:
|
|
47
|
+
return 0.0
|
|
48
|
+
|
|
49
|
+
intersection = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
|
50
|
+
|
|
51
|
+
# Calculate union area
|
|
52
|
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
53
|
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
54
|
+
union = area1 + area2 - intersection
|
|
55
|
+
|
|
56
|
+
return float(intersection / union) if union > 0 else 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def create_synthetic_predictions():
|
|
60
|
+
"""
|
|
61
|
+
Create synthetic model predictions for testing NMS.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: Synthetic predictions in the format expected by the NMS function
|
|
65
|
+
Shape: [batch_size=1, num_anchors, num_classes + 5]
|
|
66
|
+
|
|
67
|
+
Test scenarios:
|
|
68
|
+
1. Two highly overlapping boxes (IoU > 0.5) with different confidences - higher confidence should win
|
|
69
|
+
2. Two boxes with low overlap (IoU < 0.5) - both should be kept
|
|
70
|
+
3. Multiple boxes of different classes in same location - should be kept (class-independent NMS)
|
|
71
|
+
4. Three overlapping boxes with cascading confidences - highest confidence should win
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
# We'll create predictions for a 640x640 image with 3 classes
|
|
75
|
+
# Format: [x_center, y_center, width, height, objectness, class0_conf, class1_conf, class2_conf]
|
|
76
|
+
|
|
77
|
+
synthetic_boxes = []
|
|
78
|
+
|
|
79
|
+
# Scenario 1: Two highly overlapping boxes (IoU > 0.8)
|
|
80
|
+
# Box A: center=(100, 100), size=80x80, high confidence for class 0
|
|
81
|
+
# Box B: center=(105, 105), size=80x80, low confidence for class 0 (smaller offset = higher IoU)
|
|
82
|
+
# Expected: Box A kept, Box B suppressed
|
|
83
|
+
synthetic_boxes.append([100, 100, 80, 80, 0.9, 0.8, 0.1, 0.1]) # Box A - should be kept
|
|
84
|
+
synthetic_boxes.append([105, 105, 80, 80, 0.9, 0.5, 0.1, 0.1]) # Box B - should be suppressed
|
|
85
|
+
|
|
86
|
+
# Scenario 1b: Two nearly identical boxes (IoU ≈ 0.95)
|
|
87
|
+
# Box A2: center=(200, 100), size=60x60, high confidence for class 0
|
|
88
|
+
# Box B2: center=(202, 102), size=60x60, lower confidence for class 0
|
|
89
|
+
# Expected: Box A2 kept, Box B2 suppressed
|
|
90
|
+
synthetic_boxes.append([200, 100, 60, 60, 0.9, 0.9, 0.05, 0.05]) # Box A2 - should be kept
|
|
91
|
+
synthetic_boxes.append([202, 102, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box B2 - should be suppressed
|
|
92
|
+
|
|
93
|
+
# Scenario 2: Two boxes with low overlap (IoU ≈ 0.1)
|
|
94
|
+
# Box C: center=(300, 100), size=60x60, class 0
|
|
95
|
+
# Box D: center=(380, 100), size=60x60, class 0
|
|
96
|
+
# Expected: Both kept
|
|
97
|
+
synthetic_boxes.append([300, 100, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box C - should be kept
|
|
98
|
+
synthetic_boxes.append([380, 100, 60, 60, 0.9, 0.6, 0.1, 0.1]) # Box D - should be kept
|
|
99
|
+
|
|
100
|
+
# Scenario 3: Same location, different classes
|
|
101
|
+
# Box E: center=(100, 300), size=70x70, class 0
|
|
102
|
+
# Box F: center=(100, 300), size=70x70, class 1
|
|
103
|
+
# Expected: Both kept (class-independent NMS)
|
|
104
|
+
synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.7, 0.1, 0.1]) # Box E - class 0, should be kept
|
|
105
|
+
synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.1, 0.7, 0.1]) # Box F - class 1, should be kept
|
|
106
|
+
|
|
107
|
+
# Scenario 4: Three cascading overlapping boxes
|
|
108
|
+
# Box G: center=(500, 300), size=80x80, highest confidence
|
|
109
|
+
# Box H: center=(510, 310), size=80x80, medium confidence
|
|
110
|
+
# Box I: center=(520, 320), size=80x80, lowest confidence
|
|
111
|
+
# Expected: Only Box G kept
|
|
112
|
+
synthetic_boxes.append([500, 300, 80, 80, 0.95, 0.9, 0.05, 0.05]) # Box G - highest conf, should be kept
|
|
113
|
+
synthetic_boxes.append([510, 310, 80, 80, 0.9, 0.7, 0.1, 0.1]) # Box H - should be suppressed
|
|
114
|
+
synthetic_boxes.append([520, 320, 80, 80, 0.85, 0.6, 0.15, 0.15]) # Box I - should be suppressed
|
|
115
|
+
|
|
116
|
+
# Add some low-confidence boxes that should be filtered out before NMS
|
|
117
|
+
synthetic_boxes.append([200, 500, 50, 50, 0.1, 0.05, 0.02, 0.03]) # Too low confidence
|
|
118
|
+
|
|
119
|
+
# Convert to tensor format expected by NMS function
|
|
120
|
+
# We need to pad to a reasonable number of anchors (let's use 20)
|
|
121
|
+
num_anchors = 20
|
|
122
|
+
num_classes = 3
|
|
123
|
+
|
|
124
|
+
predictions = torch.zeros(1, num_anchors, num_classes + 5) # batch_size=1
|
|
125
|
+
|
|
126
|
+
# Fill in our synthetic boxes
|
|
127
|
+
for i, box_data in enumerate(synthetic_boxes):
|
|
128
|
+
if i < num_anchors:
|
|
129
|
+
predictions[0, i, :] = torch.tensor(box_data)
|
|
130
|
+
|
|
131
|
+
return predictions
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
#%% Main test function
|
|
135
|
+
|
|
136
|
+
def test_nms_functionality():
|
|
137
|
+
"""
|
|
138
|
+
Test the NMS function with synthetic data to verify correct behavior.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
print("Testing NMS functionality with synthetic data...")
|
|
142
|
+
|
|
143
|
+
# Generate synthetic predictions
|
|
144
|
+
predictions = create_synthetic_predictions()
|
|
145
|
+
print(f"Created synthetic predictions with shape: {predictions.shape}")
|
|
146
|
+
|
|
147
|
+
# Run NMS with IoU threshold = 0.5 and confidence threshold = 0.3
|
|
148
|
+
results = nms(predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
|
|
149
|
+
|
|
150
|
+
print(f"NMS returned {len(results)} batch results")
|
|
151
|
+
detections = results[0] # Get results for first (and only) image
|
|
152
|
+
print(f"Number of detections after NMS: {detections.shape[0]}")
|
|
153
|
+
|
|
154
|
+
assert detections.shape[0] != 0
|
|
155
|
+
|
|
156
|
+
print("\nDetections after NMS:")
|
|
157
|
+
print("Format: [x1, y1, x2, y2, confidence, class_id]")
|
|
158
|
+
for i, det in enumerate(detections):
|
|
159
|
+
x1, y1, x2, y2, conf, cls = det
|
|
160
|
+
center_x = (x1 + x2) / 2
|
|
161
|
+
center_y = (y1 + y2) / 2
|
|
162
|
+
width = x2 - x1
|
|
163
|
+
height = y2 - y1
|
|
164
|
+
print(f"Detection {i}: center=({center_x:.1f}, {center_y:.1f}), "
|
|
165
|
+
f"size={width:.1f}x{height:.1f}, conf={conf:.3f}, class={int(cls)}")
|
|
166
|
+
|
|
167
|
+
# Verify expected results
|
|
168
|
+
|
|
169
|
+
# Verify that high-confidence boxes are kept over low-confidence overlapping ones
|
|
170
|
+
# Look for the scenario 1 boxes (around center 100,100 area)
|
|
171
|
+
scenario1_boxes = []
|
|
172
|
+
for i, det in enumerate(detections):
|
|
173
|
+
x1, y1, x2, y2, conf, cls = det
|
|
174
|
+
center_x = (x1 + x2) / 2
|
|
175
|
+
center_y = (y1 + y2) / 2
|
|
176
|
+
if 80 <= center_x <= 130 and 80 <= center_y <= 130 and int(cls) == 0:
|
|
177
|
+
scenario1_boxes.append((i, center_x, center_y, conf))
|
|
178
|
+
|
|
179
|
+
# Check scenario 1b (around center 200,100 area)
|
|
180
|
+
scenario1b_boxes = []
|
|
181
|
+
for i, det in enumerate(detections):
|
|
182
|
+
x1, y1, x2, y2, conf, cls = det
|
|
183
|
+
center_x = (x1 + x2) / 2
|
|
184
|
+
center_y = (y1 + y2) / 2
|
|
185
|
+
if 180 <= center_x <= 220 and 80 <= center_y <= 120 and int(cls) == 0:
|
|
186
|
+
scenario1b_boxes.append((i, center_x, center_y, conf))
|
|
187
|
+
|
|
188
|
+
# Both scenario 1 and 1b should have exactly 1 detection each
|
|
189
|
+
total_high_overlap_boxes = len(scenario1_boxes) + len(scenario1b_boxes)
|
|
190
|
+
if total_high_overlap_boxes != 2:
|
|
191
|
+
print("Error: expected 2 detections in high-overlap scenarios (1 each), got {}".format(
|
|
192
|
+
total_high_overlap_boxes
|
|
193
|
+
))
|
|
194
|
+
print(f" Scenario 1: {len(scenario1_boxes)} boxes")
|
|
195
|
+
print(f" Scenario 1b: {len(scenario1b_boxes)} boxes")
|
|
196
|
+
raise AssertionError()
|
|
197
|
+
# Should be the high-confidence box (0.8 * 0.9 = 0.72)
|
|
198
|
+
elif len(scenario1_boxes) == 1 and scenario1_boxes[0][3] < 0.7:
|
|
199
|
+
print("Error: wrong box kept in scenario 1. Expected conf > 0.7, got {}".format(
|
|
200
|
+
scenario1_boxes[0][3]
|
|
201
|
+
))
|
|
202
|
+
raise AssertionError()
|
|
203
|
+
# Should be the high-confidence box (0.9 * 0.9 = 0.81)
|
|
204
|
+
elif len(scenario1b_boxes) == 1 and scenario1b_boxes[0][3] < 0.8:
|
|
205
|
+
print("Error: wrong box kept in scenario 1b. Expected conf > 0.8, got {}".format(
|
|
206
|
+
scenario1b_boxes[0][3]
|
|
207
|
+
))
|
|
208
|
+
raise AssertionError()
|
|
209
|
+
else:
|
|
210
|
+
print("Scenarios 1 & 1b passed: High-confidence boxes kept, low-confidence overlapping boxes suppressed")
|
|
211
|
+
|
|
212
|
+
# Verify IoU calculations and ensure suppression actually works
|
|
213
|
+
if len(scenario1_boxes) == 1 and len(scenario1b_boxes) == 1:
|
|
214
|
+
# Calculate what the IoU would have been between the boxes that were supposed to overlap
|
|
215
|
+
# Scenario 1: Box A (100,100,80x80) vs Box B (105,105,80x80)
|
|
216
|
+
box_a = [100-40, 100-40, 100+40, 100+40] # Convert center+size to corners
|
|
217
|
+
box_b = [105-40, 105-40, 105+40, 105+40]
|
|
218
|
+
iou_1 = calculate_iou_boxes(box_a, box_b)
|
|
219
|
+
|
|
220
|
+
# Scenario 1b: Box A2 (200,100,60x60) vs Box B2 (202,102,60x60)
|
|
221
|
+
box_a2 = [200-30, 100-30, 200+30, 100+30]
|
|
222
|
+
box_b2 = [202-30, 102-30, 202+30, 102+30]
|
|
223
|
+
iou_1b = calculate_iou_boxes(box_a2, box_b2)
|
|
224
|
+
|
|
225
|
+
print(f" Theoretical IoU for scenario 1 boxes: {iou_1:.3f}")
|
|
226
|
+
print(f" Theoretical IoU for scenario 1b boxes: {iou_1b:.3f}")
|
|
227
|
+
|
|
228
|
+
# If IoU > threshold, suppression should have occurred
|
|
229
|
+
if iou_1 <= 0.5:
|
|
230
|
+
print(f"Error: scenario 1 IoU {iou_1:.3f} is too low - test setup is invalid!")
|
|
231
|
+
raise AssertionError()
|
|
232
|
+
elif iou_1b <= 0.5:
|
|
233
|
+
print(f"Error: scenario 1b IoU {iou_1b:.3f} is too low - test setup is invalid!")
|
|
234
|
+
raise AssertionError()
|
|
235
|
+
else:
|
|
236
|
+
print(" High IoU confirmed - suppression was correct")
|
|
237
|
+
|
|
238
|
+
# Verify scenario 2 - both non-overlapping boxes should be kept
|
|
239
|
+
scenario2_boxes = []
|
|
240
|
+
for i, det in enumerate(detections):
|
|
241
|
+
x1, y1, x2, y2, conf, cls = det
|
|
242
|
+
center_x = (x1 + x2) / 2
|
|
243
|
+
center_y = (y1 + y2) / 2
|
|
244
|
+
if 270 <= center_x <= 410 and 70 <= center_y <= 130 and int(cls) == 0:
|
|
245
|
+
scenario2_boxes.append((i, center_x, center_y, conf))
|
|
246
|
+
|
|
247
|
+
if len(scenario2_boxes) != 2:
|
|
248
|
+
print(f"Error: expected 2 detections in scenario 2 area, got {len(scenario2_boxes)}")
|
|
249
|
+
raise AssertionError()
|
|
250
|
+
else:
|
|
251
|
+
print("Scenario 2 passed: Both non-overlapping boxes kept")
|
|
252
|
+
|
|
253
|
+
# Verify scenario 3 - different classes should both be kept
|
|
254
|
+
scenario3_boxes = []
|
|
255
|
+
for i, det in enumerate(detections):
|
|
256
|
+
x1, y1, x2, y2, conf, cls = det
|
|
257
|
+
center_x = (x1 + x2) / 2
|
|
258
|
+
center_y = (y1 + y2) / 2
|
|
259
|
+
if 65 <= center_x <= 135 and 265 <= center_y <= 335:
|
|
260
|
+
scenario3_boxes.append((i, center_x, center_y, conf, int(cls)))
|
|
261
|
+
|
|
262
|
+
classes_found = set(box[4] for box in scenario3_boxes)
|
|
263
|
+
if (len(scenario3_boxes) != 2) or (len(classes_found) != 2):
|
|
264
|
+
print("Error: expected 2 detections of different classes , got {} detections of classes {}".format(
|
|
265
|
+
len(scenario3_boxes),classes_found
|
|
266
|
+
))
|
|
267
|
+
raise AssertionError()
|
|
268
|
+
else:
|
|
269
|
+
print("Scenario 3 passed: Both different-class boxes kept")
|
|
270
|
+
|
|
271
|
+
# Verify scenario 4 - cascading overlapping boxes (only highest confidence should remain)
|
|
272
|
+
scenario4_boxes = []
|
|
273
|
+
for i, det in enumerate(detections):
|
|
274
|
+
x1, y1, x2, y2, conf, cls = det
|
|
275
|
+
center_x = (x1 + x2) / 2
|
|
276
|
+
center_y = (y1 + y2) / 2
|
|
277
|
+
if 460 <= center_x <= 560 and 260 <= center_y <= 360 and int(cls) == 0:
|
|
278
|
+
scenario4_boxes.append((i, center_x, center_y, conf))
|
|
279
|
+
|
|
280
|
+
print(f"\nScenario 4 analysis: Found {len(scenario4_boxes)} boxes in cascading area:")
|
|
281
|
+
for i, (det_idx, cx, cy, conf) in enumerate(scenario4_boxes):
|
|
282
|
+
print(f" Box {i}: center=({cx:.1f}, {cy:.1f}), conf={conf:.3f}")
|
|
283
|
+
|
|
284
|
+
# Check IoU between remaining boxes to ensure proper suppression
|
|
285
|
+
if len(scenario4_boxes) >= 2:
|
|
286
|
+
max_iou = 0
|
|
287
|
+
for i in range(len(scenario4_boxes)):
|
|
288
|
+
for j in range(i+1, len(scenario4_boxes)):
|
|
289
|
+
det1 = detections[scenario4_boxes[i][0]]
|
|
290
|
+
det2 = detections[scenario4_boxes[j][0]]
|
|
291
|
+
iou = calculate_iou_boxes(det1[:4], det2[:4])
|
|
292
|
+
print(f" IoU between box {i} and box {j}: {iou:.3f}")
|
|
293
|
+
max_iou = max(max_iou, iou)
|
|
294
|
+
|
|
295
|
+
if len(scenario4_boxes) == 1:
|
|
296
|
+
print("Scenario 4 passed: Only highest confidence box kept")
|
|
297
|
+
else:
|
|
298
|
+
# This is only OK if IoU < threshold
|
|
299
|
+
if max_iou < 0.5: # Our IoU threshold
|
|
300
|
+
print("Scenario 4 passed: Multiple boxes kept due to low IoU (< 0.5)")
|
|
301
|
+
else:
|
|
302
|
+
print(f"ERROR: Scenario 4 failed - boxes with IoU {max_iou:.3f} > 0.5 were not suppressed!")
|
|
303
|
+
raise AssertionError()
|
|
304
|
+
|
|
305
|
+
# Create a scenario that requires IoU calculation
|
|
306
|
+
print("\n=== COMPREHENSIVE IoU VALIDATION TEST ===")
|
|
307
|
+
|
|
308
|
+
# Create two identical boxes that should definitely be suppressed
|
|
309
|
+
identical_box_a = [100, 100, 50, 50, 0.9, 0.9, 0.05, 0.05] # High confidence
|
|
310
|
+
identical_box_b = [100, 100, 50, 50, 0.9, 0.7, 0.1, 0.1] # Lower confidence
|
|
311
|
+
|
|
312
|
+
test_predictions = torch.zeros(1, 5, 8) # Small batch for focused test
|
|
313
|
+
test_predictions[0, 0, :] = torch.tensor(identical_box_a)
|
|
314
|
+
test_predictions[0, 1, :] = torch.tensor(identical_box_b)
|
|
315
|
+
|
|
316
|
+
# Run NMS on this simple case
|
|
317
|
+
test_results = nms(test_predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
|
|
318
|
+
test_detections = test_results[0]
|
|
319
|
+
|
|
320
|
+
print(f"Identical boxes test: Input 2 identical boxes, got {test_detections.shape[0]} detections")
|
|
321
|
+
|
|
322
|
+
if test_detections.shape[0] != 1:
|
|
323
|
+
print(f"Error Two identical boxes should result in 1 detection, got {test_detections.shape[0]}")
|
|
324
|
+
raise AssertionError()
|
|
325
|
+
else:
|
|
326
|
+
# Verify it kept the higher confidence box
|
|
327
|
+
kept_conf = test_detections[0, 4].item()
|
|
328
|
+
expected_conf = 0.9 * 0.9 # objectness * class_conf
|
|
329
|
+
if abs(kept_conf - expected_conf) > 0.01:
|
|
330
|
+
print(f"ERROR: Wrong box kept. Expected conf ≈ {expected_conf:.3f}, got {kept_conf:.3f}")
|
|
331
|
+
raise AssertionError()
|
|
332
|
+
else:
|
|
333
|
+
print("Identical boxes test passed: Higher confidence box kept")
|
|
334
|
+
|
|
335
|
+
print("\nNMS tests passed")
|
megadetector/utils/ct_utils.py
CHANGED
|
@@ -16,6 +16,8 @@ import builtins
|
|
|
16
16
|
import datetime
|
|
17
17
|
import tempfile
|
|
18
18
|
import shutil
|
|
19
|
+
import platform
|
|
20
|
+
import sys
|
|
19
21
|
import uuid
|
|
20
22
|
|
|
21
23
|
import jsonpickle
|
|
@@ -983,13 +985,15 @@ def dict_to_kvp_list(d,
|
|
|
983
985
|
return s
|
|
984
986
|
|
|
985
987
|
|
|
986
|
-
def parse_bool_string(s):
|
|
988
|
+
def parse_bool_string(s, strict=False):
|
|
987
989
|
"""
|
|
988
990
|
Convert the strings "true" or "false" to boolean values. Case-insensitive, discards
|
|
989
991
|
leading and trailing whitespace. If s is already a bool, returns s.
|
|
990
992
|
|
|
991
993
|
Args:
|
|
992
994
|
s (str or bool): the string to parse, or the bool to return
|
|
995
|
+
strict (bool, optional): only allow "true" or "false", otherwise
|
|
996
|
+
handles "1", "0", "yes", and "no".
|
|
993
997
|
|
|
994
998
|
Returns:
|
|
995
999
|
bool: the parsed value
|
|
@@ -997,10 +1001,17 @@ def parse_bool_string(s):
|
|
|
997
1001
|
|
|
998
1002
|
if isinstance(s,bool):
|
|
999
1003
|
return s
|
|
1000
|
-
s = s.lower().strip()
|
|
1001
|
-
|
|
1004
|
+
s = str(s).lower().strip()
|
|
1005
|
+
|
|
1006
|
+
if strict:
|
|
1007
|
+
false_strings = ('false')
|
|
1008
|
+
true_strings = ('true')
|
|
1009
|
+
else:
|
|
1010
|
+
false_strings = ('no', 'false', 'f', 'n', '0')
|
|
1011
|
+
true_strings = ('yes', 'true', 't', 'y', '1')
|
|
1012
|
+
if s in true_strings:
|
|
1002
1013
|
return True
|
|
1003
|
-
elif s
|
|
1014
|
+
elif s in false_strings:
|
|
1004
1015
|
return False
|
|
1005
1016
|
else:
|
|
1006
1017
|
raise ValueError('Cannot parse bool from string {}'.format(str(s)))
|
|
@@ -1044,6 +1055,57 @@ def make_test_folder(subfolder=None):
|
|
|
1044
1055
|
append_guid=True)
|
|
1045
1056
|
|
|
1046
1057
|
|
|
1058
|
+
#%% Environment utilities
|
|
1059
|
+
|
|
1060
|
+
def is_sphinx_build():
|
|
1061
|
+
"""
|
|
1062
|
+
Determine whether we are running in the context of our Sphinx build.
|
|
1063
|
+
|
|
1064
|
+
Returns:
|
|
1065
|
+
bool: True if we're running a Sphinx build
|
|
1066
|
+
"""
|
|
1067
|
+
|
|
1068
|
+
is_sphinx = hasattr(builtins, '__sphinx_build__')
|
|
1069
|
+
return is_sphinx
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
def is_running_in_gha():
|
|
1073
|
+
"""
|
|
1074
|
+
Determine whether we are running on a GitHub Actions runner.
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
bool: True if we're running in a GHA runner
|
|
1078
|
+
"""
|
|
1079
|
+
|
|
1080
|
+
running_in_gha = False
|
|
1081
|
+
|
|
1082
|
+
if ('GITHUB_ACTIONS' in os.environ):
|
|
1083
|
+
# Documentation is inconsistent on how this variable presents itself
|
|
1084
|
+
if isinstance(os.environ['GITHUB_ACTIONS'],bool) and \
|
|
1085
|
+
os.environ['GITHUB_ACTIONS']:
|
|
1086
|
+
running_in_gha = True
|
|
1087
|
+
elif isinstance(os.environ['GITHUB_ACTIONS'],str) and \
|
|
1088
|
+
os.environ['GITHUB_ACTIONS'].lower() == ('true'):
|
|
1089
|
+
running_in_gha = True
|
|
1090
|
+
|
|
1091
|
+
return running_in_gha
|
|
1092
|
+
|
|
1093
|
+
|
|
1094
|
+
def environment_is_wsl():
|
|
1095
|
+
"""
|
|
1096
|
+
Determines whether we're running in WSL.
|
|
1097
|
+
|
|
1098
|
+
Returns:
|
|
1099
|
+
True if we're running in WSL
|
|
1100
|
+
"""
|
|
1101
|
+
|
|
1102
|
+
if sys.platform not in ('linux','posix'):
|
|
1103
|
+
return False
|
|
1104
|
+
platform_string = ' '.join(platform.uname()).lower()
|
|
1105
|
+
return 'microsoft' in platform_string and 'wsl' in platform_string
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
|
|
1047
1109
|
#%% Tests
|
|
1048
1110
|
|
|
1049
1111
|
def test_write_json():
|
|
@@ -1649,6 +1711,8 @@ def test_string_parsing():
|
|
|
1649
1711
|
assert not parse_bool_string("false")
|
|
1650
1712
|
assert not parse_bool_string("False")
|
|
1651
1713
|
assert not parse_bool_string(" FALSE ")
|
|
1714
|
+
assert parse_bool_string("1", strict=False)
|
|
1715
|
+
assert not parse_bool_string("0", strict=False)
|
|
1652
1716
|
assert parse_bool_string(True) is True # Test with existing bool
|
|
1653
1717
|
assert parse_bool_string(False) is False
|
|
1654
1718
|
try:
|
|
@@ -1657,7 +1721,7 @@ def test_string_parsing():
|
|
|
1657
1721
|
except ValueError:
|
|
1658
1722
|
pass
|
|
1659
1723
|
try:
|
|
1660
|
-
parse_bool_string("1")
|
|
1724
|
+
parse_bool_string("1",strict=True)
|
|
1661
1725
|
raise AssertionError("ValueError not raised for '1'")
|
|
1662
1726
|
except ValueError:
|
|
1663
1727
|
pass
|