megadetector 10.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (147) hide show
  1. megadetector/__init__.py +0 -0
  2. megadetector/api/__init__.py +0 -0
  3. megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
  4. megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
  5. megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
  6. megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
  7. megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
  8. megadetector/classification/__init__.py +0 -0
  9. megadetector/classification/aggregate_classifier_probs.py +108 -0
  10. megadetector/classification/analyze_failed_images.py +227 -0
  11. megadetector/classification/cache_batchapi_outputs.py +198 -0
  12. megadetector/classification/create_classification_dataset.py +626 -0
  13. megadetector/classification/crop_detections.py +516 -0
  14. megadetector/classification/csv_to_json.py +226 -0
  15. megadetector/classification/detect_and_crop.py +853 -0
  16. megadetector/classification/efficientnet/__init__.py +9 -0
  17. megadetector/classification/efficientnet/model.py +415 -0
  18. megadetector/classification/efficientnet/utils.py +608 -0
  19. megadetector/classification/evaluate_model.py +520 -0
  20. megadetector/classification/identify_mislabeled_candidates.py +152 -0
  21. megadetector/classification/json_to_azcopy_list.py +63 -0
  22. megadetector/classification/json_validator.py +696 -0
  23. megadetector/classification/map_classification_categories.py +276 -0
  24. megadetector/classification/merge_classification_detection_output.py +509 -0
  25. megadetector/classification/prepare_classification_script.py +194 -0
  26. megadetector/classification/prepare_classification_script_mc.py +228 -0
  27. megadetector/classification/run_classifier.py +287 -0
  28. megadetector/classification/save_mislabeled.py +110 -0
  29. megadetector/classification/train_classifier.py +827 -0
  30. megadetector/classification/train_classifier_tf.py +725 -0
  31. megadetector/classification/train_utils.py +323 -0
  32. megadetector/data_management/__init__.py +0 -0
  33. megadetector/data_management/animl_to_md.py +161 -0
  34. megadetector/data_management/annotations/__init__.py +0 -0
  35. megadetector/data_management/annotations/annotation_constants.py +33 -0
  36. megadetector/data_management/camtrap_dp_to_coco.py +270 -0
  37. megadetector/data_management/cct_json_utils.py +566 -0
  38. megadetector/data_management/cct_to_md.py +184 -0
  39. megadetector/data_management/cct_to_wi.py +293 -0
  40. megadetector/data_management/coco_to_labelme.py +284 -0
  41. megadetector/data_management/coco_to_yolo.py +702 -0
  42. megadetector/data_management/databases/__init__.py +0 -0
  43. megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
  44. megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
  45. megadetector/data_management/databases/integrity_check_json_db.py +528 -0
  46. megadetector/data_management/databases/subset_json_db.py +195 -0
  47. megadetector/data_management/generate_crops_from_cct.py +200 -0
  48. megadetector/data_management/get_image_sizes.py +164 -0
  49. megadetector/data_management/labelme_to_coco.py +559 -0
  50. megadetector/data_management/labelme_to_yolo.py +349 -0
  51. megadetector/data_management/lila/__init__.py +0 -0
  52. megadetector/data_management/lila/create_lila_blank_set.py +556 -0
  53. megadetector/data_management/lila/create_lila_test_set.py +187 -0
  54. megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
  55. megadetector/data_management/lila/download_lila_subset.py +182 -0
  56. megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
  57. megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
  58. megadetector/data_management/lila/get_lila_image_counts.py +112 -0
  59. megadetector/data_management/lila/lila_common.py +319 -0
  60. megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
  61. megadetector/data_management/mewc_to_md.py +344 -0
  62. megadetector/data_management/ocr_tools.py +873 -0
  63. megadetector/data_management/read_exif.py +964 -0
  64. megadetector/data_management/remap_coco_categories.py +195 -0
  65. megadetector/data_management/remove_exif.py +156 -0
  66. megadetector/data_management/rename_images.py +194 -0
  67. megadetector/data_management/resize_coco_dataset.py +663 -0
  68. megadetector/data_management/speciesnet_to_md.py +41 -0
  69. megadetector/data_management/wi_download_csv_to_coco.py +247 -0
  70. megadetector/data_management/yolo_output_to_md_output.py +594 -0
  71. megadetector/data_management/yolo_to_coco.py +876 -0
  72. megadetector/data_management/zamba_to_md.py +188 -0
  73. megadetector/detection/__init__.py +0 -0
  74. megadetector/detection/change_detection.py +840 -0
  75. megadetector/detection/process_video.py +479 -0
  76. megadetector/detection/pytorch_detector.py +1451 -0
  77. megadetector/detection/run_detector.py +1267 -0
  78. megadetector/detection/run_detector_batch.py +2159 -0
  79. megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
  80. megadetector/detection/run_md_and_speciesnet.py +1494 -0
  81. megadetector/detection/run_tiled_inference.py +1038 -0
  82. megadetector/detection/tf_detector.py +209 -0
  83. megadetector/detection/video_utils.py +1379 -0
  84. megadetector/postprocessing/__init__.py +0 -0
  85. megadetector/postprocessing/add_max_conf.py +72 -0
  86. megadetector/postprocessing/categorize_detections_by_size.py +166 -0
  87. megadetector/postprocessing/classification_postprocessing.py +1752 -0
  88. megadetector/postprocessing/combine_batch_outputs.py +249 -0
  89. megadetector/postprocessing/compare_batch_results.py +2110 -0
  90. megadetector/postprocessing/convert_output_format.py +403 -0
  91. megadetector/postprocessing/create_crop_folder.py +629 -0
  92. megadetector/postprocessing/detector_calibration.py +570 -0
  93. megadetector/postprocessing/generate_csv_report.py +522 -0
  94. megadetector/postprocessing/load_api_results.py +223 -0
  95. megadetector/postprocessing/md_to_coco.py +428 -0
  96. megadetector/postprocessing/md_to_labelme.py +351 -0
  97. megadetector/postprocessing/md_to_wi.py +41 -0
  98. megadetector/postprocessing/merge_detections.py +392 -0
  99. megadetector/postprocessing/postprocess_batch_results.py +2077 -0
  100. megadetector/postprocessing/remap_detection_categories.py +226 -0
  101. megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
  102. megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
  103. megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
  104. megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
  105. megadetector/postprocessing/separate_detections_into_folders.py +795 -0
  106. megadetector/postprocessing/subset_json_detector_output.py +964 -0
  107. megadetector/postprocessing/top_folders_to_bottom.py +238 -0
  108. megadetector/postprocessing/validate_batch_results.py +332 -0
  109. megadetector/taxonomy_mapping/__init__.py +0 -0
  110. megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
  111. megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
  112. megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
  113. megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
  114. megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
  115. megadetector/taxonomy_mapping/simple_image_download.py +224 -0
  116. megadetector/taxonomy_mapping/species_lookup.py +1008 -0
  117. megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
  118. megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
  119. megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
  120. megadetector/tests/__init__.py +0 -0
  121. megadetector/tests/test_nms_synthetic.py +335 -0
  122. megadetector/utils/__init__.py +0 -0
  123. megadetector/utils/ct_utils.py +1857 -0
  124. megadetector/utils/directory_listing.py +199 -0
  125. megadetector/utils/extract_frames_from_video.py +307 -0
  126. megadetector/utils/gpu_test.py +125 -0
  127. megadetector/utils/md_tests.py +2072 -0
  128. megadetector/utils/path_utils.py +2832 -0
  129. megadetector/utils/process_utils.py +172 -0
  130. megadetector/utils/split_locations_into_train_val.py +237 -0
  131. megadetector/utils/string_utils.py +234 -0
  132. megadetector/utils/url_utils.py +825 -0
  133. megadetector/utils/wi_platform_utils.py +968 -0
  134. megadetector/utils/wi_taxonomy_utils.py +1759 -0
  135. megadetector/utils/write_html_image_list.py +239 -0
  136. megadetector/visualization/__init__.py +0 -0
  137. megadetector/visualization/plot_utils.py +309 -0
  138. megadetector/visualization/render_images_with_thumbnails.py +243 -0
  139. megadetector/visualization/visualization_utils.py +1940 -0
  140. megadetector/visualization/visualize_db.py +630 -0
  141. megadetector/visualization/visualize_detector_output.py +479 -0
  142. megadetector/visualization/visualize_video_output.py +705 -0
  143. megadetector-10.0.13.dist-info/METADATA +134 -0
  144. megadetector-10.0.13.dist-info/RECORD +147 -0
  145. megadetector-10.0.13.dist-info/WHEEL +5 -0
  146. megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
  147. megadetector-10.0.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,335 @@
1
+ """
2
+
3
+ Test script for validating NMS functionality with synthetic data.
4
+
5
+ This script creates synthetic detection scenarios where we know exactly which
6
+ boxes should be suppressed by NMS, allowing us to verify the correctness of
7
+ the NMS implementation.
8
+
9
+ This is an AI-generated test module.
10
+
11
+ """
12
+
13
+
14
+ #%% Imports
15
+
16
+ import torch
17
+
18
+ from megadetector.detection.pytorch_detector import nms
19
+
20
+
21
+ #%% Support functions
22
+
23
+ def calculate_iou_boxes(box1, box2):
24
+ """
25
+ Calculate IoU between two boxes in [x1, y1, x2, y2] format.
26
+
27
+ Args:
28
+ box1: torch.Tensor or list of [x1, y1, x2, y2]
29
+ box2: torch.Tensor or list of [x1, y1, x2, y2]
30
+
31
+ Returns:
32
+ float: IoU value between 0 and 1
33
+ """
34
+
35
+ if isinstance(box1, list):
36
+ box1 = torch.tensor(box1, dtype=torch.float)
37
+ if isinstance(box2, list):
38
+ box2 = torch.tensor(box2, dtype=torch.float)
39
+
40
+ # Calculate intersection area
41
+ x1_inter = max(box1[0], box2[0])
42
+ y1_inter = max(box1[1], box2[1])
43
+ x2_inter = min(box1[2], box2[2])
44
+ y2_inter = min(box1[3], box2[3])
45
+
46
+ if x2_inter <= x1_inter or y2_inter <= y1_inter:
47
+ return 0.0
48
+
49
+ intersection = (x2_inter - x1_inter) * (y2_inter - y1_inter)
50
+
51
+ # Calculate union area
52
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
53
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
54
+ union = area1 + area2 - intersection
55
+
56
+ return float(intersection / union) if union > 0 else 0.0
57
+
58
+
59
+ def create_synthetic_predictions():
60
+ """
61
+ Create synthetic model predictions for testing NMS.
62
+
63
+ Returns:
64
+ torch.Tensor: Synthetic predictions in the format expected by the NMS function
65
+ Shape: [batch_size=1, num_anchors, num_classes + 5]
66
+
67
+ Test scenarios:
68
+ 1. Two highly overlapping boxes (IoU > 0.5) with different confidences - higher confidence should win
69
+ 2. Two boxes with low overlap (IoU < 0.5) - both should be kept
70
+ 3. Multiple boxes of different classes in same location - should be kept (class-independent NMS)
71
+ 4. Three overlapping boxes with cascading confidences - highest confidence should win
72
+ """
73
+
74
+ # We'll create predictions for a 640x640 image with 3 classes
75
+ # Format: [x_center, y_center, width, height, objectness, class0_conf, class1_conf, class2_conf]
76
+
77
+ synthetic_boxes = []
78
+
79
+ # Scenario 1: Two highly overlapping boxes (IoU > 0.8)
80
+ # Box A: center=(100, 100), size=80x80, high confidence for class 0
81
+ # Box B: center=(105, 105), size=80x80, low confidence for class 0 (smaller offset = higher IoU)
82
+ # Expected: Box A kept, Box B suppressed
83
+ synthetic_boxes.append([100, 100, 80, 80, 0.9, 0.8, 0.1, 0.1]) # Box A - should be kept
84
+ synthetic_boxes.append([105, 105, 80, 80, 0.9, 0.5, 0.1, 0.1]) # Box B - should be suppressed
85
+
86
+ # Scenario 1b: Two nearly identical boxes (IoU ≈ 0.95)
87
+ # Box A2: center=(200, 100), size=60x60, high confidence for class 0
88
+ # Box B2: center=(202, 102), size=60x60, lower confidence for class 0
89
+ # Expected: Box A2 kept, Box B2 suppressed
90
+ synthetic_boxes.append([200, 100, 60, 60, 0.9, 0.9, 0.05, 0.05]) # Box A2 - should be kept
91
+ synthetic_boxes.append([202, 102, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box B2 - should be suppressed
92
+
93
+ # Scenario 2: Two boxes with low overlap (IoU ≈ 0.1)
94
+ # Box C: center=(300, 100), size=60x60, class 0
95
+ # Box D: center=(380, 100), size=60x60, class 0
96
+ # Expected: Both kept
97
+ synthetic_boxes.append([300, 100, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box C - should be kept
98
+ synthetic_boxes.append([380, 100, 60, 60, 0.9, 0.6, 0.1, 0.1]) # Box D - should be kept
99
+
100
+ # Scenario 3: Same location, different classes
101
+ # Box E: center=(100, 300), size=70x70, class 0
102
+ # Box F: center=(100, 300), size=70x70, class 1
103
+ # Expected: Both kept (class-independent NMS)
104
+ synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.7, 0.1, 0.1]) # Box E - class 0, should be kept
105
+ synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.1, 0.7, 0.1]) # Box F - class 1, should be kept
106
+
107
+ # Scenario 4: Three cascading overlapping boxes
108
+ # Box G: center=(500, 300), size=80x80, highest confidence
109
+ # Box H: center=(510, 310), size=80x80, medium confidence
110
+ # Box I: center=(520, 320), size=80x80, lowest confidence
111
+ # Expected: Only Box G kept
112
+ synthetic_boxes.append([500, 300, 80, 80, 0.95, 0.9, 0.05, 0.05]) # Box G - highest conf, should be kept
113
+ synthetic_boxes.append([510, 310, 80, 80, 0.9, 0.7, 0.1, 0.1]) # Box H - should be suppressed
114
+ synthetic_boxes.append([520, 320, 80, 80, 0.85, 0.6, 0.15, 0.15]) # Box I - should be suppressed
115
+
116
+ # Add some low-confidence boxes that should be filtered out before NMS
117
+ synthetic_boxes.append([200, 500, 50, 50, 0.1, 0.05, 0.02, 0.03]) # Too low confidence
118
+
119
+ # Convert to tensor format expected by NMS function
120
+ # We need to pad to a reasonable number of anchors (let's use 20)
121
+ num_anchors = 20
122
+ num_classes = 3
123
+
124
+ predictions = torch.zeros(1, num_anchors, num_classes + 5) # batch_size=1
125
+
126
+ # Fill in our synthetic boxes
127
+ for i, box_data in enumerate(synthetic_boxes):
128
+ if i < num_anchors:
129
+ predictions[0, i, :] = torch.tensor(box_data)
130
+
131
+ return predictions
132
+
133
+
134
+ #%% Main test function
135
+
136
+ def test_nms_functionality():
137
+ """
138
+ Test the NMS function with synthetic data to verify correct behavior.
139
+ """
140
+
141
+ print("Testing NMS functionality with synthetic data...")
142
+
143
+ # Generate synthetic predictions
144
+ predictions = create_synthetic_predictions()
145
+ print(f"Created synthetic predictions with shape: {predictions.shape}")
146
+
147
+ # Run NMS with IoU threshold = 0.5 and confidence threshold = 0.3
148
+ results = nms(predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
149
+
150
+ print(f"NMS returned {len(results)} batch results")
151
+ detections = results[0] # Get results for first (and only) image
152
+ print(f"Number of detections after NMS: {detections.shape[0]}")
153
+
154
+ assert detections.shape[0] != 0
155
+
156
+ print("\nDetections after NMS:")
157
+ print("Format: [x1, y1, x2, y2, confidence, class_id]")
158
+ for i, det in enumerate(detections):
159
+ x1, y1, x2, y2, conf, cls = det
160
+ center_x = (x1 + x2) / 2
161
+ center_y = (y1 + y2) / 2
162
+ width = x2 - x1
163
+ height = y2 - y1
164
+ print(f"Detection {i}: center=({center_x:.1f}, {center_y:.1f}), "
165
+ f"size={width:.1f}x{height:.1f}, conf={conf:.3f}, class={int(cls)}")
166
+
167
+ # Verify expected results
168
+
169
+ # Verify that high-confidence boxes are kept over low-confidence overlapping ones
170
+ # Look for the scenario 1 boxes (around center 100,100 area)
171
+ scenario1_boxes = []
172
+ for i, det in enumerate(detections):
173
+ x1, y1, x2, y2, conf, cls = det
174
+ center_x = (x1 + x2) / 2
175
+ center_y = (y1 + y2) / 2
176
+ if 80 <= center_x <= 130 and 80 <= center_y <= 130 and int(cls) == 0:
177
+ scenario1_boxes.append((i, center_x, center_y, conf))
178
+
179
+ # Check scenario 1b (around center 200,100 area)
180
+ scenario1b_boxes = []
181
+ for i, det in enumerate(detections):
182
+ x1, y1, x2, y2, conf, cls = det
183
+ center_x = (x1 + x2) / 2
184
+ center_y = (y1 + y2) / 2
185
+ if 180 <= center_x <= 220 and 80 <= center_y <= 120 and int(cls) == 0:
186
+ scenario1b_boxes.append((i, center_x, center_y, conf))
187
+
188
+ # Both scenario 1 and 1b should have exactly 1 detection each
189
+ total_high_overlap_boxes = len(scenario1_boxes) + len(scenario1b_boxes)
190
+ if total_high_overlap_boxes != 2:
191
+ print("Error: expected 2 detections in high-overlap scenarios (1 each), got {}".format(
192
+ total_high_overlap_boxes
193
+ ))
194
+ print(f" Scenario 1: {len(scenario1_boxes)} boxes")
195
+ print(f" Scenario 1b: {len(scenario1b_boxes)} boxes")
196
+ raise AssertionError()
197
+ # Should be the high-confidence box (0.8 * 0.9 = 0.72)
198
+ elif len(scenario1_boxes) == 1 and scenario1_boxes[0][3] < 0.7:
199
+ print("Error: wrong box kept in scenario 1. Expected conf > 0.7, got {}".format(
200
+ scenario1_boxes[0][3]
201
+ ))
202
+ raise AssertionError()
203
+ # Should be the high-confidence box (0.9 * 0.9 = 0.81)
204
+ elif len(scenario1b_boxes) == 1 and scenario1b_boxes[0][3] < 0.8:
205
+ print("Error: wrong box kept in scenario 1b. Expected conf > 0.8, got {}".format(
206
+ scenario1b_boxes[0][3]
207
+ ))
208
+ raise AssertionError()
209
+ else:
210
+ print("Scenarios 1 & 1b passed: High-confidence boxes kept, low-confidence overlapping boxes suppressed")
211
+
212
+ # Verify IoU calculations and ensure suppression actually works
213
+ if len(scenario1_boxes) == 1 and len(scenario1b_boxes) == 1:
214
+ # Calculate what the IoU would have been between the boxes that were supposed to overlap
215
+ # Scenario 1: Box A (100,100,80x80) vs Box B (105,105,80x80)
216
+ box_a = [100-40, 100-40, 100+40, 100+40] # Convert center+size to corners
217
+ box_b = [105-40, 105-40, 105+40, 105+40]
218
+ iou_1 = calculate_iou_boxes(box_a, box_b)
219
+
220
+ # Scenario 1b: Box A2 (200,100,60x60) vs Box B2 (202,102,60x60)
221
+ box_a2 = [200-30, 100-30, 200+30, 100+30]
222
+ box_b2 = [202-30, 102-30, 202+30, 102+30]
223
+ iou_1b = calculate_iou_boxes(box_a2, box_b2)
224
+
225
+ print(f" Theoretical IoU for scenario 1 boxes: {iou_1:.3f}")
226
+ print(f" Theoretical IoU for scenario 1b boxes: {iou_1b:.3f}")
227
+
228
+ # If IoU > threshold, suppression should have occurred
229
+ if iou_1 <= 0.5:
230
+ print(f"Error: scenario 1 IoU {iou_1:.3f} is too low - test setup is invalid!")
231
+ raise AssertionError()
232
+ elif iou_1b <= 0.5:
233
+ print(f"Error: scenario 1b IoU {iou_1b:.3f} is too low - test setup is invalid!")
234
+ raise AssertionError()
235
+ else:
236
+ print(" High IoU confirmed - suppression was correct")
237
+
238
+ # Verify scenario 2 - both non-overlapping boxes should be kept
239
+ scenario2_boxes = []
240
+ for i, det in enumerate(detections):
241
+ x1, y1, x2, y2, conf, cls = det
242
+ center_x = (x1 + x2) / 2
243
+ center_y = (y1 + y2) / 2
244
+ if 270 <= center_x <= 410 and 70 <= center_y <= 130 and int(cls) == 0:
245
+ scenario2_boxes.append((i, center_x, center_y, conf))
246
+
247
+ if len(scenario2_boxes) != 2:
248
+ print(f"Error: expected 2 detections in scenario 2 area, got {len(scenario2_boxes)}")
249
+ raise AssertionError()
250
+ else:
251
+ print("Scenario 2 passed: Both non-overlapping boxes kept")
252
+
253
+ # Verify scenario 3 - different classes should both be kept
254
+ scenario3_boxes = []
255
+ for i, det in enumerate(detections):
256
+ x1, y1, x2, y2, conf, cls = det
257
+ center_x = (x1 + x2) / 2
258
+ center_y = (y1 + y2) / 2
259
+ if 65 <= center_x <= 135 and 265 <= center_y <= 335:
260
+ scenario3_boxes.append((i, center_x, center_y, conf, int(cls)))
261
+
262
+ classes_found = set(box[4] for box in scenario3_boxes)
263
+ if (len(scenario3_boxes) != 2) or (len(classes_found) != 2):
264
+ print("Error: expected 2 detections of different classes , got {} detections of classes {}".format(
265
+ len(scenario3_boxes),classes_found
266
+ ))
267
+ raise AssertionError()
268
+ else:
269
+ print("Scenario 3 passed: Both different-class boxes kept")
270
+
271
+ # Verify scenario 4 - cascading overlapping boxes (only highest confidence should remain)
272
+ scenario4_boxes = []
273
+ for i, det in enumerate(detections):
274
+ x1, y1, x2, y2, conf, cls = det
275
+ center_x = (x1 + x2) / 2
276
+ center_y = (y1 + y2) / 2
277
+ if 460 <= center_x <= 560 and 260 <= center_y <= 360 and int(cls) == 0:
278
+ scenario4_boxes.append((i, center_x, center_y, conf))
279
+
280
+ print(f"\nScenario 4 analysis: Found {len(scenario4_boxes)} boxes in cascading area:")
281
+ for i, (det_idx, cx, cy, conf) in enumerate(scenario4_boxes):
282
+ print(f" Box {i}: center=({cx:.1f}, {cy:.1f}), conf={conf:.3f}")
283
+
284
+ # Check IoU between remaining boxes to ensure proper suppression
285
+ if len(scenario4_boxes) >= 2:
286
+ max_iou = 0
287
+ for i in range(len(scenario4_boxes)):
288
+ for j in range(i+1, len(scenario4_boxes)):
289
+ det1 = detections[scenario4_boxes[i][0]]
290
+ det2 = detections[scenario4_boxes[j][0]]
291
+ iou = calculate_iou_boxes(det1[:4], det2[:4])
292
+ print(f" IoU between box {i} and box {j}: {iou:.3f}")
293
+ max_iou = max(max_iou, iou)
294
+
295
+ if len(scenario4_boxes) == 1:
296
+ print("Scenario 4 passed: Only highest confidence box kept")
297
+ else:
298
+ # This is only OK if IoU < threshold
299
+ if max_iou < 0.5: # Our IoU threshold
300
+ print("Scenario 4 passed: Multiple boxes kept due to low IoU (< 0.5)")
301
+ else:
302
+ print(f"ERROR: Scenario 4 failed - boxes with IoU {max_iou:.3f} > 0.5 were not suppressed!")
303
+ raise AssertionError()
304
+
305
+ # Create a scenario that requires IoU calculation
306
+ print("\n=== COMPREHENSIVE IoU VALIDATION TEST ===")
307
+
308
+ # Create two identical boxes that should definitely be suppressed
309
+ identical_box_a = [100, 100, 50, 50, 0.9, 0.9, 0.05, 0.05] # High confidence
310
+ identical_box_b = [100, 100, 50, 50, 0.9, 0.7, 0.1, 0.1] # Lower confidence
311
+
312
+ test_predictions = torch.zeros(1, 5, 8) # Small batch for focused test
313
+ test_predictions[0, 0, :] = torch.tensor(identical_box_a)
314
+ test_predictions[0, 1, :] = torch.tensor(identical_box_b)
315
+
316
+ # Run NMS on this simple case
317
+ test_results = nms(test_predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
318
+ test_detections = test_results[0]
319
+
320
+ print(f"Identical boxes test: Input 2 identical boxes, got {test_detections.shape[0]} detections")
321
+
322
+ if test_detections.shape[0] != 1:
323
+ print(f"Error Two identical boxes should result in 1 detection, got {test_detections.shape[0]}")
324
+ raise AssertionError()
325
+ else:
326
+ # Verify it kept the higher confidence box
327
+ kept_conf = test_detections[0, 4].item()
328
+ expected_conf = 0.9 * 0.9 # objectness * class_conf
329
+ if abs(kept_conf - expected_conf) > 0.01:
330
+ print(f"ERROR: Wrong box kept. Expected conf ≈ {expected_conf:.3f}, got {kept_conf:.3f}")
331
+ raise AssertionError()
332
+ else:
333
+ print("Identical boxes test passed: Higher confidence box kept")
334
+
335
+ print("\nNMS tests passed")
File without changes