megadetector 10.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +702 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +528 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +187 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +663 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +876 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2159 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1494 -0
- megadetector/detection/run_tiled_inference.py +1038 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1752 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2077 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +224 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2832 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1759 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1940 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +479 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.13.dist-info/METADATA +134 -0
- megadetector-10.0.13.dist-info/RECORD +147 -0
- megadetector-10.0.13.dist-info/WHEEL +5 -0
- megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
Test script for validating NMS functionality with synthetic data.
|
|
4
|
+
|
|
5
|
+
This script creates synthetic detection scenarios where we know exactly which
|
|
6
|
+
boxes should be suppressed by NMS, allowing us to verify the correctness of
|
|
7
|
+
the NMS implementation.
|
|
8
|
+
|
|
9
|
+
This is an AI-generated test module.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#%% Imports
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from megadetector.detection.pytorch_detector import nms
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
#%% Support functions
|
|
22
|
+
|
|
23
|
+
def calculate_iou_boxes(box1, box2):
|
|
24
|
+
"""
|
|
25
|
+
Calculate IoU between two boxes in [x1, y1, x2, y2] format.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
box1: torch.Tensor or list of [x1, y1, x2, y2]
|
|
29
|
+
box2: torch.Tensor or list of [x1, y1, x2, y2]
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
float: IoU value between 0 and 1
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
if isinstance(box1, list):
|
|
36
|
+
box1 = torch.tensor(box1, dtype=torch.float)
|
|
37
|
+
if isinstance(box2, list):
|
|
38
|
+
box2 = torch.tensor(box2, dtype=torch.float)
|
|
39
|
+
|
|
40
|
+
# Calculate intersection area
|
|
41
|
+
x1_inter = max(box1[0], box2[0])
|
|
42
|
+
y1_inter = max(box1[1], box2[1])
|
|
43
|
+
x2_inter = min(box1[2], box2[2])
|
|
44
|
+
y2_inter = min(box1[3], box2[3])
|
|
45
|
+
|
|
46
|
+
if x2_inter <= x1_inter or y2_inter <= y1_inter:
|
|
47
|
+
return 0.0
|
|
48
|
+
|
|
49
|
+
intersection = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
|
50
|
+
|
|
51
|
+
# Calculate union area
|
|
52
|
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
53
|
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
54
|
+
union = area1 + area2 - intersection
|
|
55
|
+
|
|
56
|
+
return float(intersection / union) if union > 0 else 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def create_synthetic_predictions():
|
|
60
|
+
"""
|
|
61
|
+
Create synthetic model predictions for testing NMS.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: Synthetic predictions in the format expected by the NMS function
|
|
65
|
+
Shape: [batch_size=1, num_anchors, num_classes + 5]
|
|
66
|
+
|
|
67
|
+
Test scenarios:
|
|
68
|
+
1. Two highly overlapping boxes (IoU > 0.5) with different confidences - higher confidence should win
|
|
69
|
+
2. Two boxes with low overlap (IoU < 0.5) - both should be kept
|
|
70
|
+
3. Multiple boxes of different classes in same location - should be kept (class-independent NMS)
|
|
71
|
+
4. Three overlapping boxes with cascading confidences - highest confidence should win
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
# We'll create predictions for a 640x640 image with 3 classes
|
|
75
|
+
# Format: [x_center, y_center, width, height, objectness, class0_conf, class1_conf, class2_conf]
|
|
76
|
+
|
|
77
|
+
synthetic_boxes = []
|
|
78
|
+
|
|
79
|
+
# Scenario 1: Two highly overlapping boxes (IoU > 0.8)
|
|
80
|
+
# Box A: center=(100, 100), size=80x80, high confidence for class 0
|
|
81
|
+
# Box B: center=(105, 105), size=80x80, low confidence for class 0 (smaller offset = higher IoU)
|
|
82
|
+
# Expected: Box A kept, Box B suppressed
|
|
83
|
+
synthetic_boxes.append([100, 100, 80, 80, 0.9, 0.8, 0.1, 0.1]) # Box A - should be kept
|
|
84
|
+
synthetic_boxes.append([105, 105, 80, 80, 0.9, 0.5, 0.1, 0.1]) # Box B - should be suppressed
|
|
85
|
+
|
|
86
|
+
# Scenario 1b: Two nearly identical boxes (IoU ≈ 0.95)
|
|
87
|
+
# Box A2: center=(200, 100), size=60x60, high confidence for class 0
|
|
88
|
+
# Box B2: center=(202, 102), size=60x60, lower confidence for class 0
|
|
89
|
+
# Expected: Box A2 kept, Box B2 suppressed
|
|
90
|
+
synthetic_boxes.append([200, 100, 60, 60, 0.9, 0.9, 0.05, 0.05]) # Box A2 - should be kept
|
|
91
|
+
synthetic_boxes.append([202, 102, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box B2 - should be suppressed
|
|
92
|
+
|
|
93
|
+
# Scenario 2: Two boxes with low overlap (IoU ≈ 0.1)
|
|
94
|
+
# Box C: center=(300, 100), size=60x60, class 0
|
|
95
|
+
# Box D: center=(380, 100), size=60x60, class 0
|
|
96
|
+
# Expected: Both kept
|
|
97
|
+
synthetic_boxes.append([300, 100, 60, 60, 0.9, 0.7, 0.1, 0.1]) # Box C - should be kept
|
|
98
|
+
synthetic_boxes.append([380, 100, 60, 60, 0.9, 0.6, 0.1, 0.1]) # Box D - should be kept
|
|
99
|
+
|
|
100
|
+
# Scenario 3: Same location, different classes
|
|
101
|
+
# Box E: center=(100, 300), size=70x70, class 0
|
|
102
|
+
# Box F: center=(100, 300), size=70x70, class 1
|
|
103
|
+
# Expected: Both kept (class-independent NMS)
|
|
104
|
+
synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.7, 0.1, 0.1]) # Box E - class 0, should be kept
|
|
105
|
+
synthetic_boxes.append([100, 300, 70, 70, 0.9, 0.1, 0.7, 0.1]) # Box F - class 1, should be kept
|
|
106
|
+
|
|
107
|
+
# Scenario 4: Three cascading overlapping boxes
|
|
108
|
+
# Box G: center=(500, 300), size=80x80, highest confidence
|
|
109
|
+
# Box H: center=(510, 310), size=80x80, medium confidence
|
|
110
|
+
# Box I: center=(520, 320), size=80x80, lowest confidence
|
|
111
|
+
# Expected: Only Box G kept
|
|
112
|
+
synthetic_boxes.append([500, 300, 80, 80, 0.95, 0.9, 0.05, 0.05]) # Box G - highest conf, should be kept
|
|
113
|
+
synthetic_boxes.append([510, 310, 80, 80, 0.9, 0.7, 0.1, 0.1]) # Box H - should be suppressed
|
|
114
|
+
synthetic_boxes.append([520, 320, 80, 80, 0.85, 0.6, 0.15, 0.15]) # Box I - should be suppressed
|
|
115
|
+
|
|
116
|
+
# Add some low-confidence boxes that should be filtered out before NMS
|
|
117
|
+
synthetic_boxes.append([200, 500, 50, 50, 0.1, 0.05, 0.02, 0.03]) # Too low confidence
|
|
118
|
+
|
|
119
|
+
# Convert to tensor format expected by NMS function
|
|
120
|
+
# We need to pad to a reasonable number of anchors (let's use 20)
|
|
121
|
+
num_anchors = 20
|
|
122
|
+
num_classes = 3
|
|
123
|
+
|
|
124
|
+
predictions = torch.zeros(1, num_anchors, num_classes + 5) # batch_size=1
|
|
125
|
+
|
|
126
|
+
# Fill in our synthetic boxes
|
|
127
|
+
for i, box_data in enumerate(synthetic_boxes):
|
|
128
|
+
if i < num_anchors:
|
|
129
|
+
predictions[0, i, :] = torch.tensor(box_data)
|
|
130
|
+
|
|
131
|
+
return predictions
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
#%% Main test function
|
|
135
|
+
|
|
136
|
+
def test_nms_functionality():
|
|
137
|
+
"""
|
|
138
|
+
Test the NMS function with synthetic data to verify correct behavior.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
print("Testing NMS functionality with synthetic data...")
|
|
142
|
+
|
|
143
|
+
# Generate synthetic predictions
|
|
144
|
+
predictions = create_synthetic_predictions()
|
|
145
|
+
print(f"Created synthetic predictions with shape: {predictions.shape}")
|
|
146
|
+
|
|
147
|
+
# Run NMS with IoU threshold = 0.5 and confidence threshold = 0.3
|
|
148
|
+
results = nms(predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
|
|
149
|
+
|
|
150
|
+
print(f"NMS returned {len(results)} batch results")
|
|
151
|
+
detections = results[0] # Get results for first (and only) image
|
|
152
|
+
print(f"Number of detections after NMS: {detections.shape[0]}")
|
|
153
|
+
|
|
154
|
+
assert detections.shape[0] != 0
|
|
155
|
+
|
|
156
|
+
print("\nDetections after NMS:")
|
|
157
|
+
print("Format: [x1, y1, x2, y2, confidence, class_id]")
|
|
158
|
+
for i, det in enumerate(detections):
|
|
159
|
+
x1, y1, x2, y2, conf, cls = det
|
|
160
|
+
center_x = (x1 + x2) / 2
|
|
161
|
+
center_y = (y1 + y2) / 2
|
|
162
|
+
width = x2 - x1
|
|
163
|
+
height = y2 - y1
|
|
164
|
+
print(f"Detection {i}: center=({center_x:.1f}, {center_y:.1f}), "
|
|
165
|
+
f"size={width:.1f}x{height:.1f}, conf={conf:.3f}, class={int(cls)}")
|
|
166
|
+
|
|
167
|
+
# Verify expected results
|
|
168
|
+
|
|
169
|
+
# Verify that high-confidence boxes are kept over low-confidence overlapping ones
|
|
170
|
+
# Look for the scenario 1 boxes (around center 100,100 area)
|
|
171
|
+
scenario1_boxes = []
|
|
172
|
+
for i, det in enumerate(detections):
|
|
173
|
+
x1, y1, x2, y2, conf, cls = det
|
|
174
|
+
center_x = (x1 + x2) / 2
|
|
175
|
+
center_y = (y1 + y2) / 2
|
|
176
|
+
if 80 <= center_x <= 130 and 80 <= center_y <= 130 and int(cls) == 0:
|
|
177
|
+
scenario1_boxes.append((i, center_x, center_y, conf))
|
|
178
|
+
|
|
179
|
+
# Check scenario 1b (around center 200,100 area)
|
|
180
|
+
scenario1b_boxes = []
|
|
181
|
+
for i, det in enumerate(detections):
|
|
182
|
+
x1, y1, x2, y2, conf, cls = det
|
|
183
|
+
center_x = (x1 + x2) / 2
|
|
184
|
+
center_y = (y1 + y2) / 2
|
|
185
|
+
if 180 <= center_x <= 220 and 80 <= center_y <= 120 and int(cls) == 0:
|
|
186
|
+
scenario1b_boxes.append((i, center_x, center_y, conf))
|
|
187
|
+
|
|
188
|
+
# Both scenario 1 and 1b should have exactly 1 detection each
|
|
189
|
+
total_high_overlap_boxes = len(scenario1_boxes) + len(scenario1b_boxes)
|
|
190
|
+
if total_high_overlap_boxes != 2:
|
|
191
|
+
print("Error: expected 2 detections in high-overlap scenarios (1 each), got {}".format(
|
|
192
|
+
total_high_overlap_boxes
|
|
193
|
+
))
|
|
194
|
+
print(f" Scenario 1: {len(scenario1_boxes)} boxes")
|
|
195
|
+
print(f" Scenario 1b: {len(scenario1b_boxes)} boxes")
|
|
196
|
+
raise AssertionError()
|
|
197
|
+
# Should be the high-confidence box (0.8 * 0.9 = 0.72)
|
|
198
|
+
elif len(scenario1_boxes) == 1 and scenario1_boxes[0][3] < 0.7:
|
|
199
|
+
print("Error: wrong box kept in scenario 1. Expected conf > 0.7, got {}".format(
|
|
200
|
+
scenario1_boxes[0][3]
|
|
201
|
+
))
|
|
202
|
+
raise AssertionError()
|
|
203
|
+
# Should be the high-confidence box (0.9 * 0.9 = 0.81)
|
|
204
|
+
elif len(scenario1b_boxes) == 1 and scenario1b_boxes[0][3] < 0.8:
|
|
205
|
+
print("Error: wrong box kept in scenario 1b. Expected conf > 0.8, got {}".format(
|
|
206
|
+
scenario1b_boxes[0][3]
|
|
207
|
+
))
|
|
208
|
+
raise AssertionError()
|
|
209
|
+
else:
|
|
210
|
+
print("Scenarios 1 & 1b passed: High-confidence boxes kept, low-confidence overlapping boxes suppressed")
|
|
211
|
+
|
|
212
|
+
# Verify IoU calculations and ensure suppression actually works
|
|
213
|
+
if len(scenario1_boxes) == 1 and len(scenario1b_boxes) == 1:
|
|
214
|
+
# Calculate what the IoU would have been between the boxes that were supposed to overlap
|
|
215
|
+
# Scenario 1: Box A (100,100,80x80) vs Box B (105,105,80x80)
|
|
216
|
+
box_a = [100-40, 100-40, 100+40, 100+40] # Convert center+size to corners
|
|
217
|
+
box_b = [105-40, 105-40, 105+40, 105+40]
|
|
218
|
+
iou_1 = calculate_iou_boxes(box_a, box_b)
|
|
219
|
+
|
|
220
|
+
# Scenario 1b: Box A2 (200,100,60x60) vs Box B2 (202,102,60x60)
|
|
221
|
+
box_a2 = [200-30, 100-30, 200+30, 100+30]
|
|
222
|
+
box_b2 = [202-30, 102-30, 202+30, 102+30]
|
|
223
|
+
iou_1b = calculate_iou_boxes(box_a2, box_b2)
|
|
224
|
+
|
|
225
|
+
print(f" Theoretical IoU for scenario 1 boxes: {iou_1:.3f}")
|
|
226
|
+
print(f" Theoretical IoU for scenario 1b boxes: {iou_1b:.3f}")
|
|
227
|
+
|
|
228
|
+
# If IoU > threshold, suppression should have occurred
|
|
229
|
+
if iou_1 <= 0.5:
|
|
230
|
+
print(f"Error: scenario 1 IoU {iou_1:.3f} is too low - test setup is invalid!")
|
|
231
|
+
raise AssertionError()
|
|
232
|
+
elif iou_1b <= 0.5:
|
|
233
|
+
print(f"Error: scenario 1b IoU {iou_1b:.3f} is too low - test setup is invalid!")
|
|
234
|
+
raise AssertionError()
|
|
235
|
+
else:
|
|
236
|
+
print(" High IoU confirmed - suppression was correct")
|
|
237
|
+
|
|
238
|
+
# Verify scenario 2 - both non-overlapping boxes should be kept
|
|
239
|
+
scenario2_boxes = []
|
|
240
|
+
for i, det in enumerate(detections):
|
|
241
|
+
x1, y1, x2, y2, conf, cls = det
|
|
242
|
+
center_x = (x1 + x2) / 2
|
|
243
|
+
center_y = (y1 + y2) / 2
|
|
244
|
+
if 270 <= center_x <= 410 and 70 <= center_y <= 130 and int(cls) == 0:
|
|
245
|
+
scenario2_boxes.append((i, center_x, center_y, conf))
|
|
246
|
+
|
|
247
|
+
if len(scenario2_boxes) != 2:
|
|
248
|
+
print(f"Error: expected 2 detections in scenario 2 area, got {len(scenario2_boxes)}")
|
|
249
|
+
raise AssertionError()
|
|
250
|
+
else:
|
|
251
|
+
print("Scenario 2 passed: Both non-overlapping boxes kept")
|
|
252
|
+
|
|
253
|
+
# Verify scenario 3 - different classes should both be kept
|
|
254
|
+
scenario3_boxes = []
|
|
255
|
+
for i, det in enumerate(detections):
|
|
256
|
+
x1, y1, x2, y2, conf, cls = det
|
|
257
|
+
center_x = (x1 + x2) / 2
|
|
258
|
+
center_y = (y1 + y2) / 2
|
|
259
|
+
if 65 <= center_x <= 135 and 265 <= center_y <= 335:
|
|
260
|
+
scenario3_boxes.append((i, center_x, center_y, conf, int(cls)))
|
|
261
|
+
|
|
262
|
+
classes_found = set(box[4] for box in scenario3_boxes)
|
|
263
|
+
if (len(scenario3_boxes) != 2) or (len(classes_found) != 2):
|
|
264
|
+
print("Error: expected 2 detections of different classes , got {} detections of classes {}".format(
|
|
265
|
+
len(scenario3_boxes),classes_found
|
|
266
|
+
))
|
|
267
|
+
raise AssertionError()
|
|
268
|
+
else:
|
|
269
|
+
print("Scenario 3 passed: Both different-class boxes kept")
|
|
270
|
+
|
|
271
|
+
# Verify scenario 4 - cascading overlapping boxes (only highest confidence should remain)
|
|
272
|
+
scenario4_boxes = []
|
|
273
|
+
for i, det in enumerate(detections):
|
|
274
|
+
x1, y1, x2, y2, conf, cls = det
|
|
275
|
+
center_x = (x1 + x2) / 2
|
|
276
|
+
center_y = (y1 + y2) / 2
|
|
277
|
+
if 460 <= center_x <= 560 and 260 <= center_y <= 360 and int(cls) == 0:
|
|
278
|
+
scenario4_boxes.append((i, center_x, center_y, conf))
|
|
279
|
+
|
|
280
|
+
print(f"\nScenario 4 analysis: Found {len(scenario4_boxes)} boxes in cascading area:")
|
|
281
|
+
for i, (det_idx, cx, cy, conf) in enumerate(scenario4_boxes):
|
|
282
|
+
print(f" Box {i}: center=({cx:.1f}, {cy:.1f}), conf={conf:.3f}")
|
|
283
|
+
|
|
284
|
+
# Check IoU between remaining boxes to ensure proper suppression
|
|
285
|
+
if len(scenario4_boxes) >= 2:
|
|
286
|
+
max_iou = 0
|
|
287
|
+
for i in range(len(scenario4_boxes)):
|
|
288
|
+
for j in range(i+1, len(scenario4_boxes)):
|
|
289
|
+
det1 = detections[scenario4_boxes[i][0]]
|
|
290
|
+
det2 = detections[scenario4_boxes[j][0]]
|
|
291
|
+
iou = calculate_iou_boxes(det1[:4], det2[:4])
|
|
292
|
+
print(f" IoU between box {i} and box {j}: {iou:.3f}")
|
|
293
|
+
max_iou = max(max_iou, iou)
|
|
294
|
+
|
|
295
|
+
if len(scenario4_boxes) == 1:
|
|
296
|
+
print("Scenario 4 passed: Only highest confidence box kept")
|
|
297
|
+
else:
|
|
298
|
+
# This is only OK if IoU < threshold
|
|
299
|
+
if max_iou < 0.5: # Our IoU threshold
|
|
300
|
+
print("Scenario 4 passed: Multiple boxes kept due to low IoU (< 0.5)")
|
|
301
|
+
else:
|
|
302
|
+
print(f"ERROR: Scenario 4 failed - boxes with IoU {max_iou:.3f} > 0.5 were not suppressed!")
|
|
303
|
+
raise AssertionError()
|
|
304
|
+
|
|
305
|
+
# Create a scenario that requires IoU calculation
|
|
306
|
+
print("\n=== COMPREHENSIVE IoU VALIDATION TEST ===")
|
|
307
|
+
|
|
308
|
+
# Create two identical boxes that should definitely be suppressed
|
|
309
|
+
identical_box_a = [100, 100, 50, 50, 0.9, 0.9, 0.05, 0.05] # High confidence
|
|
310
|
+
identical_box_b = [100, 100, 50, 50, 0.9, 0.7, 0.1, 0.1] # Lower confidence
|
|
311
|
+
|
|
312
|
+
test_predictions = torch.zeros(1, 5, 8) # Small batch for focused test
|
|
313
|
+
test_predictions[0, 0, :] = torch.tensor(identical_box_a)
|
|
314
|
+
test_predictions[0, 1, :] = torch.tensor(identical_box_b)
|
|
315
|
+
|
|
316
|
+
# Run NMS on this simple case
|
|
317
|
+
test_results = nms(test_predictions, conf_thres=0.3, iou_thres=0.5, max_det=300)
|
|
318
|
+
test_detections = test_results[0]
|
|
319
|
+
|
|
320
|
+
print(f"Identical boxes test: Input 2 identical boxes, got {test_detections.shape[0]} detections")
|
|
321
|
+
|
|
322
|
+
if test_detections.shape[0] != 1:
|
|
323
|
+
print(f"Error Two identical boxes should result in 1 detection, got {test_detections.shape[0]}")
|
|
324
|
+
raise AssertionError()
|
|
325
|
+
else:
|
|
326
|
+
# Verify it kept the higher confidence box
|
|
327
|
+
kept_conf = test_detections[0, 4].item()
|
|
328
|
+
expected_conf = 0.9 * 0.9 # objectness * class_conf
|
|
329
|
+
if abs(kept_conf - expected_conf) > 0.01:
|
|
330
|
+
print(f"ERROR: Wrong box kept. Expected conf ≈ {expected_conf:.3f}, got {kept_conf:.3f}")
|
|
331
|
+
raise AssertionError()
|
|
332
|
+
else:
|
|
333
|
+
print("Identical boxes test passed: Higher confidence box kept")
|
|
334
|
+
|
|
335
|
+
print("\nNMS tests passed")
|
|
File without changes
|