megadetector 5.0.28__py3-none-any.whl → 10.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +2 -2
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +1 -1
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +1 -1
- megadetector/classification/aggregate_classifier_probs.py +3 -3
- megadetector/classification/analyze_failed_images.py +5 -5
- megadetector/classification/cache_batchapi_outputs.py +5 -5
- megadetector/classification/create_classification_dataset.py +11 -12
- megadetector/classification/crop_detections.py +10 -10
- megadetector/classification/csv_to_json.py +8 -8
- megadetector/classification/detect_and_crop.py +13 -15
- megadetector/classification/efficientnet/model.py +8 -8
- megadetector/classification/efficientnet/utils.py +6 -5
- megadetector/classification/evaluate_model.py +7 -7
- megadetector/classification/identify_mislabeled_candidates.py +6 -6
- megadetector/classification/json_to_azcopy_list.py +1 -1
- megadetector/classification/json_validator.py +29 -32
- megadetector/classification/map_classification_categories.py +9 -9
- megadetector/classification/merge_classification_detection_output.py +12 -9
- megadetector/classification/prepare_classification_script.py +19 -19
- megadetector/classification/prepare_classification_script_mc.py +26 -26
- megadetector/classification/run_classifier.py +4 -4
- megadetector/classification/save_mislabeled.py +6 -6
- megadetector/classification/train_classifier.py +1 -1
- megadetector/classification/train_classifier_tf.py +9 -9
- megadetector/classification/train_utils.py +10 -10
- megadetector/data_management/annotations/annotation_constants.py +1 -2
- megadetector/data_management/camtrap_dp_to_coco.py +79 -46
- megadetector/data_management/cct_json_utils.py +103 -103
- megadetector/data_management/cct_to_md.py +49 -49
- megadetector/data_management/cct_to_wi.py +33 -33
- megadetector/data_management/coco_to_labelme.py +75 -75
- megadetector/data_management/coco_to_yolo.py +210 -193
- megadetector/data_management/databases/add_width_and_height_to_db.py +86 -12
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +40 -40
- megadetector/data_management/databases/integrity_check_json_db.py +228 -200
- megadetector/data_management/databases/subset_json_db.py +33 -33
- megadetector/data_management/generate_crops_from_cct.py +88 -39
- megadetector/data_management/get_image_sizes.py +54 -49
- megadetector/data_management/labelme_to_coco.py +133 -125
- megadetector/data_management/labelme_to_yolo.py +159 -73
- megadetector/data_management/lila/create_lila_blank_set.py +81 -83
- megadetector/data_management/lila/create_lila_test_set.py +32 -31
- megadetector/data_management/lila/create_links_to_md_results_files.py +18 -18
- megadetector/data_management/lila/download_lila_subset.py +21 -24
- megadetector/data_management/lila/generate_lila_per_image_labels.py +365 -107
- megadetector/data_management/lila/get_lila_annotation_counts.py +35 -33
- megadetector/data_management/lila/get_lila_image_counts.py +22 -22
- megadetector/data_management/lila/lila_common.py +73 -70
- megadetector/data_management/lila/test_lila_metadata_urls.py +28 -19
- megadetector/data_management/mewc_to_md.py +344 -340
- megadetector/data_management/ocr_tools.py +262 -255
- megadetector/data_management/read_exif.py +249 -227
- megadetector/data_management/remap_coco_categories.py +90 -28
- megadetector/data_management/remove_exif.py +81 -21
- megadetector/data_management/rename_images.py +187 -187
- megadetector/data_management/resize_coco_dataset.py +588 -120
- megadetector/data_management/speciesnet_to_md.py +41 -41
- megadetector/data_management/wi_download_csv_to_coco.py +55 -55
- megadetector/data_management/yolo_output_to_md_output.py +248 -122
- megadetector/data_management/yolo_to_coco.py +333 -191
- megadetector/detection/change_detection.py +832 -0
- megadetector/detection/process_video.py +340 -337
- megadetector/detection/pytorch_detector.py +358 -278
- megadetector/detection/run_detector.py +399 -186
- megadetector/detection/run_detector_batch.py +404 -377
- megadetector/detection/run_inference_with_yolov5_val.py +340 -327
- megadetector/detection/run_tiled_inference.py +257 -249
- megadetector/detection/tf_detector.py +24 -24
- megadetector/detection/video_utils.py +332 -295
- megadetector/postprocessing/add_max_conf.py +19 -11
- megadetector/postprocessing/categorize_detections_by_size.py +45 -45
- megadetector/postprocessing/classification_postprocessing.py +468 -433
- megadetector/postprocessing/combine_batch_outputs.py +23 -23
- megadetector/postprocessing/compare_batch_results.py +590 -525
- megadetector/postprocessing/convert_output_format.py +106 -102
- megadetector/postprocessing/create_crop_folder.py +347 -147
- megadetector/postprocessing/detector_calibration.py +173 -168
- megadetector/postprocessing/generate_csv_report.py +508 -499
- megadetector/postprocessing/load_api_results.py +48 -27
- megadetector/postprocessing/md_to_coco.py +133 -102
- megadetector/postprocessing/md_to_labelme.py +107 -90
- megadetector/postprocessing/md_to_wi.py +40 -40
- megadetector/postprocessing/merge_detections.py +92 -114
- megadetector/postprocessing/postprocess_batch_results.py +319 -301
- megadetector/postprocessing/remap_detection_categories.py +91 -38
- megadetector/postprocessing/render_detection_confusion_matrix.py +214 -205
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +57 -57
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +27 -28
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +704 -679
- megadetector/postprocessing/separate_detections_into_folders.py +226 -211
- megadetector/postprocessing/subset_json_detector_output.py +265 -262
- megadetector/postprocessing/top_folders_to_bottom.py +45 -45
- megadetector/postprocessing/validate_batch_results.py +70 -70
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +52 -52
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +18 -19
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +54 -33
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +67 -67
- megadetector/taxonomy_mapping/retrieve_sample_image.py +16 -16
- megadetector/taxonomy_mapping/simple_image_download.py +8 -8
- megadetector/taxonomy_mapping/species_lookup.py +156 -74
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +14 -14
- megadetector/taxonomy_mapping/taxonomy_graph.py +10 -10
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +13 -13
- megadetector/utils/ct_utils.py +1049 -211
- megadetector/utils/directory_listing.py +21 -77
- megadetector/utils/gpu_test.py +22 -22
- megadetector/utils/md_tests.py +632 -529
- megadetector/utils/path_utils.py +1520 -431
- megadetector/utils/process_utils.py +41 -41
- megadetector/utils/split_locations_into_train_val.py +62 -62
- megadetector/utils/string_utils.py +148 -27
- megadetector/utils/url_utils.py +489 -176
- megadetector/utils/wi_utils.py +2658 -2526
- megadetector/utils/write_html_image_list.py +137 -137
- megadetector/visualization/plot_utils.py +34 -30
- megadetector/visualization/render_images_with_thumbnails.py +39 -74
- megadetector/visualization/visualization_utils.py +487 -435
- megadetector/visualization/visualize_db.py +232 -198
- megadetector/visualization/visualize_detector_output.py +82 -76
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/METADATA +5 -2
- megadetector-10.0.0.dist-info/RECORD +139 -0
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/WHEEL +1 -1
- megadetector/api/batch_processing/api_core/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/score.py +0 -439
- megadetector/api/batch_processing/api_core/server.py +0 -294
- megadetector/api/batch_processing/api_core/server_api_config.py +0 -97
- megadetector/api/batch_processing/api_core/server_app_config.py +0 -55
- megadetector/api/batch_processing/api_core/server_batch_job_manager.py +0 -220
- megadetector/api/batch_processing/api_core/server_job_status_table.py +0 -149
- megadetector/api/batch_processing/api_core/server_orchestration.py +0 -360
- megadetector/api/batch_processing/api_core/server_utils.py +0 -88
- megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
- megadetector/api/batch_processing/api_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_support/summarize_daily_activity.py +0 -152
- megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
- megadetector/api/synchronous/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +0 -151
- megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -263
- megadetector/api/synchronous/api_core/animal_detection_api/config.py +0 -35
- megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
- megadetector/api/synchronous/api_core/tests/load_test.py +0 -110
- megadetector/data_management/importers/add_nacti_sizes.py +0 -52
- megadetector/data_management/importers/add_timestamps_to_icct.py +0 -79
- megadetector/data_management/importers/animl_results_to_md_results.py +0 -158
- megadetector/data_management/importers/auckland_doc_test_to_json.py +0 -373
- megadetector/data_management/importers/auckland_doc_to_json.py +0 -201
- megadetector/data_management/importers/awc_to_json.py +0 -191
- megadetector/data_management/importers/bellevue_to_json.py +0 -272
- megadetector/data_management/importers/cacophony-thermal-importer.py +0 -793
- megadetector/data_management/importers/carrizo_shrubfree_2018.py +0 -269
- megadetector/data_management/importers/carrizo_trail_cam_2017.py +0 -289
- megadetector/data_management/importers/cct_field_adjustments.py +0 -58
- megadetector/data_management/importers/channel_islands_to_cct.py +0 -913
- megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
- megadetector/data_management/importers/eMammal/eMammal_helpers.py +0 -249
- megadetector/data_management/importers/eMammal/make_eMammal_json.py +0 -223
- megadetector/data_management/importers/ena24_to_json.py +0 -276
- megadetector/data_management/importers/filenames_to_json.py +0 -386
- megadetector/data_management/importers/helena_to_cct.py +0 -283
- megadetector/data_management/importers/idaho-camera-traps.py +0 -1407
- megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
- megadetector/data_management/importers/import_desert_lion_conservation_camera_traps.py +0 -387
- megadetector/data_management/importers/jb_csv_to_json.py +0 -150
- megadetector/data_management/importers/mcgill_to_json.py +0 -250
- megadetector/data_management/importers/missouri_to_json.py +0 -490
- megadetector/data_management/importers/nacti_fieldname_adjustments.py +0 -79
- megadetector/data_management/importers/noaa_seals_2019.py +0 -181
- megadetector/data_management/importers/osu-small-animals-to-json.py +0 -364
- megadetector/data_management/importers/pc_to_json.py +0 -365
- megadetector/data_management/importers/plot_wni_giraffes.py +0 -123
- megadetector/data_management/importers/prepare_zsl_imerit.py +0 -131
- megadetector/data_management/importers/raic_csv_to_md_results.py +0 -416
- megadetector/data_management/importers/rspb_to_json.py +0 -356
- megadetector/data_management/importers/save_the_elephants_survey_A.py +0 -320
- megadetector/data_management/importers/save_the_elephants_survey_B.py +0 -329
- megadetector/data_management/importers/snapshot_safari_importer.py +0 -758
- megadetector/data_management/importers/snapshot_serengeti_lila.py +0 -1067
- megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
- megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
- megadetector/data_management/importers/sulross_get_exif.py +0 -65
- megadetector/data_management/importers/timelapse_csv_set_to_json.py +0 -490
- megadetector/data_management/importers/ubc_to_json.py +0 -399
- megadetector/data_management/importers/umn_to_json.py +0 -507
- megadetector/data_management/importers/wellington_to_json.py +0 -263
- megadetector/data_management/importers/wi_to_json.py +0 -442
- megadetector/data_management/importers/zamba_results_to_md_results.py +0 -180
- megadetector/data_management/lila/add_locations_to_island_camera_traps.py +0 -101
- megadetector/data_management/lila/add_locations_to_nacti.py +0 -151
- megadetector/utils/azure_utils.py +0 -178
- megadetector/utils/sas_blob_utils.py +0 -509
- megadetector-5.0.28.dist-info/RECORD +0 -209
- /megadetector/{api/batch_processing/__init__.py → __init__.py} +0 -0
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/licenses/LICENSE +0 -0
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ import shutil
|
|
|
17
17
|
import traceback
|
|
18
18
|
import uuid
|
|
19
19
|
import json
|
|
20
|
+
import inspect
|
|
20
21
|
|
|
21
22
|
import cv2
|
|
22
23
|
import torch
|
|
@@ -54,7 +55,7 @@ def _get_model_type_for_model(model_file,
|
|
|
54
55
|
verbose=False):
|
|
55
56
|
"""
|
|
56
57
|
Determine the model type (i.e., the inference library we need to use) for a .pt file.
|
|
57
|
-
|
|
58
|
+
|
|
58
59
|
Args:
|
|
59
60
|
model_file (str): the model file to read
|
|
60
61
|
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
@@ -64,28 +65,28 @@ def _get_model_type_for_model(model_file,
|
|
|
64
65
|
default_model_type (str, optional): return value for the case where we can't find
|
|
65
66
|
appropriate metadata in the file or in the global table.
|
|
66
67
|
verbose (bool, optional): enable additional debug output
|
|
67
|
-
|
|
68
|
+
|
|
68
69
|
Returns:
|
|
69
70
|
str: the model type indicated for this model
|
|
70
71
|
"""
|
|
71
|
-
|
|
72
|
+
|
|
72
73
|
model_info = read_metadata_from_megadetector_model_file(model_file)
|
|
73
|
-
|
|
74
|
+
|
|
74
75
|
# Check whether the model file itself specified a model type
|
|
75
76
|
model_type_from_model_file_metadata = None
|
|
76
|
-
|
|
77
|
-
if model_info is not None and 'model_type' in model_info:
|
|
77
|
+
|
|
78
|
+
if model_info is not None and 'model_type' in model_info:
|
|
78
79
|
model_type_from_model_file_metadata = model_info['model_type']
|
|
79
80
|
if verbose:
|
|
80
81
|
print('Parsed model type {} from model {}'.format(
|
|
81
82
|
model_type_from_model_file_metadata,
|
|
82
83
|
model_file))
|
|
83
|
-
|
|
84
|
+
|
|
84
85
|
model_type_from_model_version = None
|
|
85
|
-
|
|
86
|
+
|
|
86
87
|
# Check whether this is a known model version with a specific model type
|
|
87
88
|
model_version_from_file = get_detector_version_from_model_file(model_file)
|
|
88
|
-
|
|
89
|
+
|
|
89
90
|
if model_version_from_file is not None and model_version_from_file in known_models:
|
|
90
91
|
model_info = known_models[model_version_from_file]
|
|
91
92
|
if 'model_type' in model_info:
|
|
@@ -93,15 +94,15 @@ def _get_model_type_for_model(model_file,
|
|
|
93
94
|
if verbose:
|
|
94
95
|
print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
|
|
95
96
|
else:
|
|
96
|
-
model_type_from_model_version = None
|
|
97
|
-
|
|
97
|
+
model_type_from_model_version = None
|
|
98
|
+
|
|
98
99
|
if model_type_from_model_file_metadata is None and \
|
|
99
100
|
model_type_from_model_version is None:
|
|
100
101
|
if verbose:
|
|
101
102
|
print('Could not determine model type for {}, assuming {}'.format(
|
|
102
103
|
model_file,default_model_type))
|
|
103
104
|
model_type = default_model_type
|
|
104
|
-
|
|
105
|
+
|
|
105
106
|
elif model_type_from_model_file_metadata is not None and \
|
|
106
107
|
model_type_from_model_version is not None:
|
|
107
108
|
if model_type_from_model_version == model_type_from_model_file_metadata:
|
|
@@ -113,15 +114,15 @@ def _get_model_type_for_model(model_file,
|
|
|
113
114
|
model_type = model_type_from_model_file_metadata
|
|
114
115
|
else:
|
|
115
116
|
model_type = model_type_from_model_version
|
|
116
|
-
|
|
117
|
+
|
|
117
118
|
elif model_type_from_model_file_metadata is not None:
|
|
118
|
-
|
|
119
|
+
|
|
119
120
|
model_type = model_type_from_model_file_metadata
|
|
120
|
-
|
|
121
|
+
|
|
121
122
|
elif model_type_from_model_version is not None:
|
|
122
|
-
|
|
123
|
+
|
|
123
124
|
model_type = model_type_from_model_version
|
|
124
|
-
|
|
125
|
+
|
|
125
126
|
return model_type
|
|
126
127
|
|
|
127
128
|
# ...def _get_model_type_for_model(...)
|
|
@@ -134,7 +135,7 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
134
135
|
verbose=False):
|
|
135
136
|
"""
|
|
136
137
|
Initialize the appropriate YOLO imports for a model file.
|
|
137
|
-
|
|
138
|
+
|
|
138
139
|
Args:
|
|
139
140
|
model_file (str): The model file for which we're loading support
|
|
140
141
|
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
|
|
@@ -142,18 +143,17 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
142
143
|
type table says something else. Should be "table" (trust the table) or "file"
|
|
143
144
|
(trust the file).
|
|
144
145
|
default_model_type (str, optional): return value for the case where we can't find
|
|
145
|
-
appropriate metadata in the file or in the global table.
|
|
146
|
+
appropriate metadata in the file or in the global table.
|
|
146
147
|
detector_options (dict, optional): dictionary of detector options that mean
|
|
147
148
|
different things to different models
|
|
148
149
|
verbose (bool, optional): enable additional debug output
|
|
149
|
-
|
|
150
|
+
|
|
150
151
|
Returns:
|
|
151
152
|
str: the model type for which we initialized support
|
|
152
153
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
154
|
+
|
|
155
155
|
global yolo_model_type_imported
|
|
156
|
-
|
|
156
|
+
|
|
157
157
|
if detector_options is not None and 'model_type' in detector_options:
|
|
158
158
|
model_type = detector_options['model_type']
|
|
159
159
|
print('Model type {} provided in detector options'.format(model_type))
|
|
@@ -161,7 +161,7 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
161
161
|
model_type = _get_model_type_for_model(model_file,
|
|
162
162
|
prefer_model_type_source=prefer_model_type_source,
|
|
163
163
|
default_model_type=default_model_type)
|
|
164
|
-
|
|
164
|
+
|
|
165
165
|
if yolo_model_type_imported is not None:
|
|
166
166
|
if model_type == yolo_model_type_imported:
|
|
167
167
|
print('Bypassing imports for model type {}'.format(model_type))
|
|
@@ -169,54 +169,92 @@ def _initialize_yolo_imports_for_model(model_file,
|
|
|
169
169
|
else:
|
|
170
170
|
print('Previously set up imports for model type {}, re-importing as {}'.format(
|
|
171
171
|
yolo_model_type_imported,model_type))
|
|
172
|
-
|
|
172
|
+
|
|
173
173
|
_initialize_yolo_imports(model_type,verbose=verbose)
|
|
174
|
-
|
|
174
|
+
|
|
175
175
|
return model_type
|
|
176
176
|
|
|
177
177
|
|
|
178
|
-
def _clean_yolo_imports(verbose=False):
|
|
178
|
+
def _clean_yolo_imports(verbose=False,aggressive_cleanup=False):
|
|
179
179
|
"""
|
|
180
180
|
Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
|
|
181
|
-
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
181
|
+
of another YOLO library version. The reason we jump through all these hoops, rather than
|
|
182
182
|
just, e.g., handling different libraries in different modules, is that we need to make sure
|
|
183
183
|
*pickle* sees the right version of modules during module loading, including modules we don't
|
|
184
|
-
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
|
|
184
|
+
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
|
|
185
185
|
do that is to remove all the "wrong" versions from sys.modules and sys.path.
|
|
186
|
-
|
|
186
|
+
|
|
187
187
|
Args:
|
|
188
188
|
verbose (bool, optional): enable additional debug output
|
|
189
|
+
aggressive_cleanup (bool, optional): err on the side of removing modules,
|
|
190
|
+
at least by ignoring whether they are/aren't in a site-packages folder.
|
|
191
|
+
By default, only modules in a folder that includes "site-packages" will
|
|
192
|
+
be considered for unloading.
|
|
189
193
|
"""
|
|
190
|
-
|
|
194
|
+
|
|
191
195
|
modules_to_delete = []
|
|
192
|
-
|
|
196
|
+
|
|
197
|
+
for module_name in sys.modules.keys():
|
|
198
|
+
|
|
193
199
|
module = sys.modules[module_name]
|
|
200
|
+
if not hasattr(module,'__file__') or (module.__file__ is None):
|
|
201
|
+
continue
|
|
194
202
|
try:
|
|
195
203
|
module_file = module.__file__.replace('\\','/')
|
|
196
|
-
if
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
204
|
+
if not aggressive_cleanup:
|
|
205
|
+
if 'site-packages' not in module_file:
|
|
206
|
+
continue
|
|
207
|
+
tokens = module_file.split('/')
|
|
208
|
+
|
|
209
|
+
# For local path imports, a module filename that should be unloaded might
|
|
210
|
+
# look like:
|
|
211
|
+
#
|
|
212
|
+
# c:/git/yolov9/models/common.py
|
|
213
|
+
#
|
|
214
|
+
# For pip imports, a module filename that should be unloaded might look like:
|
|
215
|
+
#
|
|
216
|
+
# c:/users/user/miniforge3/envs/megadetector/lib/site-packages/yolov9/utils/__init__.py
|
|
217
|
+
first_token_to_check = len(tokens) - 4
|
|
218
|
+
for i_token,token in enumerate(tokens):
|
|
219
|
+
if i_token < first_token_to_check:
|
|
220
|
+
continue
|
|
221
|
+
# Don't remove anything based on the environment name, which
|
|
222
|
+
# always follows "envs" in the path
|
|
223
|
+
if (i_token > 1) and (tokens[i_token-1] == 'envs'):
|
|
224
|
+
continue
|
|
225
|
+
if ('yolov5' in token) or ('yolov9' in token) or ('ultralytics' in token):
|
|
226
|
+
if verbose:
|
|
227
|
+
print('Module {} ({}) looks deletable'.format(module_name,module_file))
|
|
201
228
|
modules_to_delete.append(module_name)
|
|
202
|
-
break
|
|
203
|
-
except Exception:
|
|
229
|
+
break
|
|
230
|
+
except Exception as e:
|
|
231
|
+
if verbose:
|
|
232
|
+
print('Exception during module review: {}'.format(str(e)))
|
|
204
233
|
pass
|
|
205
|
-
|
|
234
|
+
|
|
235
|
+
# ...for each module in the global namespace
|
|
236
|
+
|
|
206
237
|
for module_name in modules_to_delete:
|
|
238
|
+
|
|
207
239
|
if module_name in sys.modules.keys():
|
|
208
|
-
module_file = module.__file__.replace('\\','/')
|
|
209
240
|
if verbose:
|
|
210
|
-
|
|
241
|
+
try:
|
|
242
|
+
module = sys.modules[module_name]
|
|
243
|
+
module_file = module.__file__.replace('\\','/')
|
|
244
|
+
print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
|
|
245
|
+
except Exception:
|
|
246
|
+
pass
|
|
211
247
|
del sys.modules[module_name]
|
|
212
|
-
|
|
248
|
+
|
|
249
|
+
# ...for each module we want to remove from the global namespace
|
|
250
|
+
|
|
213
251
|
paths_to_delete = []
|
|
214
|
-
|
|
252
|
+
|
|
215
253
|
for p in sys.path:
|
|
216
254
|
if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
|
|
217
255
|
print('clean_yolo_imports: removing {} from path'.format(p))
|
|
218
256
|
paths_to_delete.append(p)
|
|
219
|
-
|
|
257
|
+
|
|
220
258
|
for p in paths_to_delete:
|
|
221
259
|
sys.path.remove(p)
|
|
222
260
|
|
|
@@ -228,52 +266,67 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
228
266
|
force_reimport=False,
|
|
229
267
|
verbose=False):
|
|
230
268
|
"""
|
|
231
|
-
Imports required functions from one or more yolo libraries (yolov5, yolov9,
|
|
269
|
+
Imports required functions from one or more yolo libraries (yolov5, yolov9,
|
|
232
270
|
ultralytics, targeting support for [model_type]).
|
|
233
|
-
|
|
271
|
+
|
|
234
272
|
Args:
|
|
235
273
|
model_type (str): The model type for which we're loading support
|
|
236
|
-
allow_fallback_import (bool, optional): If we can't import from the package for
|
|
237
|
-
which we're trying to load support, fall back to "import utils". This is
|
|
274
|
+
allow_fallback_import (bool, optional): If we can't import from the package for
|
|
275
|
+
which we're trying to load support, fall back to "import utils". This is
|
|
238
276
|
typically used when the right support library is on the current PYTHONPATH.
|
|
239
|
-
force_reimport (bool, optional): import the appropriate libraries even if the
|
|
277
|
+
force_reimport (bool, optional): import the appropriate libraries even if the
|
|
240
278
|
requested model type matches the current initialization state
|
|
241
279
|
verbose (bool, optional): include additional debug output
|
|
242
|
-
|
|
280
|
+
|
|
243
281
|
Returns:
|
|
244
282
|
str: the model type for which we initialized support
|
|
245
283
|
"""
|
|
246
|
-
|
|
284
|
+
|
|
285
|
+
# When running in pytest, the megadetector 'utils' module is put in the global
|
|
286
|
+
# namespace, which creates conflicts with yolov5; remove it from the global
|
|
287
|
+
# namespsace.
|
|
288
|
+
if ('PYTEST_CURRENT_TEST' in os.environ):
|
|
289
|
+
print('*** pytest detected ***')
|
|
290
|
+
if ('utils' in sys.modules):
|
|
291
|
+
utils_module = sys.modules['utils']
|
|
292
|
+
if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
|
|
293
|
+
print(f"Removing conflicting utils module: {utils_module.__file__}")
|
|
294
|
+
sys.modules.pop('utils', None)
|
|
295
|
+
# Also remove any submodules
|
|
296
|
+
to_remove = [name for name in sys.modules if name.startswith('utils.')]
|
|
297
|
+
for name in to_remove:
|
|
298
|
+
sys.modules.pop(name, None)
|
|
299
|
+
|
|
247
300
|
global yolo_model_type_imported
|
|
248
|
-
|
|
301
|
+
|
|
249
302
|
if model_type is None:
|
|
250
303
|
model_type = 'yolov5'
|
|
251
|
-
|
|
304
|
+
|
|
252
305
|
# The point of this function is to make the appropriate version
|
|
253
306
|
# of the following functions available at module scope
|
|
254
307
|
global non_max_suppression
|
|
255
308
|
global xyxy2xywh
|
|
256
309
|
global letterbox
|
|
257
310
|
global scale_coords
|
|
258
|
-
|
|
311
|
+
|
|
259
312
|
if yolo_model_type_imported is not None:
|
|
260
|
-
if yolo_model_type_imported == model_type:
|
|
313
|
+
if (yolo_model_type_imported == model_type) and (not force_reimport):
|
|
261
314
|
print('Bypassing imports for YOLO model type {}'.format(model_type))
|
|
262
315
|
return
|
|
263
316
|
else:
|
|
264
317
|
_clean_yolo_imports()
|
|
265
|
-
|
|
318
|
+
|
|
266
319
|
try_yolov5_import = (model_type == 'yolov5')
|
|
267
320
|
try_yolov9_import = (model_type == 'yolov9')
|
|
268
321
|
try_ultralytics_import = (model_type == 'ultralytics')
|
|
269
|
-
|
|
322
|
+
|
|
270
323
|
utils_imported = False
|
|
271
|
-
|
|
324
|
+
|
|
272
325
|
# First try importing from the yolov5 package; this is how the pip
|
|
273
326
|
# package finds YOLOv5 utilities.
|
|
274
327
|
if try_yolov5_import and not utils_imported:
|
|
275
|
-
|
|
276
|
-
try:
|
|
328
|
+
|
|
329
|
+
try:
|
|
277
330
|
from yolov5.utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
278
331
|
from yolov5.utils.augmentations import letterbox # noqa
|
|
279
332
|
try:
|
|
@@ -283,109 +336,127 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
283
336
|
utils_imported = True
|
|
284
337
|
if verbose:
|
|
285
338
|
print('Imported utils from YOLOv5 package')
|
|
286
|
-
|
|
287
|
-
except Exception as e: # noqa
|
|
288
|
-
|
|
339
|
+
|
|
340
|
+
except Exception as e: # noqa
|
|
289
341
|
# print('yolov5 module import failed: {}'.format(e))
|
|
290
|
-
# print(traceback.format_exc())
|
|
342
|
+
# print(traceback.format_exc())
|
|
291
343
|
pass
|
|
292
|
-
|
|
344
|
+
|
|
293
345
|
# Next try importing from the yolov9 package
|
|
294
346
|
if try_yolov9_import and not utils_imported:
|
|
295
|
-
|
|
347
|
+
|
|
296
348
|
try:
|
|
297
|
-
|
|
349
|
+
|
|
298
350
|
from yolov9.utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
299
351
|
from yolov9.utils.augmentations import letterbox # noqa
|
|
300
352
|
from yolov9.utils.general import scale_boxes as scale_coords # noqa
|
|
301
353
|
utils_imported = True
|
|
302
354
|
if verbose:
|
|
303
355
|
print('Imported utils from YOLOv9 package')
|
|
304
|
-
|
|
356
|
+
|
|
305
357
|
except Exception as e: # noqa
|
|
306
|
-
|
|
358
|
+
|
|
307
359
|
# print('yolov9 module import failed: {}'.format(e))
|
|
308
360
|
# print(traceback.format_exc())
|
|
309
361
|
pass
|
|
310
|
-
|
|
311
|
-
# If we haven't succeeded yet, import from the ultralytics package
|
|
362
|
+
|
|
363
|
+
# If we haven't succeeded yet, import from the ultralytics package
|
|
312
364
|
if try_ultralytics_import and not utils_imported:
|
|
313
|
-
|
|
365
|
+
|
|
314
366
|
try:
|
|
315
|
-
|
|
316
|
-
import ultralytics # noqa
|
|
317
|
-
|
|
367
|
+
|
|
368
|
+
import ultralytics # type: ignore # noqa
|
|
369
|
+
|
|
318
370
|
except Exception:
|
|
319
|
-
|
|
371
|
+
|
|
320
372
|
print('It looks like you are trying to run a model that requires the ultralytics package, '
|
|
321
373
|
'but the ultralytics package is not installed, but . For licensing reasons, this '
|
|
322
374
|
'is not installed by default with the MegaDetector Python package. Run '
|
|
323
375
|
'"pip install ultralytics" to install it, and try again.')
|
|
324
376
|
raise
|
|
325
|
-
|
|
377
|
+
|
|
326
378
|
try:
|
|
327
|
-
|
|
328
|
-
from ultralytics.utils.ops import non_max_suppression # noqa
|
|
329
|
-
from ultralytics.utils.ops import xyxy2xywh # noqa
|
|
330
|
-
|
|
379
|
+
|
|
380
|
+
from ultralytics.utils.ops import non_max_suppression # type: ignore # noqa
|
|
381
|
+
from ultralytics.utils.ops import xyxy2xywh # type: ignore # noqa
|
|
382
|
+
|
|
331
383
|
# In the ultralytics package, scale_boxes and scale_coords both exist;
|
|
332
384
|
# we want scale_boxes.
|
|
333
|
-
#
|
|
385
|
+
#
|
|
334
386
|
# from ultralytics.utils.ops import scale_coords # noqa
|
|
335
|
-
from ultralytics.utils.ops import scale_boxes as scale_coords # noqa
|
|
336
|
-
from ultralytics.data.augment import LetterBox
|
|
337
|
-
|
|
338
|
-
# letterbox() became a LetterBox class in the ultralytics package. Create a
|
|
387
|
+
from ultralytics.utils.ops import scale_boxes as scale_coords # type: ignore # noqa
|
|
388
|
+
from ultralytics.data.augment import LetterBox # type: ignore # noqa
|
|
389
|
+
|
|
390
|
+
# letterbox() became a LetterBox class in the ultralytics package. Create a
|
|
339
391
|
# backwards-compatible letterbox function wrapper that wraps the class up.
|
|
340
|
-
def letterbox(img,new_shape,auto=False,scaleFill=False,
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
392
|
+
def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
|
|
393
|
+
scaleup=True,center=True,stride=32):
|
|
394
|
+
|
|
395
|
+
# Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
|
|
396
|
+
# both conventions.
|
|
397
|
+
use_old_scalefill_arg = False
|
|
398
|
+
try:
|
|
399
|
+
sig = inspect.signature(LetterBox.__init__)
|
|
400
|
+
if 'scaleFill' in sig.parameters:
|
|
401
|
+
use_old_scalefill_arg = True
|
|
402
|
+
except Exception:
|
|
403
|
+
pass
|
|
404
|
+
|
|
405
|
+
if use_old_scalefill_arg:
|
|
406
|
+
if verbose:
|
|
407
|
+
print('Using old scaleFill calling convention')
|
|
408
|
+
letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
|
|
409
|
+
scaleup=scaleup,center=center,stride=stride)
|
|
410
|
+
else:
|
|
411
|
+
letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
|
|
412
|
+
scaleup=scaleup,center=center,stride=stride)
|
|
413
|
+
|
|
414
|
+
letterbox_result = letterbox_transformer(image=img)
|
|
415
|
+
|
|
345
416
|
if isinstance(new_shape,int):
|
|
346
417
|
new_shape = [new_shape,new_shape]
|
|
347
|
-
|
|
418
|
+
|
|
348
419
|
# The letterboxing is done, we just need to reverse-engineer what it did
|
|
349
420
|
shape = img.shape[:2]
|
|
350
|
-
|
|
421
|
+
|
|
351
422
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
352
423
|
if not scaleup:
|
|
353
424
|
r = min(r, 1.0)
|
|
354
425
|
ratio = r, r
|
|
355
|
-
|
|
426
|
+
|
|
356
427
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
357
428
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
|
358
429
|
if auto:
|
|
359
430
|
dw, dh = np.mod(dw, stride), np.mod(dh, stride)
|
|
360
|
-
elif scaleFill:
|
|
431
|
+
elif scaleFill:
|
|
361
432
|
dw, dh = 0.0, 0.0
|
|
362
433
|
new_unpad = (new_shape[1], new_shape[0])
|
|
363
434
|
ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
|
|
364
|
-
|
|
435
|
+
|
|
365
436
|
dw /= 2
|
|
366
437
|
dh /= 2
|
|
367
438
|
pad = (dw,dh)
|
|
368
|
-
|
|
439
|
+
|
|
369
440
|
return [letterbox_result,ratio,pad]
|
|
370
|
-
|
|
441
|
+
|
|
371
442
|
utils_imported = True
|
|
372
443
|
if verbose:
|
|
373
444
|
print('Imported utils from ultralytics package')
|
|
374
|
-
|
|
445
|
+
|
|
375
446
|
except Exception:
|
|
376
|
-
|
|
447
|
+
|
|
377
448
|
# print('Ultralytics module import failed')
|
|
378
449
|
pass
|
|
379
|
-
|
|
450
|
+
|
|
380
451
|
# If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
|
|
381
452
|
if (not utils_imported) and allow_fallback_import:
|
|
382
|
-
|
|
453
|
+
|
|
383
454
|
try:
|
|
384
|
-
|
|
455
|
+
|
|
385
456
|
# import pre- and post-processing functions from the YOLOv5 repo
|
|
386
457
|
from utils.general import non_max_suppression, xyxy2xywh # noqa
|
|
387
458
|
from utils.augmentations import letterbox # noqa
|
|
388
|
-
|
|
459
|
+
|
|
389
460
|
# scale_coords() is scale_boxes() in some YOLOv5 versions
|
|
390
461
|
try:
|
|
391
462
|
from utils.general import scale_coords # noqa
|
|
@@ -395,58 +466,58 @@ def _initialize_yolo_imports(model_type='yolov5',
|
|
|
395
466
|
imported_file = sys.modules[scale_coords.__module__].__file__
|
|
396
467
|
if verbose:
|
|
397
468
|
print('Imported utils from {}'.format(imported_file))
|
|
398
|
-
|
|
469
|
+
|
|
399
470
|
except ModuleNotFoundError as e:
|
|
400
|
-
|
|
471
|
+
|
|
401
472
|
raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
|
|
402
|
-
|
|
473
|
+
|
|
403
474
|
assert utils_imported, 'YOLO utils import error'
|
|
404
|
-
|
|
475
|
+
|
|
405
476
|
yolo_model_type_imported = model_type
|
|
406
477
|
if verbose:
|
|
407
478
|
print('Prepared YOLO imports for model type {}'.format(model_type))
|
|
408
|
-
|
|
479
|
+
|
|
409
480
|
return model_type
|
|
410
481
|
|
|
411
482
|
# ...def _initialize_yolo_imports(...)
|
|
412
|
-
|
|
483
|
+
|
|
413
484
|
|
|
414
485
|
#%% Model metadata functions
|
|
415
486
|
|
|
416
|
-
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
487
|
+
def add_metadata_to_megadetector_model_file(model_file_in,
|
|
417
488
|
model_file_out,
|
|
418
|
-
metadata,
|
|
489
|
+
metadata,
|
|
419
490
|
destination_path='megadetector_info.json'):
|
|
420
491
|
"""
|
|
421
|
-
Adds a .json file to the specified MegaDetector model file containing metadata used
|
|
492
|
+
Adds a .json file to the specified MegaDetector model file containing metadata used
|
|
422
493
|
by this module. Always over-writes the output file.
|
|
423
|
-
|
|
494
|
+
|
|
424
495
|
Args:
|
|
425
496
|
model_file_in (str): The input model filename, typically .pt (.zip is also sensible)
|
|
426
497
|
model_file_out (str): The output model filename, typically .pt (.zip is also sensible).
|
|
427
498
|
May be the same as model_file_in.
|
|
428
499
|
metadata (dict): The metadata dict to add to the output model file
|
|
429
|
-
destination_path (str, optional): The relative path within the main folder of the
|
|
430
|
-
model archive where we should write the metadata. This is not relative to the root
|
|
500
|
+
destination_path (str, optional): The relative path within the main folder of the
|
|
501
|
+
model archive where we should write the metadata. This is not relative to the root
|
|
431
502
|
of the archive, it's relative to the one and only folder at the root of the archive
|
|
432
|
-
(this is a PyTorch convention).
|
|
503
|
+
(this is a PyTorch convention).
|
|
433
504
|
"""
|
|
434
|
-
|
|
505
|
+
|
|
435
506
|
tmp_base = os.path.join(tempfile.gettempdir(),'md_metadata')
|
|
436
507
|
os.makedirs(tmp_base,exist_ok=True)
|
|
437
508
|
metadata_tmp_file_relative = 'megadetector_info_' + str(uuid.uuid1()) + '.json'
|
|
438
|
-
metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
|
|
509
|
+
metadata_tmp_file_abs = os.path.join(tmp_base,metadata_tmp_file_relative)
|
|
439
510
|
|
|
440
511
|
with open(metadata_tmp_file_abs,'w') as f:
|
|
441
512
|
json.dump(metadata,f,indent=1)
|
|
442
|
-
|
|
513
|
+
|
|
443
514
|
# Copy the input file to the output file
|
|
444
515
|
shutil.copyfile(model_file_in,model_file_out)
|
|
445
516
|
|
|
446
517
|
# Write metadata to the output file
|
|
447
518
|
with zipfile.ZipFile(model_file_out, 'a', compression=zipfile.ZIP_DEFLATED) as zipf:
|
|
448
|
-
|
|
449
|
-
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
519
|
+
|
|
520
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
450
521
|
# it in the one and only folder.
|
|
451
522
|
names = zipf.namelist()
|
|
452
523
|
root_folders = set()
|
|
@@ -456,9 +527,9 @@ def add_metadata_to_megadetector_model_file(model_file_in,
|
|
|
456
527
|
assert len(root_folders) == 1,\
|
|
457
528
|
'This archive does not have exactly one folder at the top level; are you sure it\'s a Torch model file?'
|
|
458
529
|
root_folder = next(iter(root_folders))
|
|
459
|
-
|
|
460
|
-
zipf.write(metadata_tmp_file_abs,
|
|
461
|
-
root_folder + '/' + destination_path,
|
|
530
|
+
|
|
531
|
+
zipf.write(metadata_tmp_file_abs,
|
|
532
|
+
root_folder + '/' + destination_path,
|
|
462
533
|
compresslevel=9,
|
|
463
534
|
compress_type=zipfile.ZIP_DEFLATED)
|
|
464
535
|
|
|
@@ -466,7 +537,7 @@ def add_metadata_to_megadetector_model_file(model_file_in,
|
|
|
466
537
|
os.remove(metadata_tmp_file_abs)
|
|
467
538
|
except Exception as e:
|
|
468
539
|
print('Warning: error deleting file {}: {}'.format(metadata_tmp_file_abs,str(e)))
|
|
469
|
-
|
|
540
|
+
|
|
470
541
|
# ...def add_metadata_to_megadetector_model_file(...)
|
|
471
542
|
|
|
472
543
|
|
|
@@ -475,22 +546,23 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
475
546
|
verbose=False):
|
|
476
547
|
"""
|
|
477
548
|
Reads custom MegaDetector metadata from a modified MegaDetector model file.
|
|
478
|
-
|
|
549
|
+
|
|
479
550
|
Args:
|
|
480
551
|
model_file (str): The model filename to read, typically .pt (.zip is also sensible)
|
|
481
|
-
relative_path (str, optional): The relative path within the main folder of the model
|
|
482
|
-
archive from which we should read the metadata. This is not relative to the root
|
|
552
|
+
relative_path (str, optional): The relative path within the main folder of the model
|
|
553
|
+
archive from which we should read the metadata. This is not relative to the root
|
|
483
554
|
of the archive, it's relative to the one and only folder at the root of the archive
|
|
484
|
-
(this is a PyTorch convention).
|
|
485
|
-
|
|
555
|
+
(this is a PyTorch convention).
|
|
556
|
+
verbose (str, optional): enable additional debug output
|
|
557
|
+
|
|
486
558
|
Returns:
|
|
487
|
-
object:
|
|
488
|
-
|
|
559
|
+
object: whatever we read from the metadata file, always a dict in practice. Returns
|
|
560
|
+
None if we failed to read the specified metadata file.
|
|
489
561
|
"""
|
|
490
|
-
|
|
562
|
+
|
|
491
563
|
with zipfile.ZipFile(model_file,'r') as zipf:
|
|
492
|
-
|
|
493
|
-
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
564
|
+
|
|
565
|
+
# Torch doesn't like anything in the root folder of the zipfile, so we put
|
|
494
566
|
# it in the one and only folder.
|
|
495
567
|
names = zipf.namelist()
|
|
496
568
|
root_folders = set()
|
|
@@ -498,17 +570,19 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
498
570
|
root_folder = name.split('/')[0]
|
|
499
571
|
root_folders.add(root_folder)
|
|
500
572
|
if len(root_folders) != 1:
|
|
501
|
-
print('Warning: this archive does not have exactly one folder at the top level;
|
|
573
|
+
print('Warning: this archive does not have exactly one folder at the top level; ' + \
|
|
574
|
+
'are you sure it\'s a Torch model file?')
|
|
502
575
|
return None
|
|
503
576
|
root_folder = next(iter(root_folders))
|
|
504
|
-
|
|
577
|
+
|
|
505
578
|
metadata_file = root_folder + '/' + relative_path
|
|
506
579
|
if metadata_file not in names:
|
|
507
580
|
# This is the case for MDv5a and MDv5b
|
|
508
581
|
if verbose:
|
|
509
|
-
print('Warning: could not find metadata file {} in zip archive'.format(
|
|
582
|
+
print('Warning: could not find metadata file {} in zip archive {}'.format(
|
|
583
|
+
metadata_file,os.path.basename(model_file)))
|
|
510
584
|
return None
|
|
511
|
-
|
|
585
|
+
|
|
512
586
|
try:
|
|
513
587
|
path = zipfile.Path(zipf,metadata_file)
|
|
514
588
|
contents = path.read_text()
|
|
@@ -516,9 +590,9 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
516
590
|
except Exception as e:
|
|
517
591
|
print('Warning: error reading metadata from path {}: {}'.format(metadata_file,str(e)))
|
|
518
592
|
return None
|
|
519
|
-
|
|
593
|
+
|
|
520
594
|
return d
|
|
521
|
-
|
|
595
|
+
|
|
522
596
|
# ...def read_metadata_from_megadetector_model_file(...)
|
|
523
597
|
|
|
524
598
|
|
|
@@ -526,27 +600,33 @@ def read_metadata_from_megadetector_model_file(model_file,
|
|
|
526
600
|
|
|
527
601
|
default_compatibility_mode = 'classic'
|
|
528
602
|
|
|
529
|
-
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
|
|
603
|
+
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
|
|
530
604
|
# correctly forcing a specific compatibility mode (I use "classic-test" in that case)
|
|
531
605
|
require_non_default_compatibility_mode = False
|
|
532
606
|
|
|
533
607
|
class PTDetector:
|
|
534
|
-
|
|
608
|
+
"""
|
|
609
|
+
Class that runs a PyTorch-based MegaDetector model.
|
|
610
|
+
"""
|
|
611
|
+
|
|
535
612
|
def __init__(self, model_path, detector_options=None, verbose=False):
|
|
536
|
-
|
|
613
|
+
|
|
614
|
+
if verbose:
|
|
615
|
+
print('Initializing PTDetector (verbose)')
|
|
616
|
+
|
|
537
617
|
# Set up the import environment for this model, unloading previous
|
|
538
618
|
# YOLO library versions if necessary.
|
|
539
619
|
_initialize_yolo_imports_for_model(model_path,
|
|
540
620
|
detector_options=detector_options,
|
|
541
621
|
verbose=verbose)
|
|
542
|
-
|
|
622
|
+
|
|
543
623
|
# Parse options specific to this detector family
|
|
544
624
|
force_cpu = False
|
|
545
|
-
use_model_native_classes = False
|
|
625
|
+
use_model_native_classes = False
|
|
546
626
|
compatibility_mode = default_compatibility_mode
|
|
547
|
-
|
|
627
|
+
|
|
548
628
|
if detector_options is not None:
|
|
549
|
-
|
|
629
|
+
|
|
550
630
|
if 'force_cpu' in detector_options:
|
|
551
631
|
force_cpu = parse_bool_string(detector_options['force_cpu'])
|
|
552
632
|
if 'use_model_native_classes' in detector_options:
|
|
@@ -554,66 +634,66 @@ class PTDetector:
|
|
|
554
634
|
if 'compatibility_mode' in detector_options:
|
|
555
635
|
if detector_options['compatibility_mode'] is None:
|
|
556
636
|
compatibility_mode = default_compatibility_mode
|
|
557
|
-
else:
|
|
558
|
-
compatibility_mode = detector_options['compatibility_mode']
|
|
559
|
-
|
|
637
|
+
else:
|
|
638
|
+
compatibility_mode = detector_options['compatibility_mode']
|
|
639
|
+
|
|
560
640
|
if require_non_default_compatibility_mode:
|
|
561
|
-
|
|
641
|
+
|
|
562
642
|
print('### DEBUG: requiring non-default compatibility mode ###')
|
|
563
643
|
assert compatibility_mode != 'classic'
|
|
564
644
|
assert compatibility_mode != 'default'
|
|
565
|
-
|
|
645
|
+
|
|
566
646
|
preprocess_only = False
|
|
567
647
|
if (detector_options is not None) and \
|
|
568
648
|
('preprocess_only' in detector_options) and \
|
|
569
649
|
(detector_options['preprocess_only']):
|
|
570
650
|
preprocess_only = True
|
|
571
|
-
|
|
651
|
+
|
|
572
652
|
if verbose or (not preprocess_only):
|
|
573
653
|
print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
|
|
574
|
-
|
|
654
|
+
|
|
575
655
|
model_metadata = read_metadata_from_megadetector_model_file(model_path)
|
|
576
|
-
|
|
577
|
-
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
|
|
656
|
+
|
|
657
|
+
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side, preserving
|
|
578
658
|
#: aspect ratio".
|
|
579
659
|
if model_metadata is not None and 'image_size' in model_metadata:
|
|
580
660
|
self.default_image_size = model_metadata['image_size']
|
|
581
|
-
|
|
582
|
-
print('Loaded image size {} from model metadata'.format(self.default_image_size))
|
|
661
|
+
print('Loaded image size {} from model metadata'.format(self.default_image_size))
|
|
583
662
|
else:
|
|
663
|
+
print('No image size available in model metadata, defaulting to 1280')
|
|
584
664
|
self.default_image_size = 1280
|
|
585
|
-
|
|
665
|
+
|
|
586
666
|
#: Either a string ('cpu','cuda:0') or a torch.device()
|
|
587
667
|
self.device = 'cpu'
|
|
588
|
-
|
|
668
|
+
|
|
589
669
|
#: Have we already printed a warning about using a non-standard image size?
|
|
590
670
|
#:
|
|
591
671
|
#: :meta private:
|
|
592
672
|
self.printed_image_size_warning = False
|
|
593
|
-
|
|
673
|
+
|
|
594
674
|
#: If this is False, we assume the underlying model is producing class indices in the
|
|
595
675
|
#: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
|
|
596
|
-
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
|
|
676
|
+
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
|
|
597
677
|
#: indices the model provides
|
|
598
|
-
self.use_model_native_classes = use_model_native_classes
|
|
599
|
-
|
|
678
|
+
self.use_model_native_classes = use_model_native_classes
|
|
679
|
+
|
|
600
680
|
#: This allows us to maintain backwards compatibility across a set of changes to the
|
|
601
|
-
#: way this class does inference. Currently should start with either "default" or
|
|
681
|
+
#: way this class does inference. Currently should start with either "default" or
|
|
602
682
|
#: "classic".
|
|
603
683
|
self.compatibility_mode = compatibility_mode
|
|
604
|
-
|
|
684
|
+
|
|
605
685
|
#: Stride size passed to YOLOv5's letterbox() function
|
|
606
686
|
self.letterbox_stride = 32
|
|
607
|
-
|
|
687
|
+
|
|
608
688
|
if 'classic' in self.compatibility_mode:
|
|
609
689
|
self.letterbox_stride = 64
|
|
610
|
-
|
|
690
|
+
|
|
611
691
|
#: Use half-precision inference... fixed by the model, generally don't mess with this
|
|
612
692
|
self.half_precision = False
|
|
613
|
-
|
|
693
|
+
|
|
614
694
|
if preprocess_only:
|
|
615
695
|
return
|
|
616
|
-
|
|
696
|
+
|
|
617
697
|
if not force_cpu:
|
|
618
698
|
if torch.cuda.is_available():
|
|
619
699
|
self.device = torch.device('cuda:0')
|
|
@@ -623,10 +703,10 @@ class PTDetector:
|
|
|
623
703
|
except AttributeError:
|
|
624
704
|
pass
|
|
625
705
|
try:
|
|
626
|
-
self.model = PTDetector._load_model(model_path,
|
|
627
|
-
device=self.device,
|
|
706
|
+
self.model = PTDetector._load_model(model_path,
|
|
707
|
+
device=self.device,
|
|
628
708
|
compatibility_mode=self.compatibility_mode)
|
|
629
|
-
|
|
709
|
+
|
|
630
710
|
except Exception as e:
|
|
631
711
|
# In a very esoteric scenario where an old version of YOLOv5 is used to run
|
|
632
712
|
# newer models, we run into an issue because the "Model" class became
|
|
@@ -636,21 +716,21 @@ class PTDetector:
|
|
|
636
716
|
print('Forward-compatibility issue detected, patching')
|
|
637
717
|
from models import yolo
|
|
638
718
|
yolo.DetectionModel = yolo.Model
|
|
639
|
-
self.model = PTDetector._load_model(model_path,
|
|
719
|
+
self.model = PTDetector._load_model(model_path,
|
|
640
720
|
device=self.device,
|
|
641
721
|
compatibility_mode=self.compatibility_mode,
|
|
642
|
-
verbose=verbose)
|
|
722
|
+
verbose=verbose)
|
|
643
723
|
else:
|
|
644
724
|
raise
|
|
645
725
|
if (self.device != 'cpu'):
|
|
646
726
|
if verbose:
|
|
647
727
|
print('Sending model to GPU')
|
|
648
728
|
self.model.to(self.device)
|
|
649
|
-
|
|
729
|
+
|
|
650
730
|
|
|
651
731
|
@staticmethod
|
|
652
732
|
def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
|
|
653
|
-
|
|
733
|
+
|
|
654
734
|
if verbose:
|
|
655
735
|
print(f'Using PyTorch version {torch.__version__}')
|
|
656
736
|
|
|
@@ -661,12 +741,12 @@ class PTDetector:
|
|
|
661
741
|
# very slight changes to the output, which always make me nervous, so I'm not
|
|
662
742
|
# doing a wholesale swap just yet. Instead, we'll just do this on M1 hardware.
|
|
663
743
|
if 'classic' in compatibility_mode:
|
|
664
|
-
use_map_location = (device != 'mps')
|
|
744
|
+
use_map_location = (device != 'mps')
|
|
665
745
|
else:
|
|
666
746
|
use_map_location = False
|
|
667
|
-
|
|
747
|
+
|
|
668
748
|
if use_map_location:
|
|
669
|
-
try:
|
|
749
|
+
try:
|
|
670
750
|
checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
|
|
671
751
|
# For a transitional period, we want to support torch 1.1x, where the weights_only
|
|
672
752
|
# parameter doesn't exist
|
|
@@ -682,31 +762,31 @@ class PTDetector:
|
|
|
682
762
|
# parameter doesn't exist
|
|
683
763
|
except Exception as e:
|
|
684
764
|
if "'weights_only' is an invalid keyword" in str(e):
|
|
685
|
-
checkpoint = torch.load(model_pt_path)
|
|
765
|
+
checkpoint = torch.load(model_pt_path)
|
|
686
766
|
else:
|
|
687
767
|
raise
|
|
688
|
-
|
|
689
|
-
# Compatibility fix that allows us to load older YOLOv5 models with
|
|
768
|
+
|
|
769
|
+
# Compatibility fix that allows us to load older YOLOv5 models with
|
|
690
770
|
# newer versions of YOLOv5/PT
|
|
691
771
|
for m in checkpoint['model'].modules():
|
|
692
772
|
t = type(m)
|
|
693
773
|
if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
694
774
|
m.recompute_scale_factor = None
|
|
695
|
-
|
|
696
|
-
if use_map_location:
|
|
697
|
-
model = checkpoint['model'].float().fuse().eval()
|
|
775
|
+
|
|
776
|
+
if use_map_location:
|
|
777
|
+
model = checkpoint['model'].float().fuse().eval()
|
|
698
778
|
else:
|
|
699
|
-
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
700
|
-
|
|
779
|
+
model = checkpoint['model'].float().fuse().eval().to(device)
|
|
780
|
+
|
|
701
781
|
return model
|
|
702
782
|
|
|
703
783
|
# ...def _load_model(...)
|
|
704
|
-
|
|
705
784
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
785
|
+
|
|
786
|
+
def generate_detections_one_image(self,
|
|
787
|
+
img_original,
|
|
788
|
+
image_id='unknown',
|
|
789
|
+
detection_threshold=0.00001,
|
|
710
790
|
image_size=None,
|
|
711
791
|
skip_image_resizing=False,
|
|
712
792
|
augment=False,
|
|
@@ -716,16 +796,16 @@ class PTDetector:
|
|
|
716
796
|
Applies the detector to an image.
|
|
717
797
|
|
|
718
798
|
Args:
|
|
719
|
-
img_original (Image): the PIL Image object (or numpy array) on which we should run the
|
|
799
|
+
img_original (Image): the PIL Image object (or numpy array) on which we should run the
|
|
720
800
|
detector, with EXIF rotation already handled
|
|
721
|
-
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
801
|
+
image_id (str, optional): a path to identify the image; will be in the "file" field
|
|
722
802
|
of the output object
|
|
723
|
-
detection_threshold (float, optional): only detections above this confidence threshold
|
|
803
|
+
detection_threshold (float, optional): only detections above this confidence threshold
|
|
724
804
|
will be included in the return value
|
|
725
|
-
image_size (
|
|
805
|
+
image_size (int, optional): image size to use for inference, only mess with this if
|
|
726
806
|
(a) you're using a model other than MegaDetector or (b) you know what you're getting into
|
|
727
|
-
skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
|
|
728
|
-
external resizing), only mess with this if (a) you're using a model other than MegaDetector
|
|
807
|
+
skip_image_resizing (bool, optional): whether to skip internal image resizing (and rely on
|
|
808
|
+
external resizing), only mess with this if (a) you're using a model other than MegaDetector
|
|
729
809
|
or (b) you know what you're getting into
|
|
730
810
|
augment (bool, optional): enable (implementation-specific) image augmentation
|
|
731
811
|
preprocess_only (bool, optional): only run preprocessing, and return the preprocessed image
|
|
@@ -747,17 +827,17 @@ class PTDetector:
|
|
|
747
827
|
assert 'classic' in self.compatibility_mode, \
|
|
748
828
|
'Standalone preprocessing only supported in "classic" mode'
|
|
749
829
|
assert not skip_image_resizing, \
|
|
750
|
-
'skip_image_resizing and preprocess_only are exclusive'
|
|
751
|
-
|
|
830
|
+
'skip_image_resizing and preprocess_only are exclusive'
|
|
831
|
+
|
|
752
832
|
if detection_threshold is None:
|
|
753
|
-
|
|
833
|
+
|
|
754
834
|
detection_threshold = 0
|
|
755
|
-
|
|
835
|
+
|
|
756
836
|
try:
|
|
757
|
-
|
|
837
|
+
|
|
758
838
|
# If the caller wants us to skip all the resizing operations...
|
|
759
839
|
if skip_image_resizing:
|
|
760
|
-
|
|
840
|
+
|
|
761
841
|
if isinstance(img_original,dict):
|
|
762
842
|
image_info = img_original
|
|
763
843
|
img = image_info['img_processed']
|
|
@@ -768,107 +848,107 @@ class PTDetector:
|
|
|
768
848
|
img_original_pil = image_info['img_original_pil']
|
|
769
849
|
else:
|
|
770
850
|
img = img_original
|
|
771
|
-
|
|
851
|
+
|
|
772
852
|
else:
|
|
773
|
-
|
|
853
|
+
|
|
774
854
|
img_original_pil = None
|
|
775
855
|
# If we were given a PIL image
|
|
776
|
-
|
|
856
|
+
|
|
777
857
|
if not isinstance(img_original,np.ndarray):
|
|
778
|
-
img_original_pil = img_original
|
|
858
|
+
img_original_pil = img_original
|
|
779
859
|
img_original = np.asarray(img_original)
|
|
780
|
-
|
|
860
|
+
|
|
781
861
|
# PIL images are RGB already
|
|
782
862
|
# img_original = img_original[:, :, ::-1]
|
|
783
|
-
|
|
863
|
+
|
|
784
864
|
# Save the original shape for scaling boxes later
|
|
785
865
|
scaling_shape = img_original.shape
|
|
786
|
-
|
|
866
|
+
|
|
787
867
|
# If the caller is requesting a specific target size...
|
|
788
868
|
if image_size is not None:
|
|
789
|
-
|
|
869
|
+
|
|
790
870
|
assert isinstance(image_size,int)
|
|
791
|
-
|
|
871
|
+
|
|
792
872
|
if not self.printed_image_size_warning:
|
|
793
873
|
print('Using user-supplied image size {}'.format(image_size))
|
|
794
|
-
self.printed_image_size_warning = True
|
|
795
|
-
|
|
874
|
+
self.printed_image_size_warning = True
|
|
875
|
+
|
|
796
876
|
# Otherwise resize to self.default_image_size
|
|
797
877
|
else:
|
|
798
|
-
|
|
878
|
+
|
|
799
879
|
image_size = self.default_image_size
|
|
800
880
|
self.printed_image_size_warning = False
|
|
801
|
-
|
|
881
|
+
|
|
802
882
|
# ...if the caller has specified an image size
|
|
803
|
-
|
|
883
|
+
|
|
804
884
|
# In "classic mode", we only do the letterboxing resize, we don't do an
|
|
805
885
|
# additional initial resizing operation
|
|
806
886
|
if 'classic' in self.compatibility_mode:
|
|
807
|
-
|
|
887
|
+
|
|
808
888
|
resize_ratio = 1.0
|
|
809
|
-
|
|
810
|
-
# Resize the image so the long side matches the target image size. This is not
|
|
889
|
+
|
|
890
|
+
# Resize the image so the long side matches the target image size. This is not
|
|
811
891
|
# letterboxing (i.e., padding) yet, just resizing.
|
|
812
892
|
else:
|
|
813
|
-
|
|
893
|
+
|
|
814
894
|
use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
|
|
815
|
-
|
|
895
|
+
|
|
816
896
|
h,w = img_original.shape[:2]
|
|
817
897
|
resize_ratio = image_size / max(h,w)
|
|
818
|
-
|
|
898
|
+
|
|
819
899
|
# Only resize if we have to
|
|
820
900
|
if resize_ratio != 1:
|
|
821
|
-
|
|
822
|
-
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
901
|
+
|
|
902
|
+
# Match what yolov5 does: use linear interpolation for upsizing;
|
|
823
903
|
# area interpolation for downsizing
|
|
824
904
|
if resize_ratio > 1:
|
|
825
905
|
interpolation_method = cv2.INTER_LINEAR
|
|
826
906
|
else:
|
|
827
|
-
interpolation_method = cv2.INTER_AREA
|
|
828
|
-
|
|
907
|
+
interpolation_method = cv2.INTER_AREA
|
|
908
|
+
|
|
829
909
|
if use_ceil_for_resize:
|
|
830
910
|
target_w = math.ceil(w * resize_ratio)
|
|
831
911
|
target_h = math.ceil(h * resize_ratio)
|
|
832
912
|
else:
|
|
833
913
|
target_w = int(w * resize_ratio)
|
|
834
914
|
target_h = int(h * resize_ratio)
|
|
835
|
-
|
|
915
|
+
|
|
836
916
|
img_original = cv2.resize(
|
|
837
917
|
img_original, (target_w, target_h),
|
|
838
918
|
interpolation=interpolation_method)
|
|
839
919
|
|
|
840
920
|
if 'classic' in self.compatibility_mode:
|
|
841
|
-
|
|
921
|
+
|
|
842
922
|
letterbox_auto = True
|
|
843
923
|
letterbox_scaleup = True
|
|
844
924
|
target_shape = image_size
|
|
845
|
-
|
|
925
|
+
|
|
846
926
|
else:
|
|
847
|
-
|
|
927
|
+
|
|
848
928
|
letterbox_auto = False
|
|
849
929
|
letterbox_scaleup = False
|
|
850
|
-
|
|
930
|
+
|
|
851
931
|
# The padding to apply as a fraction of the stride size
|
|
852
932
|
pad = 0.5
|
|
853
|
-
|
|
933
|
+
|
|
854
934
|
model_stride = int(self.model.stride.max())
|
|
855
|
-
|
|
935
|
+
|
|
856
936
|
max_dimension = max(img_original.shape)
|
|
857
937
|
normalized_shape = [img_original.shape[0] / max_dimension,
|
|
858
938
|
img_original.shape[1] / max_dimension]
|
|
859
939
|
target_shape = np.ceil(np.array(normalized_shape) * image_size / model_stride + \
|
|
860
940
|
pad).astype(int) * model_stride
|
|
861
|
-
|
|
941
|
+
|
|
862
942
|
# Now we letterbox, which is just padding, since we've already resized.
|
|
863
|
-
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
943
|
+
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
|
|
864
944
|
new_shape=target_shape,
|
|
865
|
-
stride=self.letterbox_stride,
|
|
945
|
+
stride=self.letterbox_stride,
|
|
866
946
|
auto=letterbox_auto,
|
|
867
947
|
scaleFill=False,
|
|
868
948
|
scaleup=letterbox_scaleup)
|
|
869
|
-
|
|
949
|
+
|
|
870
950
|
if preprocess_only:
|
|
871
|
-
|
|
951
|
+
|
|
872
952
|
assert 'file' in result
|
|
873
953
|
result['img_processed'] = img
|
|
874
954
|
result['img_original'] = img_original
|
|
@@ -878,14 +958,14 @@ class PTDetector:
|
|
|
878
958
|
result['letterbox_ratio'] = letterbox_ratio
|
|
879
959
|
result['letterbox_pad'] = letterbox_pad
|
|
880
960
|
return result
|
|
881
|
-
|
|
961
|
+
|
|
882
962
|
# ...are we doing resizing here, or were images already resized?
|
|
883
|
-
|
|
963
|
+
|
|
884
964
|
# Convert HWC to CHW (which is what the model expects). The PIL Image is RGB already,
|
|
885
965
|
# so we don't need to mess with the color channels.
|
|
886
966
|
#
|
|
887
967
|
# TODO, this could be moved into the preprocessing loop
|
|
888
|
-
|
|
968
|
+
|
|
889
969
|
img = img.transpose((2, 0, 1)) # [::-1]
|
|
890
970
|
img = np.ascontiguousarray(img)
|
|
891
971
|
img = torch.from_numpy(img)
|
|
@@ -893,8 +973,8 @@ class PTDetector:
|
|
|
893
973
|
img = img.half() if self.half_precision else img.float()
|
|
894
974
|
img /= 255
|
|
895
975
|
|
|
896
|
-
# In practice this is always true
|
|
897
|
-
if len(img.shape) == 3:
|
|
976
|
+
# In practice this is always true
|
|
977
|
+
if len(img.shape) == 3:
|
|
898
978
|
img = torch.unsqueeze(img, 0)
|
|
899
979
|
|
|
900
980
|
# Run the model
|
|
@@ -908,19 +988,19 @@ class PTDetector:
|
|
|
908
988
|
else:
|
|
909
989
|
nms_conf_thres = detection_threshold # 0.01
|
|
910
990
|
nms_iou_thres = 0.6
|
|
911
|
-
nms_agnostic = False
|
|
991
|
+
nms_agnostic = False
|
|
912
992
|
nms_multi_label = True
|
|
913
|
-
|
|
993
|
+
|
|
914
994
|
# As of PyTorch 1.13.0.dev20220824, nms is not implemented for MPS.
|
|
915
995
|
#
|
|
916
|
-
# Send predictions back to the CPU for NMS.
|
|
996
|
+
# Send predictions back to the CPU for NMS.
|
|
917
997
|
if self.device == 'mps':
|
|
918
998
|
pred_nms = pred.cpu()
|
|
919
999
|
else:
|
|
920
1000
|
pred_nms = pred
|
|
921
|
-
|
|
1001
|
+
|
|
922
1002
|
# NMS
|
|
923
|
-
pred = non_max_suppression(prediction=pred_nms,
|
|
1003
|
+
pred = non_max_suppression(prediction=pred_nms,
|
|
924
1004
|
conf_thres=nms_conf_thres,
|
|
925
1005
|
iou_thres=nms_iou_thres,
|
|
926
1006
|
agnostic=nms_agnostic,
|
|
@@ -930,26 +1010,26 @@ class PTDetector:
|
|
|
930
1010
|
gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
|
|
931
1011
|
|
|
932
1012
|
if 'classic' in self.compatibility_mode:
|
|
933
|
-
|
|
1013
|
+
|
|
934
1014
|
ratio = None
|
|
935
1015
|
ratio_pad = None
|
|
936
|
-
|
|
1016
|
+
|
|
937
1017
|
else:
|
|
938
|
-
|
|
1018
|
+
|
|
939
1019
|
# letterbox_pad is a 2-tuple specifying the padding that was added on each axis.
|
|
940
1020
|
#
|
|
941
1021
|
# ratio is a 2-tuple specifying the scaling that was applied to each dimension.
|
|
942
1022
|
#
|
|
943
1023
|
# The scale_boxes function expects a 2-tuple with these things combined.
|
|
944
1024
|
ratio = (img_original.shape[0]/scaling_shape[0], img_original.shape[1]/scaling_shape[1])
|
|
945
|
-
ratio_pad = (ratio, letterbox_pad)
|
|
946
|
-
|
|
1025
|
+
ratio_pad = (ratio, letterbox_pad)
|
|
1026
|
+
|
|
947
1027
|
# This is a loop over detection batches, which will always be length 1 in our case,
|
|
948
1028
|
# since we're not doing batch inference.
|
|
949
1029
|
#
|
|
950
1030
|
# det = pred[0]
|
|
951
1031
|
#
|
|
952
|
-
# det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
|
|
1032
|
+
# det is a torch.Tensor with size [nBoxes,6]. In practice the boxes are sorted
|
|
953
1033
|
# in descending order by confidence.
|
|
954
1034
|
#
|
|
955
1035
|
# Columns are:
|
|
@@ -959,15 +1039,15 @@ class PTDetector:
|
|
|
959
1039
|
# At this point, these are *non*-normalized values, referring to the size at which we
|
|
960
1040
|
# ran inference (img.shape).
|
|
961
1041
|
for det in pred:
|
|
962
|
-
|
|
1042
|
+
|
|
963
1043
|
if len(det) == 0:
|
|
964
1044
|
continue
|
|
965
|
-
|
|
1045
|
+
|
|
966
1046
|
# Rescale boxes from img_size to im0 size, and undo the effect of padded letterboxing
|
|
967
1047
|
if 'classic' in self.compatibility_mode:
|
|
968
|
-
|
|
1048
|
+
|
|
969
1049
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img_original.shape).round()
|
|
970
|
-
|
|
1050
|
+
|
|
971
1051
|
else:
|
|
972
1052
|
# After this scaling, each element of det is a box in x0,y0,x1,y1 format, referring to the
|
|
973
1053
|
# original pixel dimension of the image, followed by the class and confidence
|
|
@@ -975,14 +1055,14 @@ class PTDetector:
|
|
|
975
1055
|
|
|
976
1056
|
# Loop over detections
|
|
977
1057
|
for *xyxy, conf, cls in reversed(det):
|
|
978
|
-
|
|
1058
|
+
|
|
979
1059
|
if conf < detection_threshold:
|
|
980
1060
|
continue
|
|
981
|
-
|
|
1061
|
+
|
|
982
1062
|
# Convert this box to normalized cx, cy, w, h (i.e., YOLO format)
|
|
983
1063
|
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
|
|
984
1064
|
|
|
985
|
-
# Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
|
|
1065
|
+
# Convert from normalized cx/cy/w/h (i.e., YOLO format) to normalized
|
|
986
1066
|
# left/top/w/h (i.e., MD format)
|
|
987
1067
|
api_box = ct_utils.convert_yolo_to_xywh(xywh)
|
|
988
1068
|
|
|
@@ -991,11 +1071,11 @@ class PTDetector:
|
|
|
991
1071
|
conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
|
|
992
1072
|
else:
|
|
993
1073
|
api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
|
|
994
|
-
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
995
|
-
|
|
1074
|
+
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
|
|
1075
|
+
|
|
996
1076
|
if not self.use_model_native_classes:
|
|
997
|
-
# The MegaDetector output format's categories start at 1, but all YOLO-based
|
|
998
|
-
# MD models have category numbers starting at 0.
|
|
1077
|
+
# The MegaDetector output format's categories start at 1, but all YOLO-based
|
|
1078
|
+
# MD models have category numbers starting at 0.
|
|
999
1079
|
cls = int(cls.tolist()) + 1
|
|
1000
1080
|
if cls not in (1, 2, 3):
|
|
1001
1081
|
raise KeyError(f'{cls} is not a valid class.')
|
|
@@ -1008,15 +1088,15 @@ class PTDetector:
|
|
|
1008
1088
|
'bbox': api_box
|
|
1009
1089
|
})
|
|
1010
1090
|
max_conf = max(max_conf, conf)
|
|
1011
|
-
|
|
1091
|
+
|
|
1012
1092
|
# ...for each detection in this batch
|
|
1013
|
-
|
|
1093
|
+
|
|
1014
1094
|
# ...for each detection batch (always one iteration)
|
|
1015
1095
|
|
|
1016
1096
|
# ...try
|
|
1017
|
-
|
|
1097
|
+
|
|
1018
1098
|
except Exception as e:
|
|
1019
|
-
|
|
1099
|
+
|
|
1020
1100
|
result['failure'] = FAILURE_INFER
|
|
1021
1101
|
print('PTDetector: image {} failed during inference: {}\n'.format(image_id, str(e)))
|
|
1022
1102
|
# traceback.print_exc(e)
|
|
@@ -1039,12 +1119,12 @@ class PTDetector:
|
|
|
1039
1119
|
if __name__ == '__main__':
|
|
1040
1120
|
|
|
1041
1121
|
pass
|
|
1042
|
-
|
|
1122
|
+
|
|
1043
1123
|
#%%
|
|
1044
|
-
|
|
1124
|
+
|
|
1045
1125
|
import os #noqa
|
|
1046
1126
|
from megadetector.visualization import visualization_utils as vis_utils
|
|
1047
|
-
|
|
1127
|
+
|
|
1048
1128
|
model_file = os.environ['MDV5A']
|
|
1049
1129
|
im_file = os.path.expanduser('~/git/MegaDetector/images/nacti.jpg')
|
|
1050
1130
|
|