megadetector 5.0.28__py3-none-any.whl → 5.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/batch_processing/api_core/batch_service/score.py +4 -5
- megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +1 -1
- megadetector/api/batch_processing/api_support/summarize_daily_activity.py +1 -1
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +2 -2
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +1 -1
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +1 -1
- megadetector/api/synchronous/api_core/tests/load_test.py +2 -3
- megadetector/classification/aggregate_classifier_probs.py +3 -3
- megadetector/classification/analyze_failed_images.py +5 -5
- megadetector/classification/cache_batchapi_outputs.py +5 -5
- megadetector/classification/create_classification_dataset.py +11 -12
- megadetector/classification/crop_detections.py +10 -10
- megadetector/classification/csv_to_json.py +8 -8
- megadetector/classification/detect_and_crop.py +13 -15
- megadetector/classification/evaluate_model.py +7 -7
- megadetector/classification/identify_mislabeled_candidates.py +6 -6
- megadetector/classification/json_to_azcopy_list.py +1 -1
- megadetector/classification/json_validator.py +29 -32
- megadetector/classification/map_classification_categories.py +9 -9
- megadetector/classification/merge_classification_detection_output.py +12 -9
- megadetector/classification/prepare_classification_script.py +19 -19
- megadetector/classification/prepare_classification_script_mc.py +23 -23
- megadetector/classification/run_classifier.py +4 -4
- megadetector/classification/save_mislabeled.py +6 -6
- megadetector/classification/train_classifier.py +1 -1
- megadetector/classification/train_classifier_tf.py +9 -9
- megadetector/classification/train_utils.py +10 -10
- megadetector/data_management/annotations/annotation_constants.py +1 -1
- megadetector/data_management/camtrap_dp_to_coco.py +45 -45
- megadetector/data_management/cct_json_utils.py +101 -101
- megadetector/data_management/cct_to_md.py +49 -49
- megadetector/data_management/cct_to_wi.py +33 -33
- megadetector/data_management/coco_to_labelme.py +75 -75
- megadetector/data_management/coco_to_yolo.py +189 -189
- megadetector/data_management/databases/add_width_and_height_to_db.py +3 -2
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +38 -38
- megadetector/data_management/databases/integrity_check_json_db.py +202 -188
- megadetector/data_management/databases/subset_json_db.py +33 -33
- megadetector/data_management/generate_crops_from_cct.py +38 -38
- megadetector/data_management/get_image_sizes.py +54 -49
- megadetector/data_management/labelme_to_coco.py +130 -124
- megadetector/data_management/labelme_to_yolo.py +78 -72
- megadetector/data_management/lila/create_lila_blank_set.py +81 -83
- megadetector/data_management/lila/create_lila_test_set.py +32 -31
- megadetector/data_management/lila/create_links_to_md_results_files.py +18 -18
- megadetector/data_management/lila/download_lila_subset.py +21 -24
- megadetector/data_management/lila/generate_lila_per_image_labels.py +91 -91
- megadetector/data_management/lila/get_lila_annotation_counts.py +30 -30
- megadetector/data_management/lila/get_lila_image_counts.py +22 -22
- megadetector/data_management/lila/lila_common.py +70 -70
- megadetector/data_management/lila/test_lila_metadata_urls.py +13 -14
- megadetector/data_management/mewc_to_md.py +339 -340
- megadetector/data_management/ocr_tools.py +258 -252
- megadetector/data_management/read_exif.py +231 -224
- megadetector/data_management/remap_coco_categories.py +26 -26
- megadetector/data_management/remove_exif.py +31 -20
- megadetector/data_management/rename_images.py +187 -187
- megadetector/data_management/resize_coco_dataset.py +41 -41
- megadetector/data_management/speciesnet_to_md.py +41 -41
- megadetector/data_management/wi_download_csv_to_coco.py +55 -55
- megadetector/data_management/yolo_output_to_md_output.py +117 -120
- megadetector/data_management/yolo_to_coco.py +195 -188
- megadetector/detection/change_detection.py +831 -0
- megadetector/detection/process_video.py +340 -337
- megadetector/detection/pytorch_detector.py +304 -262
- megadetector/detection/run_detector.py +177 -164
- megadetector/detection/run_detector_batch.py +364 -363
- megadetector/detection/run_inference_with_yolov5_val.py +328 -325
- megadetector/detection/run_tiled_inference.py +256 -249
- megadetector/detection/tf_detector.py +24 -24
- megadetector/detection/video_utils.py +290 -282
- megadetector/postprocessing/add_max_conf.py +15 -11
- megadetector/postprocessing/categorize_detections_by_size.py +44 -44
- megadetector/postprocessing/classification_postprocessing.py +415 -415
- megadetector/postprocessing/combine_batch_outputs.py +20 -21
- megadetector/postprocessing/compare_batch_results.py +528 -517
- megadetector/postprocessing/convert_output_format.py +97 -97
- megadetector/postprocessing/create_crop_folder.py +219 -146
- megadetector/postprocessing/detector_calibration.py +173 -168
- megadetector/postprocessing/generate_csv_report.py +508 -499
- megadetector/postprocessing/load_api_results.py +23 -20
- megadetector/postprocessing/md_to_coco.py +129 -98
- megadetector/postprocessing/md_to_labelme.py +89 -83
- megadetector/postprocessing/md_to_wi.py +40 -40
- megadetector/postprocessing/merge_detections.py +87 -114
- megadetector/postprocessing/postprocess_batch_results.py +313 -298
- megadetector/postprocessing/remap_detection_categories.py +36 -36
- megadetector/postprocessing/render_detection_confusion_matrix.py +205 -199
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +57 -57
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +27 -28
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +702 -677
- megadetector/postprocessing/separate_detections_into_folders.py +226 -211
- megadetector/postprocessing/subset_json_detector_output.py +265 -262
- megadetector/postprocessing/top_folders_to_bottom.py +45 -45
- megadetector/postprocessing/validate_batch_results.py +70 -70
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +52 -52
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +15 -15
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +14 -14
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +66 -66
- megadetector/taxonomy_mapping/retrieve_sample_image.py +16 -16
- megadetector/taxonomy_mapping/simple_image_download.py +8 -8
- megadetector/taxonomy_mapping/species_lookup.py +33 -33
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +14 -14
- megadetector/taxonomy_mapping/taxonomy_graph.py +10 -10
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +13 -13
- megadetector/utils/azure_utils.py +22 -22
- megadetector/utils/ct_utils.py +1018 -200
- megadetector/utils/directory_listing.py +21 -77
- megadetector/utils/gpu_test.py +22 -22
- megadetector/utils/md_tests.py +541 -518
- megadetector/utils/path_utils.py +1457 -398
- megadetector/utils/process_utils.py +41 -41
- megadetector/utils/sas_blob_utils.py +53 -49
- megadetector/utils/split_locations_into_train_val.py +61 -61
- megadetector/utils/string_utils.py +147 -26
- megadetector/utils/url_utils.py +463 -173
- megadetector/utils/wi_utils.py +2629 -2526
- megadetector/utils/write_html_image_list.py +137 -137
- megadetector/visualization/plot_utils.py +21 -21
- megadetector/visualization/render_images_with_thumbnails.py +37 -73
- megadetector/visualization/visualization_utils.py +401 -397
- megadetector/visualization/visualize_db.py +197 -190
- megadetector/visualization/visualize_detector_output.py +79 -73
- {megadetector-5.0.28.dist-info → megadetector-5.0.29.dist-info}/METADATA +135 -132
- megadetector-5.0.29.dist-info/RECORD +163 -0
- {megadetector-5.0.28.dist-info → megadetector-5.0.29.dist-info}/WHEEL +1 -1
- {megadetector-5.0.28.dist-info → megadetector-5.0.29.dist-info}/licenses/LICENSE +0 -0
- {megadetector-5.0.28.dist-info → megadetector-5.0.29.dist-info}/top_level.txt +0 -0
- megadetector/data_management/importers/add_nacti_sizes.py +0 -52
- megadetector/data_management/importers/add_timestamps_to_icct.py +0 -79
- megadetector/data_management/importers/animl_results_to_md_results.py +0 -158
- megadetector/data_management/importers/auckland_doc_test_to_json.py +0 -373
- megadetector/data_management/importers/auckland_doc_to_json.py +0 -201
- megadetector/data_management/importers/awc_to_json.py +0 -191
- megadetector/data_management/importers/bellevue_to_json.py +0 -272
- megadetector/data_management/importers/cacophony-thermal-importer.py +0 -793
- megadetector/data_management/importers/carrizo_shrubfree_2018.py +0 -269
- megadetector/data_management/importers/carrizo_trail_cam_2017.py +0 -289
- megadetector/data_management/importers/cct_field_adjustments.py +0 -58
- megadetector/data_management/importers/channel_islands_to_cct.py +0 -913
- megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
- megadetector/data_management/importers/eMammal/eMammal_helpers.py +0 -249
- megadetector/data_management/importers/eMammal/make_eMammal_json.py +0 -223
- megadetector/data_management/importers/ena24_to_json.py +0 -276
- megadetector/data_management/importers/filenames_to_json.py +0 -386
- megadetector/data_management/importers/helena_to_cct.py +0 -283
- megadetector/data_management/importers/idaho-camera-traps.py +0 -1407
- megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
- megadetector/data_management/importers/import_desert_lion_conservation_camera_traps.py +0 -387
- megadetector/data_management/importers/jb_csv_to_json.py +0 -150
- megadetector/data_management/importers/mcgill_to_json.py +0 -250
- megadetector/data_management/importers/missouri_to_json.py +0 -490
- megadetector/data_management/importers/nacti_fieldname_adjustments.py +0 -79
- megadetector/data_management/importers/noaa_seals_2019.py +0 -181
- megadetector/data_management/importers/osu-small-animals-to-json.py +0 -364
- megadetector/data_management/importers/pc_to_json.py +0 -365
- megadetector/data_management/importers/plot_wni_giraffes.py +0 -123
- megadetector/data_management/importers/prepare_zsl_imerit.py +0 -131
- megadetector/data_management/importers/raic_csv_to_md_results.py +0 -416
- megadetector/data_management/importers/rspb_to_json.py +0 -356
- megadetector/data_management/importers/save_the_elephants_survey_A.py +0 -320
- megadetector/data_management/importers/save_the_elephants_survey_B.py +0 -329
- megadetector/data_management/importers/snapshot_safari_importer.py +0 -758
- megadetector/data_management/importers/snapshot_serengeti_lila.py +0 -1067
- megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
- megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
- megadetector/data_management/importers/sulross_get_exif.py +0 -65
- megadetector/data_management/importers/timelapse_csv_set_to_json.py +0 -490
- megadetector/data_management/importers/ubc_to_json.py +0 -399
- megadetector/data_management/importers/umn_to_json.py +0 -507
- megadetector/data_management/importers/wellington_to_json.py +0 -263
- megadetector/data_management/importers/wi_to_json.py +0 -442
- megadetector/data_management/importers/zamba_results_to_md_results.py +0 -180
- megadetector/data_management/lila/add_locations_to_island_camera_traps.py +0 -101
- megadetector/data_management/lila/add_locations_to_nacti.py +0 -151
- megadetector-5.0.28.dist-info/RECORD +0 -209
|
@@ -17,6 +17,7 @@ import shutil
|
|
|
17
17
|
import traceback
|
|
18
18
|
import uuid
|
|
19
19
|
import json
|
|
20
|
+
import inspect
|
|
20
21
|
|
|
21
22
|
import cv2
|
|
22
23
|
import torch
|
|
@@ -27,6 +28,7 @@ from megadetector.detection.run_detector import \
|
|
|
27
28
|
get_detector_version_from_model_file, \
|
|
28
29
|
known_models
|
|
29
30
|
from megadetector.utils.ct_utils import parse_bool_string
|
|
31
|
+
from megadetector.utils.ct_utils import to_bool
|
|
30
32
|
from megadetector.utils import ct_utils
|
|
31
33
|
|
|
32
34
|
# We support a few ways of accessing the YOLOv5 dependencies:
|
|
@@ -54,7 +56,7 @@ def _get_model_type_for_model(model_file,
|
|
|
54
56
|
verbose=False):
|
|
55
57
|
"""
|
|
56
58
|
Determine the model type (i.e., the inference library we need to use) for a .pt file.
|
|
57
|
-
|
|
59
|
+
|
|
58
60
|
Args:
|
|
59
61
|
model_file (str): the model file to read
|
|
60
62
|
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
@@ -64,28 +66,28 @@ def _get_model_type_for_model(model_file,
|
|
|
64
66
|
default_model_type (str, optional): return value for the case where we can't find
|
|
65
67
|
appropriate metadata in the file or in the global table.
|
|
66
68
|
verbose (bool, optional): enable additional debug output
|
|
67
|
-
|
|
69
|
+
|
|
68
70
|
Returns:
|
|
69
71
|
str: the model type indicated for this model
|
|
70
72
|
"""
|
|
71
|
-
|
|
73
|
+
|
|
72
74
|
model_info = read_metadata_from_megadetector_model_file(model_file)
|
|
73
|
-
|
|
75
|
+
|
|
74
76
|
# Check whether the model file itself specified a model type
|
|
75
77
|
model_type_from_model_file_metadata = None
|
|
76
|
-
|
|
77
|
-
if model_info is not None and 'model_type' in model_info:
|
|
78
|
+
|
|
79
|
+
if model_info is not None and 'model_type' in model_info:
|
|
78
80
|
model_type_from_model_file_metadata = model_info['model_type']
|
|
79
81
|
if verbose:
|
|
80
82
|
print('Parsed model type {} from model {}'.format(
|
|
81
83
|
model_type_from_model_file_metadata,
|
|
82
84
|
model_file))
|
|
83
|
-
|
|
85
|
+
|
|
84
86
|
model_type_from_model_version = None
|
|
85
|
-
|
|
87
|
+
|
|
86
88
|
# Check whether this is a known model version with a specific model type
|
|
87
89
|
model_version_from_file = get_detector_version_from_model_file(model_file)
|
|
88
|
-
|
|
90
|
+
|
|
89
91
|
if model_version_from_file is not None and model_version_from_file in known_models:
|
|
90
92
|
model_info = known_models[model_version_from_file]
|
|
91
93
|
if 'model_type' in model_info:
|
|
@@ -93,15 +95,15 @@ def _get_model_type_for_model(model_file,
|
|
|
93
95
|
if verbose:
|
|
94
96
|
print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
|
|
95
97
|
else:
|
|
96
|
-
model_type_from_model_version = None
|
|
97
|
-
|
|
98
|
+
model_type_from_model_version = None
|
|
99
|
+
|
|
98
100
|
if model_type_from_model_file_metadata is None and \
|
|
99
101
|
model_type_from_model_version is None:
|
|
100
102
|
if verbose:
|
|
101
103
|
print('Could not determine model type for {}, assuming {}'.format(
|
|
102
104
|
model_file,default_model_type))
|
|
103
105
|
model_type = default_model_type
|
|
104
|
-
|
|
106
|
+
|
|
105
107
|
elif model_type_from_model_file_metadata is not None and \
|
|
106
108
|
model_type_from_model_version is not None:
|
|
107
109
|
if model_type_from_model_version == model_type_from_model_file_metadata:
|
|
@@ -113,15 +115,15 @@ def _get_model_type_for_model(model_file,
|
|
|
113
115
|
model_type = model_type_from_model_file_metadata
|
|
114
116
|
else:
|
|
115
117
|
model_type = model_type_from_model_version
|
|
116
|
-
|
|
118
|
+
|
|
117
119
|
elif model_type_from_model_file_metadata is not None:
|
|
118
|
-
|
|
120
|
+
|
|
119
121
|
model_type = model_type_from_model_file_metadata
|
|
120
|
-
|
|
122
|
+
|
|
121
123
|
elif model_type_from_model_version is not None:
|
|
122
|
-
|
|
124
|
+
|
|
123
125
|
model_type = model_type_from_model_version
|
|
124
|
-
|
|
126
|
+
|
|
125
127
|
return model_type
|
|
126
128
|
|
|
127
129
|
# ...def _get_model_type_for_model(...)
|
|
@@ -134,7 +136,7 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
134
136
|
verbose=False):
|
|
135
137
|
"""
|
|
136
138
|
Initialize the appropriate YOLO imports for a model file.
|
|
137
|
-
|
|
139
|
+
|
|
138
140
|
Args:
|
|
139
141
|
model_file (str): The model file for which we're loading support
|
|
140
142
|
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
@@ -142,18 +144,17 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
142
144
|
type table says something else. Should be "table" (trust the table) or "file"
|
|
143
145
|
(trust the file).
|
|
144
146
|
default_model_type (str, optional): return value for the case where we can't find
|
|
145
|
-
appropriate metadata in the file or in the global table.
|
|
147
|
+
appropriate metadata in the file or in the global table.
|
|
146
148
|
detector_options (dict, optional): dictionary of detector options that mean
|
|
147
149
|
different things to different models
|
|
148
150
|
verbose (bool, optional): enable additional debug output
|
|
149
|
-
|
|
151
|
+
|
|
150
152
|
Returns:
|
|
151
153
|
str: the model type for which we initialized support
|
|
152
154
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
155
|
+
|
|
155
156
|
global yolo_model_type_imported
|
|
156
|
-
|
|
157
|
+
|
|
157
158
|
if detector_options is not None and 'model_type' in detector_options:
|
|
158
159
|
model_type = detector_options['model_type']
|
|
159
160
|
print('Model type {} provided in detector options'.format(model_type))
|
|
@@ -161,7 +162,7 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
161
162
|
model_type = _get_model_type_for_model(model_file,
|
|
162
163
|
prefer_model_type_source=prefer_model_type_source,
|
|
163
164
|
default_model_type=default_model_type)
|
|
164
|
-
|
|
165
|
+
|
|
165
166
|
if yolo_model_type_imported is not None:
|
|
166
167
|
if model_type == yolo_model_type_imported:
|
|
167
168
|
print('Bypassing imports for model type {}'.format(model_type))
|
|
@@ -169,27 +170,27 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
169
170
|
else:
|
|
170
171
|
print('Previously set up imports for model type {}, re-importing as {}'.format(
|
|
171
172
|
yolo_model_type_imported,model_type))
|
|
172
|
-
|
|
173
|
+
|
|
173
174
|
_initialize_yolo_imports(model_type,verbose=verbose)
|
|
174
|
-
|
|
175
|
+
|
|
175
176
|
return model_type
|
|
176
177
|
|
|
177
178
|
|
|
178
179
|
def _clean_yolo_imports(verbose=False):
|
|
179
180
|
"""
|
|
180
181
|
Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
|
|
181
|
-
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
182
|
+
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
182
183
|
just, e.g., handling different libraries in different modules, is that we need to make sure
|
|
183
184
|
*pickle* sees the right version of modules during module loading, including modules we don't
|
|
184
|
-
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
|
|
185
|
+
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
|
|
185
186
|
do that is to remove all the "wrong" versions from sys.modules and sys.path.
|
|
186
|
-
|
|
187
|
+
|
|
187
188
|
Args:
|
|
188
189
|
verbose (bool, optional): enable additional debug output
|
|
189
190
|
"""
|
|
190
|
-
|
|
191
|
+
|
|
191
192
|
modules_to_delete = []
|
|
192
|
-
for module_name in sys.modules.keys():
|
|
193
|
+
for module_name in sys.modules.keys():
|
|
193
194
|
module = sys.modules[module_name]
|
|
194
195
|
try:
|
|
195
196
|
module_file = module.__file__.replace('\\','/')
|
|
@@ -199,24 +200,24 @@ def _clean_yolo_imports(verbose=False):
|
|
|
199
200
|
for token in tokens:
|
|
200
201
|
if 'yolov5' in token or 'yolov9' in token or 'ultralytics' in token:
|
|
201
202
|
modules_to_delete.append(module_name)
|
|
202
|
-
break
|
|
203
|
+
break
|
|
203
204
|
except Exception:
|
|
204
205
|
pass
|
|
205
|
-
|
|
206
|
+
|
|
206
207
|
for module_name in modules_to_delete:
|
|
207
208
|
if module_name in sys.modules.keys():
|
|
208
209
|
module_file = module.__file__.replace('\\','/')
|
|
209
210
|
if verbose:
|
|
210
211
|
print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
|
|
211
212
|
del sys.modules[module_name]
|
|
212
|
-
|
|
213
|
+
|
|
213
214
|
paths_to_delete = []
|
|
214
|
-
|
|
215
|
+
|
|
215
216
|
for p in sys.path:
|
|
216
217
|
if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
|
|
217
218
|
print('clean_yolo_imports: removing {} from path'.format(p))
|
|
218
219
|
paths_to_delete.append(p)
|
|
219
|
-
|
|
220
|
+
|
|
220
221
|
for p in paths_to_delete:
|
|
221
222
|
sys.path.remove(p)
|
|
222
223
|
|
|
@@ -226,54 +227,69 @@ def _clean_yolo_imports(verbose=False):
|
|
|
226
227
|
def _initialize_yolo_imports(model_type='yolov5',
|
|
227
228
|
allow_fallback_import=True,
|
|
228
229
|
force_reimport=False,
|
|
229
|
-
verbose=
|
|
230
|
+
verbose=True):
|
|
230
231
|
"""
|
|
231
|
-
Imports required functions from one or more yolo libraries (yolov5, yolov9,
|
|
232
|
+
Imports required functions from one or more yolo libraries (yolov5, yolov9,
|
|
232
233
|
ultralytics, targeting support for [model_type]).
|
|
233
|
-
|
|
234
|
+
|
|
234
235
|
Args:
|
|
235
236
|
model_type (str): The model type for which we're loading support
|
|
236
|
-
allow_fallback_import (bool, optional): If we can't import from the package for
|
|
237
|
-
which we're trying to load support, fall back to "import utils". This is
|
|
237
|
+
allow_fallback_import (bool, optional): If we can't import from the package for
|
|
238
|
+
which we're trying to load support, fall back to "import utils". This is
|
|
238
239
|
typically used when the right support library is on the current PYTHONPATH.
|
|
239
|
-
force_reimport (bool, optional): import the appropriate libraries even if the
|
|
240
|
+
force_reimport (bool, optional): import the appropriate libraries even if the
|
|
240
241
|
requested model type matches the current initialization state
|
|
241
242
|
verbose (bool, optional): include additional debug output
|
|
242
|
-
|
|
243
|
+
|
|
243
244
|
Returns:
|
|
244
245
|
str: the model type for which we initialized support
|
|
245
246
|
"""
|
|
246
|
-
|
|
247
|
+
|
|
248
|
+
# When running in pytest, the megadetector 'utils' module is put in the global
|
|
249
|
+
# namespace, which creates conflicts with yolov5; remove it from the global
|
|
250
|
+
# namespsace.
|
|
251
|
+
if ('PYTEST_CURRENT_TEST' in os.environ):
|
|
252
|
+
print('*** pytest detected ***')
|
|
253
|
+
if ('utils' in sys.modules):
|
|
254
|
+
utils_module = sys.modules['utils']
|
|
255
|
+
if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
|
|
256
|
+
print(f"Removing conflicting utils module: {utils_module.__file__}")
|
|
257
|
+
sys.modules.pop('utils', None)
|
|
258
|
+
# Also remove any submodules
|
|
259
|
+
to_remove = [name for name in sys.modules if name.startswith('utils.')]
|
|
260
|
+
for name in to_remove:
|
|
261
|
+
sys.modules.pop(name, None)
|
|
262
|
+
|
|
247
263
|
global yolo_model_type_imported
|
|
248
|
-
|
|
264
|
+
|
|
249
265
|
if model_type is None:
|
|
250
266
|
model_type = 'yolov5'
|
|
251
|
-
|
|
267
|
+
|
|
252
268
|
# The point of this function is to make the appropriate version
|
|
253
269
|
# of the following functions available at module scope
|
|
254
270
|
global non_max_suppression
|
|
255
271
|
global xyxy2xywh
|
|
256
272
|
global letterbox
|
|
257
273
|
global scale_coords
|
|
258
|
-
|
|
274
|
+
|
|
259
275
|
if yolo_model_type_imported is not None:
|
|
260
|
-
if yolo_model_type_imported == model_type:
|
|
276
|
+
if (yolo_model_type_imported == model_type) and (not force_reimport):
|
|
261
277
|
print('Bypassing imports for YOLO model type {}'.format(model_type))
|
|
262
278
|
return
|
|
263
279
|
else:
|
|
264
280
|
_clean_yolo_imports()
|
|
265
|
-
|
|
281
|
+
|
|
266
282
|
try_yolov5_import = (model_type == 'yolov5')
|
|
267
283
|
try_yolov9_import = (model_type == 'yolov9')
|
|
268
284
|
try_ultralytics_import = (model_type == 'ultralytics')
|
|
269
|
-
|
|
285
|
+
|
|
270
286
|
utils_imported = False
|
|
271
|
-
|
|
287
|
+
|
|
272
288
|
# First try importing from the yolov5 package; this is how the pip
|
|
273
289
|
# package finds YOLOv5 utilities.
|
|
274
290
|
if try_yolov5_import and not utils_imported:
|
|
275
|
-
|
|
276
|
-
try:
|
|
291
|
+
|
|
292
|
+
try:
|
|
277
293
|
from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
278
294
|
from yolov5.utils.augmentations import letterbox # noqa
|
|
279
295
|
try:
|
|
@@ -283,109 +299,127 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
283
299
|
utils_imported = True
|
|
284
300
|
if verbose:
|
|
285
301
|
print('Imported utils from YOLOv5 package')
|
|
286
|
-
|
|
287
|
-
except Exception as e: # noqa
|
|
288
|
-
|
|
302
|
+
|
|
303
|
+
except Exception as e: # noqa
|
|
289
304
|
# print('yolov5 module import failed: {}'.format(e))
|
|
290
|
-
# print(traceback.format_exc())
|
|
305
|
+
# print(traceback.format_exc())
|
|
291
306
|
pass
|
|
292
|
-
|
|
307
|
+
|
|
293
308
|
# Next try importing from the yolov9 package
|
|
294
309
|
if try_yolov9_import and not utils_imported:
|
|
295
|
-
|
|
310
|
+
|
|
296
311
|
try:
|
|
297
|
-
|
|
312
|
+
|
|
298
313
|
from yolov9.utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
299
314
|
from yolov9.utils.augmentations import letterbox # noqa
|
|
300
315
|
from yolov9.utils.general import scale_boxes as scale_coords # noqa
|
|
301
316
|
utils_imported = True
|
|
302
317
|
if verbose:
|
|
303
318
|
print('Imported utils from YOLOv9 package')
|
|
304
|
-
|
|
319
|
+
|
|
305
320
|
except Exception as e: # noqa
|
|
306
|
-
|
|
321
|
+
|
|
307
322
|
# print('yolov9 module import failed: {}'.format(e))
|
|
308
323
|
# print(traceback.format_exc())
|
|
309
324
|
pass
|
|
310
|
-
|
|
311
|
-
# If we haven't succeeded yet, import from the ultralytics package
|
|
325
|
+
|
|
326
|
+
# If we haven't succeeded yet, import from the ultralytics package
|
|
312
327
|
if try_ultralytics_import and not utils_imported:
|
|
313
|
-
|
|
328
|
+
|
|
314
329
|
try:
|
|
315
|
-
|
|
330
|
+
|
|
316
331
|
import ultralytics # noqa
|
|
317
|
-
|
|
332
|
+
|
|
318
333
|
except Exception:
|
|
319
|
-
|
|
334
|
+
|
|
320
335
|
print('It looks like you are trying to run a model that requires the ultralytics package, '
|
|
321
336
|
'but the ultralytics package is not installed, but . For licensing reasons, this '
|
|
322
337
|
'is not installed by default with the MegaDetector Python package. Run '
|
|
323
338
|
'"pip install ultralytics" to install it, and try again.')
|
|
324
339
|
raise
|
|
325
|
-
|
|
340
|
+
|
|
326
341
|
try:
|
|
327
|
-
|
|
342
|
+
|
|
328
343
|
from ultralytics.utils.ops import non_max_suppression # noqa
|
|
329
344
|
from ultralytics.utils.ops import xyxy2xywh # noqa
|
|
330
|
-
|
|
345
|
+
|
|
331
346
|
# In the ultralytics package, scale_boxes and scale_coords both exist;
|
|
332
347
|
# we want scale_boxes.
|
|
333
|
-
#
|
|
348
|
+
#
|
|
334
349
|
# from ultralytics.utils.ops import scale_coords # noqa
|
|
335
350
|
from ultralytics.utils.ops import scale_boxes as scale_coords # noqa
|
|
336
351
|
from ultralytics.data.augment import LetterBox
|
|
337
|
-
|
|
338
|
-
# letterbox() became a LetterBox class in the ultralytics package. Create a
|
|
352
|
+
|
|
353
|
+
# letterbox() became a LetterBox class in the ultralytics package. Create a
|
|
339
354
|
# backwards-compatible letterbox function wrapper that wraps the class up.
|
|
340
|
-
def letterbox(img,new_shape,auto=False,scaleFill=False,
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
355
|
+
def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
|
|
356
|
+
scaleup=True,center=True,stride=32):
|
|
357
|
+
|
|
358
|
+
# Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
|
|
359
|
+
# both conventions.
|
|
360
|
+
use_old_scalefill_arg = False
|
|
361
|
+
try:
|
|
362
|
+
sig = inspect.signature(LetterBox.__init__)
|
|
363
|
+
if 'scaleFill' in sig.parameters:
|
|
364
|
+
use_old_scalefill_arg = True
|
|
365
|
+
except Exception:
|
|
366
|
+
pass
|
|
367
|
+
|
|
368
|
+
if use_old_scalefill_arg:
|
|
369
|
+
if verbose:
|
|
370
|
+
print('Using old scaleFill calling convention')
|
|
371
|
+
letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
|
|
372
|
+
scaleup=scaleup,center=center,stride=stride)
|
|
373
|
+
else:
|
|
374
|
+
letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
|
|
375
|
+
scaleup=scaleup,center=center,stride=stride)
|
|
376
|
+
|
|
377
|
+
letterbox_result = letterbox_transformer(image=img)
|
|
378
|
+
|
|
345
379
|
if isinstance(new_shape,int):
|
|
346
380
|
new_shape = [new_shape,new_shape]
|
|
347
|
-
|
|
381
|
+
|
|
348
382
|
# The letterboxing is done, we just need to reverse-engineer what it did
|
|
349
383
|
shape = img.shape[:2]
|
|
350
|
-
|
|
384
|
+
|
|
351
385
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
352
386
|
if not scaleup:
|
|
353
387
|
r = min(r, 1.0)
|
|
354
388
|
ratio = r, r
|
|
355
|
-
|
|
389
|
+
|
|
356
390
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
357
391
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
|
358
392
|
if auto:
|
|
359
393
|
dw, dh = np.mod(dw, stride), np.mod(dh, stride)
|
|
360
|
-
elif scaleFill:
|
|
394
|
+
elif scaleFill:
|
|
361
395
|
dw, dh = 0.0, 0.0
|
|
362
396
|
new_unpad = (new_shape[1], new_shape[0])
|
|
363
397
|
ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
|
|
364
|
-
|
|
398
|
+
|
|
365
399
|
dw /= 2
|
|
366
400
|
dh /= 2
|
|
367
401
|
pad = (dw,dh)
|
|
368
|
-
|
|
402
|
+
|
|
369
403
|
return [letterbox_result,ratio,pad]
|
|
370
|
-
|
|
404
|
+
|
|
371
405
|
utils_imported = True
|
|
372
406
|
if verbose:
|
|
373
407
|
print('Imported utils from ultralytics package')
|
|
374
|
-
|
|
408
|
+
|
|
375
409
|
except Exception:
|
|
376
|
-
|
|
410
|
+
|
|
377
411
|
# print('Ultralytics module import failed')
|
|
378
412
|
pass
|
|
379
|
-
|
|
413
|
+
|
|
380
414
|
# If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
|
|
381
415
|
if (not utils_imported) and allow_fallback_import:
|
|
382
|
-
|
|
416
|
+
|
|
383
417
|
try:
|
|
384
|
-
|
|
418
|
+
|
|
385
419
|
# import pre- and post-processing functions from the YOLOv5 repo
|
|
386
420
|
from utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
387
421
|
from utils.augmentations import letterbox # noqa
|
|
388
|
-
|
|
422
|
+
|
|
389
423
|
# scale_coords() is scale_boxes() in some YOLOv5 versions
|
|
390
424
|
try:
|
|
391
425
|
from utils.general import scale_coords # noqa
|
|
@@ -395,58 +429,58 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
395
429
|
imported_file = sys.modules[scale_coords.__module__].__file__
|
|
396
430
|
if verbose:
|
|
397
431
|
print('Imported utils from {}'.format(imported_file))
|
|
398
|
-
|
|
432
|
+
|
|
399
433
|
except ModuleNotFoundError as e:
|
|
400
|
-
|
|
434
|
+
|
|
401
435
|
raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
|
|
402
|
-
|
|
436
|
+
|
|
403
437
|
assert utils_imported, 'YOLO utils import error'
|
|
404
|
-
|
|
438
|
+
|
|
405
439
|
yolo_model_type_imported = model_type
|
|
406
440
|
if verbose:
|
|
407
441
|
print('Prepared YOLO imports for model type {}'.format(model_type))
|
|
408
|
-
|
|
442
|
+
|
|
409
443
|
return model_type
|
|
410
444
|
|
|
411
445
|
# ...def _initialize_yolo_imports(...)
|
|
412
|
-
|
|
446
|
+
|
|
413
447
|
|
|
414
448
|
#%% Model metadata functions
|
|
415
449
|
|
|
416
|
-
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
450
|
+
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
417
451
|
model_file_out,
|
|
418
|
-
metadata,
|
|
452
|
+
metadata,
|
|
419
453
|
destination_path='megadetector_info.json'):
|
|
420
454
|
"""
|
|
421
|
-
Adds a .json file to the specified MegaDetector model file containing metadata used
|
|
455
|
+
Adds a .json file to the specified MegaDetector model file containing metadata used
|
|
422
456
|
by this module. Always over-writes the output file.
|
|
423
|
-
|
|
457
|
+
|
|
424
458
|
Args:
|
|
425
459
|
model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
|
|
426
460
|
model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
|
|
427
461
|
May be the same as model_file_in.
|
|
428
462
|
metadata (dict): The metadata dict to add to the output model file
|
|
429
|
-
destination_path (str, optional): The relative path within the main folder of the
|
|
430
|
-
model archive where we should write the metadata. This is not relative to the root
|
|
463
|
+
destination_path (str, optional): The relative path within the main folder of the
|
|
464
|
+
model archive where we should write the metadata. This is not relative to the root
|
|
431
465
|
of the archive, it's relative to the one and only folder at the root of the archive
|
|
432
|
-
(this is a PyTorch convention).
|
|
466
|
+
(this is a PyTorch convention).
|
|
433
467
|
"""
|
|
434
|
-
|
|
468
|
+
|
|
435
469
|
tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
|
|
436
470
|
os.makedirs(tmp_base,exist_ok=True)
|
|
437
471
|
metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
|
|
438
|
-
metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
|
|
472
|
+
metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
|
|
439
473
|
|
|
440
474
|
with open(metadata_tmp_file_abs,'w') as f:
|
|
441
475
|
json.dump(metadata,f,indent=1)
|
|
442
|
-
|
|
476
|
+
|
|
443
477
|
# Copy the input file to the output file
|
|
444
478
|
shutil.copyfile(model_file_in,model_file_out)
|
|
445
479
|
|
|
446
480
|
# Write metadata to the output file
|
|
447
481
|
with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
|
|
448
|
-
|
|
449
|
-
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
482
|
+
|
|
483
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
450
484
|
# it in the one and only folder.
|
|
451
485
|
names = zipf.namelist()
|
|
452
486
|
root_folders = set()
|
|
@@ -456,9 +490,9 @@ def add_metadata_to_megadetector_model_file(model_file_in,
|
|
|
456
490
|
assert len(root_folders) == 1,\
|
|
457
491
|
'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
|
|
458
492
|
root_folder = next(iter(root_folders))
|
|
459
|
-
|
|
460
|
-
zipf.write(metadata_tmp_file_abs,
|
|
461
|
-
root_folder + '/' + destination_path,
|
|
493
|
+
|
|
494
|
+
zipf.write(metadata_tmp_file_abs,
|
|
495
|
+
root_folder + '/' + destination_path,
|
|
462
496
|
compresslevel=9,
|
|
463
497
|
compress_type=zipfile.ZIP_DEFLATED)
|
|
464
498
|
|
|
@@ -466,7 +500,7 @@ def add_metadata_to_megadetector_model_file(model_file_in,
|
|
|
466
500
|
os.remove(metadata_tmp_file_abs)
|
|
467
501
|
except Exception as e:
|
|
468
502
|
print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
|
|
469
|
-
|
|
503
|
+
|
|
470
504
|
# ...def add_metadata_to_megadetector_model_file(...)
|
|
471
505
|
|
|
472
506
|
|
|
@@ -475,22 +509,22 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
475
509
|
verbose=False):
|
|
476
510
|
"""
|
|
477
511
|
Reads custom MegaDetector metadata from a modified MegaDetector model file.
|
|
478
|
-
|
|
512
|
+
|
|
479
513
|
Args:
|
|
480
514
|
model_file (str): The model filename to read, typically .pt (.zip is also sensible)
|
|
481
|
-
relative_path (str, optional): The relative path within the main folder of the model
|
|
482
|
-
archive from which we should read the metadata. This is not relative to the root
|
|
515
|
+
relative_path (str, optional): The relative path within the main folder of the model
|
|
516
|
+
archive from which we should read the metadata. This is not relative to the root
|
|
483
517
|
of the archive, it's relative to the one and only folder at the root of the archive
|
|
484
|
-
(this is a PyTorch convention).
|
|
485
|
-
|
|
518
|
+
(this is a PyTorch convention).
|
|
519
|
+
|
|
486
520
|
Returns:
|
|
487
|
-
object:
|
|
488
|
-
|
|
521
|
+
object: whatever we read from the metadata file, always a dict in practice. Returns
|
|
522
|
+
None if we failed to read the specified metadata file.
|
|
489
523
|
"""
|
|
490
|
-
|
|
524
|
+
|
|
491
525
|
with zipfile.ZipFile(model_file,'r') as zipf:
|
|
492
|
-
|
|
493
|
-
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
526
|
+
|
|
527
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
494
528
|
# it in the one and only folder.
|
|
495
529
|
names = zipf.namelist()
|
|
496
530
|
root_folders = set()
|
|
@@ -498,17 +532,19 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
498
532
|
root_folder = name.split('/')[0]
|
|
499
533
|
root_folders.add(root_folder)
|
|
500
534
|
if len(root_folders) != 1:
|
|
501
|
-
print('Warning: this archive does not have exactly one folder at the top level;
|
|
535
|
+
print('Warning: this archive does not have exactly one folder at the top level; ' + \
|
|
536
|
+
'are you sure it\'s a Torch model file?')
|
|
502
537
|
return None
|
|
503
538
|
root_folder = next(iter(root_folders))
|
|
504
|
-
|
|
539
|
+
|
|
505
540
|
metadata_file = root_folder + '/' + relative_path
|
|
506
541
|
if metadata_file not in names:
|
|
507
542
|
# This is the case for MDv5a and MDv5b
|
|
508
543
|
if verbose:
|
|
509
|
-
print('Warning: could not find metadata file {} in zip archive'.format(
|
|
544
|
+
print('Warning: could not find metadata file {} in zip archive {}'.format(
|
|
545
|
+
metadata_file,os.path.basename(model_file)))
|
|
510
546
|
return None
|
|
511
|
-
|
|
547
|
+
|
|
512
548
|
try:
|
|
513
549
|
path = zipfile.Path(zipf,metadata_file)
|
|
514
550
|
contents = path.read_text()
|
|
@@ -516,9 +552,9 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
516
552
|
except Exception as e:
|
|
517
553
|
print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
|
|
518
554
|
return None
|
|
519
|
-
|
|
555
|
+
|
|
520
556
|
return d
|
|
521
|
-
|
|
557
|
+
|
|
522
558
|
# ...def read_metadata_from_megadetector_model_file(...)
|
|
523
559
|
|
|
524
560
|
|
|
@@ -526,27 +562,33 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
526
562
|
|
|
527
563
|
default_compatibility_mode = 'classic'
|
|
528
564
|
|
|
529
|
-
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
|
|
565
|
+
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
|
|
530
566
|
# correctly forcing a specific compatibility mode (I use "classic-test" in that case)
|
|
531
567
|
require_non_default_compatibility_mode = False
|
|
532
568
|
|
|
533
569
|
class PTDetector:
|
|
534
|
-
|
|
570
|
+
"""
|
|
571
|
+
Class that runs a PyTorch-based MegaDetector model.
|
|
572
|
+
"""
|
|
573
|
+
|
|
535
574
|
def __init__(self, model_path, detector_options=None, verbose=False):
|
|
575
|
+
|
|
576
|
+
if verbose:
|
|
577
|
+
print('Initializing PTDetector (verbose)')
|
|
536
578
|
|
|
537
579
|
# Set up the import environment for this model, unloading previous
|
|
538
580
|
# YOLO library versions if necessary.
|
|
539
581
|
_initialize_yolo_imports_for_model(model_path,
|
|
540
582
|
detector_options=detector_options,
|
|
541
583
|
verbose=verbose)
|
|
542
|
-
|
|
584
|
+
|
|
543
585
|
# Parse options specific to this detector family
|
|
544
586
|
force_cpu = False
|
|
545
|
-
use_model_native_classes = False
|
|
587
|
+
use_model_native_classes = False
|
|
546
588
|
compatibility_mode = default_compatibility_mode
|
|
547
|
-
|
|
589
|
+
|
|
548
590
|
if detector_options is not None:
|
|
549
|
-
|
|
591
|
+
|
|
550
592
|
if 'force_cpu' in detector_options:
|
|
551
593
|
force_cpu = parse_bool_string(detector_options['force_cpu'])
|
|
552
594
|
if 'use_model_native_classes' in detector_options:
|
|
@@ -554,27 +596,27 @@ class PTDetector:
|
|
|
554
596
|
if 'compatibility_mode' in detector_options:
|
|
555
597
|
if detector_options['compatibility_mode'] is None:
|
|
556
598
|
compatibility_mode = default_compatibility_mode
|
|
557
|
-
else:
|
|
558
|
-
compatibility_mode = detector_options['compatibility_mode']
|
|
559
|
-
|
|
599
|
+
else:
|
|
600
|
+
compatibility_mode = detector_options['compatibility_mode']
|
|
601
|
+
|
|
560
602
|
if require_non_default_compatibility_mode:
|
|
561
|
-
|
|
603
|
+
|
|
562
604
|
print('### DEBUG: requiring non-default compatibility mode ###')
|
|
563
605
|
assert compatibility_mode != 'classic'
|
|
564
606
|
assert compatibility_mode != 'default'
|
|
565
|
-
|
|
607
|
+
|
|
566
608
|
preprocess_only = False
|
|
567
609
|
if (detector_options is not None) and \
|
|
568
610
|
('preprocess_only' in detector_options) and \
|
|
569
611
|
(detector_options['preprocess_only']):
|
|
570
612
|
preprocess_only = True
|
|
571
|
-
|
|
613
|
+
|
|
572
614
|
if verbose or (not preprocess_only):
|
|
573
615
|
print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
|
|
574
|
-
|
|
616
|
+
|
|
575
617
|
model_metadata = read_metadata_from_megadetector_model_file(model_path)
|
|
576
|
-
|
|
577
|
-
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
|
|
618
|
+
|
|
619
|
+
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
|
|
578
620
|
#: aspect ratio".
|
|
579
621
|
if model_metadata is not None and 'image_size' in model_metadata:
|
|
580
622
|
self.default_image_size = model_metadata['image_size']
|
|
@@ -582,38 +624,38 @@ class PTDetector:
|
|
|
582
624
|
print('Loaded image size {} from model metadata'.format(self.default_image_size))
|
|
583
625
|
else:
|
|
584
626
|
self.default_image_size = 1280
|
|
585
|
-
|
|
627
|
+
|
|
586
628
|
#: Either a string ('cpu','cuda:0') or a torch.device()
|
|
587
629
|
self.device = 'cpu'
|
|
588
|
-
|
|
630
|
+
|
|
589
631
|
#: Have we already printed a warning about using a non-standard image size?
|
|
590
632
|
#:
|
|
591
633
|
#: :meta private:
|
|
592
634
|
self.printed_image_size_warning = False
|
|
593
|
-
|
|
635
|
+
|
|
594
636
|
#: If this is False, we assume the underlying model is producing class indices in the
|
|
595
637
|
#: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
|
|
596
|
-
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
|
|
638
|
+
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
|
|
597
639
|
#: indices the model provides
|
|
598
|
-
self.use_model_native_classes = use_model_native_classes
|
|
599
|
-
|
|
640
|
+
self.use_model_native_classes = use_model_native_classes
|
|
641
|
+
|
|
600
642
|
#: This allows us to maintain backwards compatibility across a set of changes to the
|
|
601
|
-
#: way this class does inference. Currently should start with either "default" or
|
|
643
|
+
#: way this class does inference. Currently should start with either "default" or
|
|
602
644
|
#: "classic".
|
|
603
645
|
self.compatibility_mode = compatibility_mode
|
|
604
|
-
|
|
646
|
+
|
|
605
647
|
#: Stride size passed to YOLOv5's letterbox() function
|
|
606
648
|
self.letterbox_stride = 32
|
|
607
|
-
|
|
649
|
+
|
|
608
650
|
if 'classic' in self.compatibility_mode:
|
|
609
651
|
self.letterbox_stride = 64
|
|
610
|
-
|
|
652
|
+
|
|
611
653
|
#: Use half-precision inference... fixed by the model, generally don't mess with this
|
|
612
654
|
self.half_precision = False
|
|
613
|
-
|
|
655
|
+
|
|
614
656
|
if preprocess_only:
|
|
615
657
|
return
|
|
616
|
-
|
|
658
|
+
|
|
617
659
|
if not force_cpu:
|
|
618
660
|
if torch.cuda.is_available():
|
|
619
661
|
self.device = torch.device('cuda:0')
|
|
@@ -623,10 +665,10 @@ class PTDetector:
|
|
|
623
665
|
except AttributeError:
|
|
624
666
|
pass
|
|
625
667
|
try:
|
|
626
|
-
self.model = PTDetector._load_model(model_path,
|
|
627
|
-
device=self.device,
|
|
668
|
+
self.model = PTDetector._load_model(model_path,
|
|
669
|
+
device=self.device,
|
|
628
670
|
compatibility_mode=self.compatibility_mode)
|
|
629
|
-
|
|
671
|
+
|
|
630
672
|
except Exception as e:
|
|
631
673
|
# In a very esoteric scenario where an old version of YOLOv5 is used to run
|
|
632
674
|
# newer models, we run into an issue because the "Model" class became
|
|
@@ -636,21 +678,21 @@ class PTDetector:
|
|
|
636
678
|
print('Forward-compatibility issue detected, patching')
|
|
637
679
|
from models import yolo
|
|
638
680
|
yolo.DetectionModel = yolo.Model
|
|
639
|
-
self.model = PTDetector._load_model(model_path,
|
|
681
|
+
self.model = PTDetector._load_model(model_path,
|
|
640
682
|
device=self.device,
|
|
641
683
|
compatibility_mode=self.compatibility_mode,
|
|
642
|
-
verbose=verbose)
|
|
684
|
+
verbose=verbose)
|
|
643
685
|
else:
|
|
644
686
|
raise
|
|
645
687
|
if (self.device != 'cpu'):
|
|
646
688
|
if verbose:
|
|
647
689
|
print('Sending model to GPU')
|
|
648
690
|
self.model.to(self.device)
|
|
649
|
-
|
|
691
|
+
|
|
650
692
|
|
|
651
693
|
@staticmethod
|
|
652
694
|
def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
|
|
653
|
-
|
|
695
|
+
|
|
654
696
|
if verbose:
|
|
655
697
|
print(f'Using PyTorch version {torch.__version__}')
|
|
656
698
|
|
|
@@ -661,12 +703,12 @@ class PTDetector:
|
|
|
661
703
|
# very slight changes to the output, which always make me nervous, so I'm not
|
|
662
704
|
# doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
|
|
663
705
|
if 'classic' in compatibility_mode:
|
|
664
|
-
use_map_location = (device != 'mps')
|
|
706
|
+
use_map_location = (device != 'mps')
|
|
665
707
|
else:
|
|
666
708
|
use_map_location = False
|
|
667
|
-
|
|
709
|
+
|
|
668
710
|
if use_map_location:
|
|
669
|
-
try:
|
|
711
|
+
try:
|
|
670
712
|
checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
|
|
671
713
|
# For a transitional period, we want to support torch 1.1x, where the weights_only
|
|
672
714
|
# parameter doesn't exist
|
|
@@ -682,31 +724,31 @@ class PTDetector:
|
|
|
682
724
|
# parameter doesn't exist
|
|
683
725
|
except Exception as e:
|
|
684
726
|
if "'weights_only' is an invalid keyword" in str(e):
|
|
685
|
-
checkpoint = torch.load(model_pt_path)
|
|
727
|
+
checkpoint = torch.load(model_pt_path)
|
|
686
728
|
else:
|
|
687
729
|
raise
|
|
688
|
-
|
|
689
|
-
# Compatibility fix that allows us to load older YOLOv5 models with
|
|
730
|
+
|
|
731
|
+
# Compatibility fix that allows us to load older YOLOv5 models with
|
|
690
732
|
# newer versions of YOLOv5/PT
|
|
691
733
|
for m in checkpoint['model'].modules():
|
|
692
734
|
t = type(m)
|
|
693
735
|
if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
694
736
|
m.recompute_scale_factor = None
|
|
695
|
-
|
|
696
|
-
if use_map_location:
|
|
697
|
-
model = checkpoint['model'].float().fuse().eval()
|
|
737
|
+
|
|
738
|
+
if use_map_location:
|
|
739
|
+
model = checkpoint['model'].float().fuse().eval()
|
|
698
740
|
else:
|
|
699
|
-
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
700
|
-
|
|
741
|
+
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
742
|
+
|
|
701
743
|
return model
|
|
702
744
|
|
|
703
745
|
# ...def _load_model(...)
|
|
704
|
-
|
|
705
746
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
747
|
+
|
|
748
|
+
def generate_detections_one_image(self,
|
|
749
|
+
img_original,
|
|
750
|
+
image_id='unknown',
|
|
751
|
+
detection_threshold=0.00001,
|
|
710
752
|
image_size=None,
|
|
711
753
|
skip_image_resizing=False,
|
|
712
754
|
augment=False,
|
|
@@ -716,16 +758,16 @@ class PTDetector:
|
|
|
716
758
|
Applies the detector to an image.
|
|
717
759
|
|
|
718
760
|
Args:
|
|
719
|
-
img_original (Image): the PIL Image object (or numpy array) on which we should run the
|
|
761
|
+
img_original (Image): the PIL Image object (or numpy array) on which we should run the
|
|
720
762
|
detector, with EXIF rotation already handled
|
|
721
|
-
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
763
|
+
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
722
764
|
of the output object
|
|
723
|
-
detection_threshold (float, optional): only detections above this confidence threshold
|
|
765
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
724
766
|
will be included in the return value
|
|
725
|
-
image_size (tuple, optional): image size to use for inference, only mess with this if
|
|
767
|
+
image_size (tuple, optional): image size to use for inference, only mess with this if
|
|
726
768
|
(a) you're using a model other than MegaDetector or (b) you know what you're getting into
|
|
727
|
-
skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
|
|
728
|
-
external resizing), only mess with this if (a) you're using a model other than MegaDetector
|
|
769
|
+
skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
|
|
770
|
+
external resizing), only mess with this if (a) you're using a model other than MegaDetector
|
|
729
771
|
or (b) you know what you're getting into
|
|
730
772
|
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
731
773
|
preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
|
|
@@ -747,17 +789,17 @@ class PTDetector:
|
|
|
747
789
|
assert 'classic' in self.compatibility_mode, \
|
|
748
790
|
'Standalone preprocessing only supported in "classic" mode'
|
|
749
791
|
assert not skip_image_resizing, \
|
|
750
|
-
'skip_image_resizing and preprocess_only are exclusive'
|
|
751
|
-
|
|
792
|
+
'skip_image_resizing and preprocess_only are exclusive'
|
|
793
|
+
|
|
752
794
|
if detection_threshold is None:
|
|
753
|
-
|
|
795
|
+
|
|
754
796
|
detection_threshold = 0
|
|
755
|
-
|
|
797
|
+
|
|
756
798
|
try:
|
|
757
|
-
|
|
799
|
+
|
|
758
800
|
# If the caller wants us to skip all the resizing operations...
|
|
759
801
|
if skip_image_resizing:
|
|
760
|
-
|
|
802
|
+
|
|
761
803
|
if isinstance(img_original,dict):
|
|
762
804
|
image_info = img_original
|
|
763
805
|
img = image_info['img_processed']
|
|
@@ -768,107 +810,107 @@ class PTDetector:
|
|
|
768
810
|
img_original_pil = image_info['img_original_pil']
|
|
769
811
|
else:
|
|
770
812
|
img = img_original
|
|
771
|
-
|
|
813
|
+
|
|
772
814
|
else:
|
|
773
|
-
|
|
815
|
+
|
|
774
816
|
img_original_pil = None
|
|
775
817
|
# If we were given a PIL image
|
|
776
|
-
|
|
818
|
+
|
|
777
819
|
if not isinstance(img_original,np.ndarray):
|
|
778
|
-
img_original_pil = img_original
|
|
820
|
+
img_original_pil = img_original
|
|
779
821
|
img_original = np.asarray(img_original)
|
|
780
|
-
|
|
822
|
+
|
|
781
823
|
# PIL images are RGB already
|
|
782
824
|
# img_original = img_original[:, :, ::-1]
|
|
783
|
-
|
|
825
|
+
|
|
784
826
|
# Save the original shape for scaling boxes later
|
|
785
827
|
scaling_shape = img_original.shape
|
|
786
|
-
|
|
828
|
+
|
|
787
829
|
# If the caller is requesting a specific target size...
|
|
788
830
|
if image_size is not None:
|
|
789
|
-
|
|
831
|
+
|
|
790
832
|
assert isinstance(image_size,int)
|
|
791
|
-
|
|
833
|
+
|
|
792
834
|
if not self.printed_image_size_warning:
|
|
793
835
|
print('Using user-supplied image size {}'.format(image_size))
|
|
794
|
-
self.printed_image_size_warning = True
|
|
795
|
-
|
|
836
|
+
self.printed_image_size_warning = True
|
|
837
|
+
|
|
796
838
|
# Otherwise resize to self.default_image_size
|
|
797
839
|
else:
|
|
798
|
-
|
|
840
|
+
|
|
799
841
|
image_size = self.default_image_size
|
|
800
842
|
self.printed_image_size_warning = False
|
|
801
|
-
|
|
843
|
+
|
|
802
844
|
# ...if the caller has specified an image size
|
|
803
|
-
|
|
845
|
+
|
|
804
846
|
# In "classic mode", we only do the letterboxing resize, we don't do an
|
|
805
847
|
# additional initial resizing operation
|
|
806
848
|
if 'classic' in self.compatibility_mode:
|
|
807
|
-
|
|
849
|
+
|
|
808
850
|
resize_ratio = 1.0
|
|
809
|
-
|
|
810
|
-
# Resize the image so the long side matches the target image size. This is not
|
|
851
|
+
|
|
852
|
+
# Resize the image so the long side matches the target image size. This is not
|
|
811
853
|
# letterboxing (i.e., padding) yet, just resizing.
|
|
812
854
|
else:
|
|
813
|
-
|
|
855
|
+
|
|
814
856
|
use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
|
|
815
|
-
|
|
857
|
+
|
|
816
858
|
h,w = img_original.shape[:2]
|
|
817
859
|
resize_ratio = image_size / max(h,w)
|
|
818
|
-
|
|
860
|
+
|
|
819
861
|
# Only resize if we have to
|
|
820
862
|
if resize_ratio != 1:
|
|
821
|
-
|
|
822
|
-
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
863
|
+
|
|
864
|
+
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
823
865
|
# area interpolation for downsizing
|
|
824
866
|
if resize_ratio > 1:
|
|
825
867
|
interpolation_method = cv2.INTER_LINEAR
|
|
826
868
|
else:
|
|
827
|
-
interpolation_method = cv2.INTER_AREA
|
|
828
|
-
|
|
869
|
+
interpolation_method = cv2.INTER_AREA
|
|
870
|
+
|
|
829
871
|
if use_ceil_for_resize:
|
|
830
872
|
target_w = math.ceil(w * resize_ratio)
|
|
831
873
|
target_h = math.ceil(h * resize_ratio)
|
|
832
874
|
else:
|
|
833
875
|
target_w = int(w * resize_ratio)
|
|
834
876
|
target_h = int(h * resize_ratio)
|
|
835
|
-
|
|
877
|
+
|
|
836
878
|
img_original = cv2.resize(
|
|
837
879
|
img_original, (target_w, target_h),
|
|
838
880
|
interpolation=interpolation_method)
|
|
839
881
|
|
|
840
882
|
if 'classic' in self.compatibility_mode:
|
|
841
|
-
|
|
883
|
+
|
|
842
884
|
letterbox_auto = True
|
|
843
885
|
letterbox_scaleup = True
|
|
844
886
|
target_shape = image_size
|
|
845
|
-
|
|
887
|
+
|
|
846
888
|
else:
|
|
847
|
-
|
|
889
|
+
|
|
848
890
|
letterbox_auto = False
|
|
849
891
|
letterbox_scaleup = False
|
|
850
|
-
|
|
892
|
+
|
|
851
893
|
# The padding to apply as a fraction of the stride size
|
|
852
894
|
pad = 0.5
|
|
853
|
-
|
|
895
|
+
|
|
854
896
|
model_stride = int(self.model.stride.max())
|
|
855
|
-
|
|
897
|
+
|
|
856
898
|
max_dimension = max(img_original.shape)
|
|
857
899
|
normalized_shape = [img_original.shape[0] / max_dimension,
|
|
858
900
|
img_original.shape[1] / max_dimension]
|
|
859
901
|
target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
|
|
860
902
|
pad).astype(int) * model_stride
|
|
861
|
-
|
|
903
|
+
|
|
862
904
|
# Now we letterbox, which is just padding, since we've already resized.
|
|
863
|
-
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
905
|
+
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
864
906
|
new_shape=target_shape,
|
|
865
|
-
stride=self.letterbox_stride,
|
|
907
|
+
stride=self.letterbox_stride,
|
|
866
908
|
auto=letterbox_auto,
|
|
867
909
|
scaleFill=False,
|
|
868
910
|
scaleup=letterbox_scaleup)
|
|
869
|
-
|
|
911
|
+
|
|
870
912
|
if preprocess_only:
|
|
871
|
-
|
|
913
|
+
|
|
872
914
|
assert 'file' in result
|
|
873
915
|
result['img_processed'] = img
|
|
874
916
|
result['img_original'] = img_original
|
|
@@ -878,14 +920,14 @@ class PTDetector:
|
|
|
878
920
|
result['letterbox_ratio'] = letterbox_ratio
|
|
879
921
|
result['letterbox_pad'] = letterbox_pad
|
|
880
922
|
return result
|
|
881
|
-
|
|
923
|
+
|
|
882
924
|
# ...are we doing resizing here, or were images already resized?
|
|
883
|
-
|
|
925
|
+
|
|
884
926
|
# Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
|
|
885
927
|
# so we don't need to mess with the color channels.
|
|
886
928
|
#
|
|
887
929
|
# TODO, this could be moved into the preprocessing loop
|
|
888
|
-
|
|
930
|
+
|
|
889
931
|
img = img.transpose((2, 0, 1)) # [::-1]
|
|
890
932
|
img = np.ascontiguousarray(img)
|
|
891
933
|
img = torch.from_numpy(img)
|
|
@@ -893,8 +935,8 @@ class PTDetector:
|
|
|
893
935
|
img = img.half() if self.half_precision else img.float()
|
|
894
936
|
img /= 255
|
|
895
937
|
|
|
896
|
-
# In practice this is always true
|
|
897
|
-
if len(img.shape) == 3:
|
|
938
|
+
# In practice this is always true
|
|
939
|
+
if len(img.shape) == 3:
|
|
898
940
|
img = torch.unsqueeze(img, 0)
|
|
899
941
|
|
|
900
942
|
# Run the model
|
|
@@ -908,19 +950,19 @@ class PTDetector:
|
|
|
908
950
|
else:
|
|
909
951
|
nms_conf_thres = detection_threshold # 0.01
|
|
910
952
|
nms_iou_thres = 0.6
|
|
911
|
-
nms_agnostic = False
|
|
953
|
+
nms_agnostic = False
|
|
912
954
|
nms_multi_label = True
|
|
913
|
-
|
|
955
|
+
|
|
914
956
|
# As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
|
|
915
957
|
#
|
|
916
|
-
# Send predictions back to the CPU for NMS.
|
|
958
|
+
# Send predictions back to the CPU for NMS.
|
|
917
959
|
if self.device == 'mps':
|
|
918
960
|
pred_nms = pred.cpu()
|
|
919
961
|
else:
|
|
920
962
|
pred_nms = pred
|
|
921
|
-
|
|
963
|
+
|
|
922
964
|
# NMS
|
|
923
|
-
pred = non_max_suppression(prediction=pred_nms,
|
|
965
|
+
pred = non_max_suppression(prediction=pred_nms,
|
|
924
966
|
conf_thres=nms_conf_thres,
|
|
925
967
|
iou_thres=nms_iou_thres,
|
|
926
968
|
agnostic=nms_agnostic,
|
|
@@ -930,26 +972,26 @@ class PTDetector:
|
|
|
930
972
|
gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
|
|
931
973
|
|
|
932
974
|
if 'classic' in self.compatibility_mode:
|
|
933
|
-
|
|
975
|
+
|
|
934
976
|
ratio = None
|
|
935
977
|
ratio_pad = None
|
|
936
|
-
|
|
978
|
+
|
|
937
979
|
else:
|
|
938
|
-
|
|
980
|
+
|
|
939
981
|
# letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
|
|
940
982
|
#
|
|
941
983
|
# ratio is a 2-tuple specifying the scaling that was applied to each dimension.
|
|
942
984
|
#
|
|
943
985
|
# The scale_boxes function expects a 2-tuple with these things combined.
|
|
944
986
|
ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
|
|
945
|
-
ratio_pad = (ratio, letterbox_pad)
|
|
946
|
-
|
|
987
|
+
ratio_pad = (ratio, letterbox_pad)
|
|
988
|
+
|
|
947
989
|
# This is a loop over detection batches, which will always be length 1 in our case,
|
|
948
990
|
# since we're not doing batch inference.
|
|
949
991
|
#
|
|
950
992
|
# det = pred[0]
|
|
951
993
|
#
|
|
952
|
-
# det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
|
|
994
|
+
# det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
|
|
953
995
|
# in descending order by confidence.
|
|
954
996
|
#
|
|
955
997
|
# Columns are:
|
|
@@ -959,15 +1001,15 @@ class PTDetector:
|
|
|
959
1001
|
# At this point, these are *non*-normalized values, referring to the size at which we
|
|
960
1002
|
# ran inference (img.shape).
|
|
961
1003
|
for det in pred:
|
|
962
|
-
|
|
1004
|
+
|
|
963
1005
|
if len(det) == 0:
|
|
964
1006
|
continue
|
|
965
|
-
|
|
1007
|
+
|
|
966
1008
|
# Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
|
|
967
1009
|
if 'classic' in self.compatibility_mode:
|
|
968
|
-
|
|
1010
|
+
|
|
969
1011
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
|
|
970
|
-
|
|
1012
|
+
|
|
971
1013
|
else:
|
|
972
1014
|
# After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
|
|
973
1015
|
# original pixel dimension of the image, followed by the class and confidence
|
|
@@ -975,14 +1017,14 @@ class PTDetector:
|
|
|
975
1017
|
|
|
976
1018
|
# Loop over detections
|
|
977
1019
|
for *xyxy, conf, cls in reversed(det):
|
|
978
|
-
|
|
1020
|
+
|
|
979
1021
|
if conf < detection_threshold:
|
|
980
1022
|
continue
|
|
981
|
-
|
|
1023
|
+
|
|
982
1024
|
# Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
|
|
983
1025
|
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
|
|
984
1026
|
|
|
985
|
-
# Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
|
|
1027
|
+
# Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
|
|
986
1028
|
# left/top/w/h (i.e., MD format)
|
|
987
1029
|
api_box = ct_utils.convert_yolo_to_xywh(xywh)
|
|
988
1030
|
|
|
@@ -991,11 +1033,11 @@ class PTDetector:
|
|
|
991
1033
|
conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
|
|
992
1034
|
else:
|
|
993
1035
|
api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
|
|
994
|
-
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
995
|
-
|
|
1036
|
+
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
1037
|
+
|
|
996
1038
|
if not self.use_model_native_classes:
|
|
997
|
-
# The MegaDetector output format's categories start at 1, but all YOLO-based
|
|
998
|
-
# MD models have category numbers starting at 0.
|
|
1039
|
+
# The MegaDetector output format's categories start at 1, but all YOLO-based
|
|
1040
|
+
# MD models have category numbers starting at 0.
|
|
999
1041
|
cls = int(cls.tolist()) + 1
|
|
1000
1042
|
if cls not in (1, 2, 3):
|
|
1001
1043
|
raise KeyError(f'{cls} is not a valid class.')
|
|
@@ -1008,15 +1050,15 @@ class PTDetector:
|
|
|
1008
1050
|
'bbox': api_box
|
|
1009
1051
|
})
|
|
1010
1052
|
max_conf = max(max_conf, conf)
|
|
1011
|
-
|
|
1053
|
+
|
|
1012
1054
|
# ...for each detection in this batch
|
|
1013
|
-
|
|
1055
|
+
|
|
1014
1056
|
# ...for each detection batch (always one iteration)
|
|
1015
1057
|
|
|
1016
1058
|
# ...try
|
|
1017
|
-
|
|
1059
|
+
|
|
1018
1060
|
except Exception as e:
|
|
1019
|
-
|
|
1061
|
+
|
|
1020
1062
|
result['failure'] = FAILURE_INFER
|
|
1021
1063
|
print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
|
|
1022
1064
|
# traceback.print_exc(e)
|
|
@@ -1039,12 +1081,12 @@ class PTDetector:
|
|
|
1039
1081
|
if __name__ == '__main__':
|
|
1040
1082
|
|
|
1041
1083
|
pass
|
|
1042
|
-
|
|
1084
|
+
|
|
1043
1085
|
#%%
|
|
1044
|
-
|
|
1086
|
+
|
|
1045
1087
|
import os #noqa
|
|
1046
1088
|
from megadetector.visualization import visualization_utils as vis_utils
|
|
1047
|
-
|
|
1089
|
+
|
|
1048
1090
|
model_file = os.environ['MDV5A']
|
|
1049
1091
|
im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
|
|
1050
1092
|
|