megadetector 10.0.2__py3-none-any.whl → 10.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/detection/process_video.py +120 -913
- megadetector/detection/pytorch_detector.py +572 -263
- megadetector/detection/run_detector_batch.py +525 -143
- megadetector/detection/run_md_and_speciesnet.py +1301 -0
- megadetector/detection/video_utils.py +240 -105
- megadetector/postprocessing/classification_postprocessing.py +12 -1
- megadetector/postprocessing/compare_batch_results.py +21 -2
- megadetector/postprocessing/merge_detections.py +16 -12
- megadetector/postprocessing/validate_batch_results.py +25 -2
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/ct_utils.py +16 -5
- megadetector/utils/extract_frames_from_video.py +303 -0
- megadetector/utils/md_tests.py +578 -520
- megadetector/utils/wi_utils.py +20 -4
- megadetector/visualization/visualize_db.py +8 -22
- megadetector/visualization/visualize_detector_output.py +1 -1
- megadetector/visualization/visualize_video_output.py +607 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.3.dist-info}/METADATA +134 -135
- {megadetector-10.0.2.dist-info → megadetector-10.0.3.dist-info}/RECORD +23 -18
- {megadetector-10.0.2.dist-info → megadetector-10.0.3.dist-info}/licenses/LICENSE +0 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.3.dist-info}/top_level.txt +0 -0
- {megadetector-10.0.2.dist-info → megadetector-10.0.3.dist-info}/WHEEL +0 -0
|
@@ -186,7 +186,8 @@ def merge_detections(source_files,target_file,output_file,options=None):
|
|
|
186
186
|
|
|
187
187
|
image_filename = source_im['file']
|
|
188
188
|
|
|
189
|
-
assert image_filename in fn_to_image,
|
|
189
|
+
assert image_filename in fn_to_image, \
|
|
190
|
+
'Image {} not in target image set'.format(image_filename)
|
|
190
191
|
target_im = fn_to_image[image_filename]
|
|
191
192
|
|
|
192
193
|
if 'detections' not in source_im or source_im['detections'] is None:
|
|
@@ -294,10 +295,15 @@ def merge_detections(source_files,target_file,output_file,options=None):
|
|
|
294
295
|
|
|
295
296
|
print('Saved merged results to {}'.format(output_file))
|
|
296
297
|
|
|
298
|
+
# ...def merge_detections(...)
|
|
299
|
+
|
|
297
300
|
|
|
298
301
|
#%% Command-line driver
|
|
299
302
|
|
|
300
|
-
def main():
|
|
303
|
+
def main():
|
|
304
|
+
"""
|
|
305
|
+
Command-line driver for merge_detections.py
|
|
306
|
+
"""
|
|
301
307
|
|
|
302
308
|
default_options = MergeDetectionsOptions()
|
|
303
309
|
|
|
@@ -305,7 +311,7 @@ def main(): # noqa
|
|
|
305
311
|
description='Merge detections from one or more MegaDetector results files into an existing reuslts file')
|
|
306
312
|
parser.add_argument(
|
|
307
313
|
'source_files',
|
|
308
|
-
nargs=
|
|
314
|
+
nargs='+',
|
|
309
315
|
help='Path to source .json file(s) to merge from')
|
|
310
316
|
parser.add_argument(
|
|
311
317
|
'target_file',
|
|
@@ -325,11 +331,11 @@ def main(): # noqa
|
|
|
325
331
|
default=default_options.min_detection_size,
|
|
326
332
|
type=float,
|
|
327
333
|
help='Ignore detections with an area smaller than this (as a fraction of ' + \
|
|
328
|
-
|
|
329
|
-
|
|
334
|
+
'image size) (default {})'.format(
|
|
335
|
+
default_options.min_detection_size))
|
|
330
336
|
parser.add_argument(
|
|
331
337
|
'--source_confidence_thresholds',
|
|
332
|
-
nargs=
|
|
338
|
+
nargs='+',
|
|
333
339
|
type=float,
|
|
334
340
|
default=default_options.source_confidence_thresholds,
|
|
335
341
|
help='List of thresholds for each source file (default {}). '.format(
|
|
@@ -340,19 +346,18 @@ def main(): # noqa
|
|
|
340
346
|
'--target_confidence_threshold',
|
|
341
347
|
type=float,
|
|
342
348
|
default=default_options.target_confidence_threshold,
|
|
343
|
-
help=
|
|
344
|
-
|
|
345
|
-
default_options.target_confidence_threshold))
|
|
349
|
+
help="Do not merge if target file detection confidence is already higher " + \
|
|
350
|
+
"than this (default {})".format(default_options.target_confidence_threshold))
|
|
346
351
|
parser.add_argument(
|
|
347
352
|
'--categories_to_include',
|
|
348
353
|
type=int,
|
|
349
|
-
nargs=
|
|
354
|
+
nargs='+',
|
|
350
355
|
default=None,
|
|
351
356
|
help='List of numeric detection category IDs to include')
|
|
352
357
|
parser.add_argument(
|
|
353
358
|
'--categories_to_exclude',
|
|
354
359
|
type=int,
|
|
355
|
-
nargs=
|
|
360
|
+
nargs='+',
|
|
356
361
|
default=None,
|
|
357
362
|
help='List of numeric detection categories to include')
|
|
358
363
|
parser.add_argument(
|
|
@@ -387,4 +392,3 @@ def main(): # noqa
|
|
|
387
392
|
|
|
388
393
|
if __name__ == '__main__':
|
|
389
394
|
main()
|
|
390
|
-
|
|
@@ -172,7 +172,9 @@ def validate_batch_results(json_filename,options=None):
|
|
|
172
172
|
file = im['file']
|
|
173
173
|
|
|
174
174
|
if 'detections' in im and im['detections'] is not None:
|
|
175
|
+
|
|
175
176
|
for det in im['detections']:
|
|
177
|
+
|
|
176
178
|
assert 'category' in det, 'Image {} has a detection with no category'.format(file)
|
|
177
179
|
assert 'conf' in det, 'Image {} has a detection with no confidence'.format(file)
|
|
178
180
|
assert isinstance(det['conf'],float), \
|
|
@@ -182,6 +184,21 @@ def validate_batch_results(json_filename,options=None):
|
|
|
182
184
|
'Image {} has a detection with an unmapped category {}'.format(
|
|
183
185
|
file,det['category'])
|
|
184
186
|
|
|
187
|
+
if 'classifications' in det and det['classifications'] is not None:
|
|
188
|
+
for c in det['classifications']:
|
|
189
|
+
assert isinstance(c[0],str), \
|
|
190
|
+
'Image {} has an illegal classification category: {}'.format(file,c[0])
|
|
191
|
+
try:
|
|
192
|
+
_ = int(c[0])
|
|
193
|
+
except Exception:
|
|
194
|
+
raise ValueError('Image {} has an illegal classification category: {}'.format(
|
|
195
|
+
file,c[0]))
|
|
196
|
+
assert isinstance(c[1],float) or isinstance(c[1], int)
|
|
197
|
+
|
|
198
|
+
# ...for each detection
|
|
199
|
+
|
|
200
|
+
# ...if this image has a detections field
|
|
201
|
+
|
|
185
202
|
if options.check_image_existence:
|
|
186
203
|
|
|
187
204
|
if options.relative_path_base is None:
|
|
@@ -207,13 +224,19 @@ def validate_batch_results(json_filename,options=None):
|
|
|
207
224
|
if not isinstance(im['detections'],list):
|
|
208
225
|
raise ValueError('Invalid detections list for image {}'.format(im['file']))
|
|
209
226
|
|
|
227
|
+
if is_video_file(im['file']) and (format_version >= 1.5):
|
|
228
|
+
|
|
229
|
+
if 'frames_processed' not in im:
|
|
230
|
+
raise ValueError('Video without frames_processed field: {}'.format(im['file']))
|
|
231
|
+
|
|
210
232
|
if is_video_file(im['file']) and (format_version >= 1.4):
|
|
211
233
|
|
|
212
234
|
if 'frame_rate' not in im:
|
|
213
235
|
raise ValueError('Video without frame rate: {}'.format(im['file']))
|
|
214
236
|
if im['frame_rate'] < 0:
|
|
215
|
-
|
|
216
|
-
|
|
237
|
+
if 'failure' not in im:
|
|
238
|
+
raise ValueError('Video with illegal frame rate {}: {}'.format(
|
|
239
|
+
str(im['frame_rate']),im['file']))
|
|
217
240
|
if 'detections' in im and im['detections'] is not None:
|
|
218
241
|
for det in im['detections']:
|
|
219
242
|
if 'frame_number' not in det:
|
|
File without changes
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
Test script for validating NMS functionality with synthetic data.
|
|
4
|
+
|
|
5
|
+
This script creates synthetic detection scenarios where we know exactly which
|
|
6
|
+
boxes should be suppressed by NMS, allowing us to verify the correctness of
|
|
7
|
+
the NMS implementation.
|
|
8
|
+
|
|
9
|
+
This is an AI-generated test module.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#%% Imports
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from megadetector.detection.pytorch_detector import nms
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
#%% Support functions
|
|
22
|
+
|
|
23
|
+
def calculate_iou_boxes(box1, box2):
|
|
24
|
+
"""
|
|
25
|
+
Calculate IoU between two boxes in [x1, y1, x2, y2] format.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
box1: torch.Tensor or list of [x1, y1, x2, y2]
|
|
29
|
+
box2: torch.Tensor or list of [x1, y1, x2, y2]
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
float: IoU value between 0 and 1
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
if isinstance(box1, list):
|
|
36
|
+
box1 = torch.tensor(box1, dtype=torch.float)
|
|
37
|
+
if isinstance(box2, list):
|
|
38
|
+
box2 = torch.tensor(box2, dtype=torch.float)
|
|
39
|
+
|
|
40
|
+
# Calculate intersection area
|
|
41
|
+
x1_inter = max(box1[0], box2[0])
|
|
42
|
+
y1_inter = max(box1[1], box2[1])
|
|
43
|
+
x2_inter = min(box1[2], box2[2])
|
|
44
|
+
y2_inter = min(box1[3], box2[3])
|
|
45
|
+
|
|
46
|
+
if x2_inter <= x1_inter or y2_inter <= y1_inter:
|
|
47
|
+
return 0.0
|
|
48
|
+
|
|
49
|
+
intersection = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
|
50
|
+
|
|
51
|
+
# Calculate union area
|
|
52
|
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
53
|
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
54
|
+
union = area1 + area2 - intersection
|
|
55
|
+
|
|
56
|
+
return float(intersection / union) if union > 0 else 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def create_synthetic_predictions():
|
|
60
|
+
"""
|
|
61
|
+
Create synthetic model predictions for testing NMS.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: Synthetic predictions in the format expected by the NMS function
|
|
65
|
+
Shape: [batch_size=1, num_anchors, num_classes + 5]
|
|
66
|
+
|
|
67
|
+
Test scenarios:
|
|
68
|
+
1. Two highly overlapping boxes (IoU > 0.5) with different confidences - higher confidence should win
|
|
69
|
+
2. Two boxes with low overlap (IoU < 0.5) - both should be kept
|
|
70
|
+
3. Multiple boxes of different classes in same location - should be kept (class-independent NMS)
|
|
71
|
+
4. Three overlapping boxes with cascading confidences - highest confidence should win
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
# We'll create predictions for a 640x640 image with 3 classes
|
|
75
|
+
# Format: [x_center, y_center, width, height, objectness, class0_conf, class1_conf, class2_conf]
|
|
76
|
+
|
|
77
|
+
synthetic_boxes = []
|
|
78
|
+
|
|
79
|
+
# Scenario 1: Two highly overlapping boxes (IoU > 0.8)
|
|
80
|
+
# Box A: center=(100, 100), size=80x80, high confidence for class 0
|
|
81
|
+
# Box B: center=(105, 105), size=80x80, low confidence for class 0 (smaller offset = higher IoU)
|
|
82
|
+
# Expected: Box A kept, Box B suppressed
|
|
83
|
+
synthetic_boxes.append([100, 100, 80, 80, 0.9, 0.8, 0.1, 0.1]) # Box A - should be kept
|
|
84
|
+
synthetic_boxes.append([105, 105, 80, 80, 0.9, 0.5, 0.1, 0.1]) # Box B - should be suppressed
|
|
85
|
+
|
|
86
|
+
# Scenario 1b: Two nearly identical boxes (IoU ≈ 0.95)
|
|
87
|
+
# Box A2: center=(200, 100), size=60x60, high confidence for class 0
|
|
88
|
+
# Box B2: center=(202, 102), size=60x60, lower confidence for class 0
|
|
89
|
+
# Expected: Box A2 kept, Box B2 suppressed
|
|
90
|
+
synthetic_boxes.append([200, 100, 60, 60, 0.9, 0.9, 0.05, 0.05]) # Box A2 - should be kept
|
|
91
|
+
synthetic_boxes.append([202, 102, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box B2 - should be suppressed
|
|
92
|
+
|
|
93
|
+
# Scenario 2: Two boxes with low overlap (IoU ≈ 0.1)
|
|
94
|
+
# Box C: center=(300, 100), size=60x60, class 0
|
|
95
|
+
# Box D: center=(380, 100), size=60x60, class 0
|
|
96
|
+
# Expected: Both kept
|
|
97
|
+
synthetic_boxes.append([300, 100, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box C - should be kept
|
|
98
|
+
synthetic_boxes.append([380, 100, 60, 60, 0.9, 0.6, 0.1, 0.1]) # Box D - should be kept
|
|
99
|
+
|
|
100
|
+
# Scenario 3: Same location, different classes
|
|
101
|
+
# Box E: center=(100, 300), size=70x70, class 0
|
|
102
|
+
# Box F: center=(100, 300), size=70x70, class 1
|
|
103
|
+
# Expected: Both kept (class-independent NMS)
|
|
104
|
+
synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.7, 0.1, 0.1]) # Box E - class 0, should be kept
|
|
105
|
+
synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.1, 0.7, 0.1]) # Box F - class 1, should be kept
|
|
106
|
+
|
|
107
|
+
# Scenario 4: Three cascading overlapping boxes
|
|
108
|
+
# Box G: center=(500, 300), size=80x80, highest confidence
|
|
109
|
+
# Box H: center=(510, 310), size=80x80, medium confidence
|
|
110
|
+
# Box I: center=(520, 320), size=80x80, lowest confidence
|
|
111
|
+
# Expected: Only Box G kept
|
|
112
|
+
synthetic_boxes.append([500, 300, 80, 80, 0.95, 0.9, 0.05, 0.05]) # Box G - highest conf, should be kept
|
|
113
|
+
synthetic_boxes.append([510, 310, 80, 80, 0.9, 0.7, 0.1, 0.1]) # Box H - should be suppressed
|
|
114
|
+
synthetic_boxes.append([520, 320, 80, 80, 0.85, 0.6, 0.15, 0.15]) # Box I - should be suppressed
|
|
115
|
+
|
|
116
|
+
# Add some low-confidence boxes that should be filtered out before NMS
|
|
117
|
+
synthetic_boxes.append([200, 500, 50, 50, 0.1, 0.05, 0.02, 0.03]) # Too low confidence
|
|
118
|
+
|
|
119
|
+
# Convert to tensor format expected by NMS function
|
|
120
|
+
# We need to pad to a reasonable number of anchors (let's use 20)
|
|
121
|
+
num_anchors = 20
|
|
122
|
+
num_classes = 3
|
|
123
|
+
|
|
124
|
+
predictions = torch.zeros(1, num_anchors, num_classes + 5) # batch_size=1
|
|
125
|
+
|
|
126
|
+
# Fill in our synthetic boxes
|
|
127
|
+
for i, box_data in enumerate(synthetic_boxes):
|
|
128
|
+
if i < num_anchors:
|
|
129
|
+
predictions[0, i, :] = torch.tensor(box_data)
|
|
130
|
+
|
|
131
|
+
return predictions
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
#%% Main test function
|
|
135
|
+
|
|
136
|
+
def test_nms_functionality():
|
|
137
|
+
"""
|
|
138
|
+
Test the NMS function with synthetic data to verify correct behavior.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
print("Testing NMS functionality with synthetic data...")
|
|
142
|
+
|
|
143
|
+
# Generate synthetic predictions
|
|
144
|
+
predictions = create_synthetic_predictions()
|
|
145
|
+
print(f"Created synthetic predictions with shape: {predictions.shape}")
|
|
146
|
+
|
|
147
|
+
# Run NMS with IoU threshold = 0.5 and confidence threshold = 0.3
|
|
148
|
+
results = nms(predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
|
|
149
|
+
|
|
150
|
+
print(f"NMS returned {len(results)} batch results")
|
|
151
|
+
detections = results[0] # Get results for first (and only) image
|
|
152
|
+
print(f"Number of detections after NMS: {detections.shape[0]}")
|
|
153
|
+
|
|
154
|
+
assert detections.shape[0] != 0
|
|
155
|
+
|
|
156
|
+
print("\nDetections after NMS:")
|
|
157
|
+
print("Format: [x1, y1, x2, y2, confidence, class_id]")
|
|
158
|
+
for i, det in enumerate(detections):
|
|
159
|
+
x1, y1, x2, y2, conf, cls = det
|
|
160
|
+
center_x = (x1 + x2) / 2
|
|
161
|
+
center_y = (y1 + y2) / 2
|
|
162
|
+
width = x2 - x1
|
|
163
|
+
height = y2 - y1
|
|
164
|
+
print(f"Detection {i}: center=({center_x:.1f}, {center_y:.1f}), "
|
|
165
|
+
f"size={width:.1f}x{height:.1f}, conf={conf:.3f}, class={int(cls)}")
|
|
166
|
+
|
|
167
|
+
# Verify expected results
|
|
168
|
+
|
|
169
|
+
# Verify that high-confidence boxes are kept over low-confidence overlapping ones
|
|
170
|
+
# Look for the scenario 1 boxes (around center 100,100 area)
|
|
171
|
+
scenario1_boxes = []
|
|
172
|
+
for i, det in enumerate(detections):
|
|
173
|
+
x1, y1, x2, y2, conf, cls = det
|
|
174
|
+
center_x = (x1 + x2) / 2
|
|
175
|
+
center_y = (y1 + y2) / 2
|
|
176
|
+
if 80 <= center_x <= 130 and 80 <= center_y <= 130 and int(cls) == 0:
|
|
177
|
+
scenario1_boxes.append((i, center_x, center_y, conf))
|
|
178
|
+
|
|
179
|
+
# Check scenario 1b (around center 200,100 area)
|
|
180
|
+
scenario1b_boxes = []
|
|
181
|
+
for i, det in enumerate(detections):
|
|
182
|
+
x1, y1, x2, y2, conf, cls = det
|
|
183
|
+
center_x = (x1 + x2) / 2
|
|
184
|
+
center_y = (y1 + y2) / 2
|
|
185
|
+
if 180 <= center_x <= 220 and 80 <= center_y <= 120 and int(cls) == 0:
|
|
186
|
+
scenario1b_boxes.append((i, center_x, center_y, conf))
|
|
187
|
+
|
|
188
|
+
# Both scenario 1 and 1b should have exactly 1 detection each
|
|
189
|
+
total_high_overlap_boxes = len(scenario1_boxes) + len(scenario1b_boxes)
|
|
190
|
+
if total_high_overlap_boxes != 2:
|
|
191
|
+
print("Error: expected 2 detections in high-overlap scenarios (1 each), got {}".format(
|
|
192
|
+
total_high_overlap_boxes
|
|
193
|
+
))
|
|
194
|
+
print(f" Scenario 1: {len(scenario1_boxes)} boxes")
|
|
195
|
+
print(f" Scenario 1b: {len(scenario1b_boxes)} boxes")
|
|
196
|
+
raise AssertionError()
|
|
197
|
+
# Should be the high-confidence box (0.8 * 0.9 = 0.72)
|
|
198
|
+
elif len(scenario1_boxes) == 1 and scenario1_boxes[0][3] < 0.7:
|
|
199
|
+
print("Error: wrong box kept in scenario 1. Expected conf > 0.7, got {}".format(
|
|
200
|
+
scenario1_boxes[0][3]
|
|
201
|
+
))
|
|
202
|
+
raise AssertionError()
|
|
203
|
+
# Should be the high-confidence box (0.9 * 0.9 = 0.81)
|
|
204
|
+
elif len(scenario1b_boxes) == 1 and scenario1b_boxes[0][3] < 0.8:
|
|
205
|
+
print("Error: wrong box kept in scenario 1b. Expected conf > 0.8, got {}".format(
|
|
206
|
+
scenario1b_boxes[0][3]
|
|
207
|
+
))
|
|
208
|
+
raise AssertionError()
|
|
209
|
+
else:
|
|
210
|
+
print("Scenarios 1 & 1b passed: High-confidence boxes kept, low-confidence overlapping boxes suppressed")
|
|
211
|
+
|
|
212
|
+
# Verify IoU calculations and ensure suppression actually works
|
|
213
|
+
if len(scenario1_boxes) == 1 and len(scenario1b_boxes) == 1:
|
|
214
|
+
# Calculate what the IoU would have been between the boxes that were supposed to overlap
|
|
215
|
+
# Scenario 1: Box A (100,100,80x80) vs Box B (105,105,80x80)
|
|
216
|
+
box_a = [100-40, 100-40, 100+40, 100+40] # Convert center+size to corners
|
|
217
|
+
box_b = [105-40, 105-40, 105+40, 105+40]
|
|
218
|
+
iou_1 = calculate_iou_boxes(box_a, box_b)
|
|
219
|
+
|
|
220
|
+
# Scenario 1b: Box A2 (200,100,60x60) vs Box B2 (202,102,60x60)
|
|
221
|
+
box_a2 = [200-30, 100-30, 200+30, 100+30]
|
|
222
|
+
box_b2 = [202-30, 102-30, 202+30, 102+30]
|
|
223
|
+
iou_1b = calculate_iou_boxes(box_a2, box_b2)
|
|
224
|
+
|
|
225
|
+
print(f" Theoretical IoU for scenario 1 boxes: {iou_1:.3f}")
|
|
226
|
+
print(f" Theoretical IoU for scenario 1b boxes: {iou_1b:.3f}")
|
|
227
|
+
|
|
228
|
+
# If IoU > threshold, suppression should have occurred
|
|
229
|
+
if iou_1 <= 0.5:
|
|
230
|
+
print(f"Error: scenario 1 IoU {iou_1:.3f} is too low - test setup is invalid!")
|
|
231
|
+
raise AssertionError()
|
|
232
|
+
elif iou_1b <= 0.5:
|
|
233
|
+
print(f"Error: scenario 1b IoU {iou_1b:.3f} is too low - test setup is invalid!")
|
|
234
|
+
raise AssertionError()
|
|
235
|
+
else:
|
|
236
|
+
print(" High IoU confirmed - suppression was correct")
|
|
237
|
+
|
|
238
|
+
# Verify scenario 2 - both non-overlapping boxes should be kept
|
|
239
|
+
scenario2_boxes = []
|
|
240
|
+
for i, det in enumerate(detections):
|
|
241
|
+
x1, y1, x2, y2, conf, cls = det
|
|
242
|
+
center_x = (x1 + x2) / 2
|
|
243
|
+
center_y = (y1 + y2) / 2
|
|
244
|
+
if 270 <= center_x <= 410 and 70 <= center_y <= 130 and int(cls) == 0:
|
|
245
|
+
scenario2_boxes.append((i, center_x, center_y, conf))
|
|
246
|
+
|
|
247
|
+
if len(scenario2_boxes) != 2:
|
|
248
|
+
print(f"Error: expected 2 detections in scenario 2 area, got {len(scenario2_boxes)}")
|
|
249
|
+
raise AssertionError()
|
|
250
|
+
else:
|
|
251
|
+
print("Scenario 2 passed: Both non-overlapping boxes kept")
|
|
252
|
+
|
|
253
|
+
# Verify scenario 3 - different classes should both be kept
|
|
254
|
+
scenario3_boxes = []
|
|
255
|
+
for i, det in enumerate(detections):
|
|
256
|
+
x1, y1, x2, y2, conf, cls = det
|
|
257
|
+
center_x = (x1 + x2) / 2
|
|
258
|
+
center_y = (y1 + y2) / 2
|
|
259
|
+
if 65 <= center_x <= 135 and 265 <= center_y <= 335:
|
|
260
|
+
scenario3_boxes.append((i, center_x, center_y, conf, int(cls)))
|
|
261
|
+
|
|
262
|
+
classes_found = set(box[4] for box in scenario3_boxes)
|
|
263
|
+
if (len(scenario3_boxes) != 2) or (len(classes_found) != 2):
|
|
264
|
+
print("Error: expected 2 detections of different classes , got {} detections of classes {}".format(
|
|
265
|
+
len(scenario3_boxes),classes_found
|
|
266
|
+
))
|
|
267
|
+
raise AssertionError()
|
|
268
|
+
else:
|
|
269
|
+
print("Scenario 3 passed: Both different-class boxes kept")
|
|
270
|
+
|
|
271
|
+
# Verify scenario 4 - cascading overlapping boxes (only highest confidence should remain)
|
|
272
|
+
scenario4_boxes = []
|
|
273
|
+
for i, det in enumerate(detections):
|
|
274
|
+
x1, y1, x2, y2, conf, cls = det
|
|
275
|
+
center_x = (x1 + x2) / 2
|
|
276
|
+
center_y = (y1 + y2) / 2
|
|
277
|
+
if 460 <= center_x <= 560 and 260 <= center_y <= 360 and int(cls) == 0:
|
|
278
|
+
scenario4_boxes.append((i, center_x, center_y, conf))
|
|
279
|
+
|
|
280
|
+
print(f"\nScenario 4 analysis: Found {len(scenario4_boxes)} boxes in cascading area:")
|
|
281
|
+
for i, (det_idx, cx, cy, conf) in enumerate(scenario4_boxes):
|
|
282
|
+
print(f" Box {i}: center=({cx:.1f}, {cy:.1f}), conf={conf:.3f}")
|
|
283
|
+
|
|
284
|
+
# Check IoU between remaining boxes to ensure proper suppression
|
|
285
|
+
if len(scenario4_boxes) >= 2:
|
|
286
|
+
max_iou = 0
|
|
287
|
+
for i in range(len(scenario4_boxes)):
|
|
288
|
+
for j in range(i+1, len(scenario4_boxes)):
|
|
289
|
+
det1 = detections[scenario4_boxes[i][0]]
|
|
290
|
+
det2 = detections[scenario4_boxes[j][0]]
|
|
291
|
+
iou = calculate_iou_boxes(det1[:4], det2[:4])
|
|
292
|
+
print(f" IoU between box {i} and box {j}: {iou:.3f}")
|
|
293
|
+
max_iou = max(max_iou, iou)
|
|
294
|
+
|
|
295
|
+
if len(scenario4_boxes) == 1:
|
|
296
|
+
print("Scenario 4 passed: Only highest confidence box kept")
|
|
297
|
+
else:
|
|
298
|
+
# This is only OK if IoU < threshold
|
|
299
|
+
if max_iou < 0.5: # Our IoU threshold
|
|
300
|
+
print("Scenario 4 passed: Multiple boxes kept due to low IoU (< 0.5)")
|
|
301
|
+
else:
|
|
302
|
+
print(f"ERROR: Scenario 4 failed - boxes with IoU {max_iou:.3f} > 0.5 were not suppressed!")
|
|
303
|
+
raise AssertionError()
|
|
304
|
+
|
|
305
|
+
# Create a scenario that requires IoU calculation
|
|
306
|
+
print("\n=== COMPREHENSIVE IoU VALIDATION TEST ===")
|
|
307
|
+
|
|
308
|
+
# Create two identical boxes that should definitely be suppressed
|
|
309
|
+
identical_box_a = [100, 100, 50, 50, 0.9, 0.9, 0.05, 0.05] # High confidence
|
|
310
|
+
identical_box_b = [100, 100, 50, 50, 0.9, 0.7, 0.1, 0.1] # Lower confidence
|
|
311
|
+
|
|
312
|
+
test_predictions = torch.zeros(1, 5, 8) # Small batch for focused test
|
|
313
|
+
test_predictions[0, 0, :] = torch.tensor(identical_box_a)
|
|
314
|
+
test_predictions[0, 1, :] = torch.tensor(identical_box_b)
|
|
315
|
+
|
|
316
|
+
# Run NMS on this simple case
|
|
317
|
+
test_results = nms(test_predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
|
|
318
|
+
test_detections = test_results[0]
|
|
319
|
+
|
|
320
|
+
print(f"Identical boxes test: Input 2 identical boxes, got {test_detections.shape[0]} detections")
|
|
321
|
+
|
|
322
|
+
if test_detections.shape[0] != 1:
|
|
323
|
+
print(f"Error Two identical boxes should result in 1 detection, got {test_detections.shape[0]}")
|
|
324
|
+
raise AssertionError()
|
|
325
|
+
else:
|
|
326
|
+
# Verify it kept the higher confidence box
|
|
327
|
+
kept_conf = test_detections[0, 4].item()
|
|
328
|
+
expected_conf = 0.9 * 0.9 # objectness * class_conf
|
|
329
|
+
if abs(kept_conf - expected_conf) > 0.01:
|
|
330
|
+
print(f"ERROR: Wrong box kept. Expected conf ≈ {expected_conf:.3f}, got {kept_conf:.3f}")
|
|
331
|
+
raise AssertionError()
|
|
332
|
+
else:
|
|
333
|
+
print("Identical boxes test passed: Higher confidence box kept")
|
|
334
|
+
|
|
335
|
+
print("\nNMS tests passed")
|
megadetector/utils/ct_utils.py
CHANGED
|
@@ -983,13 +983,15 @@ def dict_to_kvp_list(d,
|
|
|
983
983
|
return s
|
|
984
984
|
|
|
985
985
|
|
|
986
|
-
def parse_bool_string(s):
|
|
986
|
+
def parse_bool_string(s, strict=False):
|
|
987
987
|
"""
|
|
988
988
|
Convert the strings "true" or "false" to boolean values. Case-insensitive, discards
|
|
989
989
|
leading and trailing whitespace. If s is already a bool, returns s.
|
|
990
990
|
|
|
991
991
|
Args:
|
|
992
992
|
s (str or bool): the string to parse, or the bool to return
|
|
993
|
+
strict (bool, optional): only allow "true" or "false", otherwise
|
|
994
|
+
handles "1", "0", "yes", and "no".
|
|
993
995
|
|
|
994
996
|
Returns:
|
|
995
997
|
bool: the parsed value
|
|
@@ -997,10 +999,17 @@ def parse_bool_string(s):
|
|
|
997
999
|
|
|
998
1000
|
if isinstance(s,bool):
|
|
999
1001
|
return s
|
|
1000
|
-
s = s.lower().strip()
|
|
1001
|
-
|
|
1002
|
+
s = str(s).lower().strip()
|
|
1003
|
+
|
|
1004
|
+
if strict:
|
|
1005
|
+
false_strings = ('false')
|
|
1006
|
+
true_strings = ('true')
|
|
1007
|
+
else:
|
|
1008
|
+
false_strings = ('no', 'false', 'f', 'n', '0')
|
|
1009
|
+
true_strings = ('yes', 'true', 't', 'y', '1')
|
|
1010
|
+
if s in true_strings:
|
|
1002
1011
|
return True
|
|
1003
|
-
elif s
|
|
1012
|
+
elif s in false_strings:
|
|
1004
1013
|
return False
|
|
1005
1014
|
else:
|
|
1006
1015
|
raise ValueError('Cannot parse bool from string {}'.format(str(s)))
|
|
@@ -1649,6 +1658,8 @@ def test_string_parsing():
|
|
|
1649
1658
|
assert not parse_bool_string("false")
|
|
1650
1659
|
assert not parse_bool_string("False")
|
|
1651
1660
|
assert not parse_bool_string(" FALSE ")
|
|
1661
|
+
assert parse_bool_string("1", strict=False)
|
|
1662
|
+
assert not parse_bool_string("0", strict=False)
|
|
1652
1663
|
assert parse_bool_string(True) is True # Test with existing bool
|
|
1653
1664
|
assert parse_bool_string(False) is False
|
|
1654
1665
|
try:
|
|
@@ -1657,7 +1668,7 @@ def test_string_parsing():
|
|
|
1657
1668
|
except ValueError:
|
|
1658
1669
|
pass
|
|
1659
1670
|
try:
|
|
1660
|
-
parse_bool_string("1")
|
|
1671
|
+
parse_bool_string("1",strict=True)
|
|
1661
1672
|
raise AssertionError("ValueError not raised for '1'")
|
|
1662
1673
|
except ValueError:
|
|
1663
1674
|
pass
|