megadetector 5.0.28__py3-none-any.whl → 10.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +2 -2
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +1 -1
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +1 -1
- megadetector/classification/aggregate_classifier_probs.py +3 -3
- megadetector/classification/analyze_failed_images.py +5 -5
- megadetector/classification/cache_batchapi_outputs.py +5 -5
- megadetector/classification/create_classification_dataset.py +11 -12
- megadetector/classification/crop_detections.py +10 -10
- megadetector/classification/csv_to_json.py +8 -8
- megadetector/classification/detect_and_crop.py +13 -15
- megadetector/classification/efficientnet/model.py +8 -8
- megadetector/classification/efficientnet/utils.py +6 -5
- megadetector/classification/evaluate_model.py +7 -7
- megadetector/classification/identify_mislabeled_candidates.py +6 -6
- megadetector/classification/json_to_azcopy_list.py +1 -1
- megadetector/classification/json_validator.py +29 -32
- megadetector/classification/map_classification_categories.py +9 -9
- megadetector/classification/merge_classification_detection_output.py +12 -9
- megadetector/classification/prepare_classification_script.py +19 -19
- megadetector/classification/prepare_classification_script_mc.py +26 -26
- megadetector/classification/run_classifier.py +4 -4
- megadetector/classification/save_mislabeled.py +6 -6
- megadetector/classification/train_classifier.py +1 -1
- megadetector/classification/train_classifier_tf.py +9 -9
- megadetector/classification/train_utils.py +10 -10
- megadetector/data_management/annotations/annotation_constants.py +1 -2
- megadetector/data_management/camtrap_dp_to_coco.py +79 -46
- megadetector/data_management/cct_json_utils.py +103 -103
- megadetector/data_management/cct_to_md.py +49 -49
- megadetector/data_management/cct_to_wi.py +33 -33
- megadetector/data_management/coco_to_labelme.py +75 -75
- megadetector/data_management/coco_to_yolo.py +210 -193
- megadetector/data_management/databases/add_width_and_height_to_db.py +86 -12
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +40 -40
- megadetector/data_management/databases/integrity_check_json_db.py +228 -200
- megadetector/data_management/databases/subset_json_db.py +33 -33
- megadetector/data_management/generate_crops_from_cct.py +88 -39
- megadetector/data_management/get_image_sizes.py +54 -49
- megadetector/data_management/labelme_to_coco.py +133 -125
- megadetector/data_management/labelme_to_yolo.py +159 -73
- megadetector/data_management/lila/create_lila_blank_set.py +81 -83
- megadetector/data_management/lila/create_lila_test_set.py +32 -31
- megadetector/data_management/lila/create_links_to_md_results_files.py +18 -18
- megadetector/data_management/lila/download_lila_subset.py +21 -24
- megadetector/data_management/lila/generate_lila_per_image_labels.py +365 -107
- megadetector/data_management/lila/get_lila_annotation_counts.py +35 -33
- megadetector/data_management/lila/get_lila_image_counts.py +22 -22
- megadetector/data_management/lila/lila_common.py +73 -70
- megadetector/data_management/lila/test_lila_metadata_urls.py +28 -19
- megadetector/data_management/mewc_to_md.py +344 -340
- megadetector/data_management/ocr_tools.py +262 -255
- megadetector/data_management/read_exif.py +249 -227
- megadetector/data_management/remap_coco_categories.py +90 -28
- megadetector/data_management/remove_exif.py +81 -21
- megadetector/data_management/rename_images.py +187 -187
- megadetector/data_management/resize_coco_dataset.py +588 -120
- megadetector/data_management/speciesnet_to_md.py +41 -41
- megadetector/data_management/wi_download_csv_to_coco.py +55 -55
- megadetector/data_management/yolo_output_to_md_output.py +248 -122
- megadetector/data_management/yolo_to_coco.py +333 -191
- megadetector/detection/change_detection.py +832 -0
- megadetector/detection/process_video.py +340 -337
- megadetector/detection/pytorch_detector.py +358 -278
- megadetector/detection/run_detector.py +399 -186
- megadetector/detection/run_detector_batch.py +404 -377
- megadetector/detection/run_inference_with_yolov5_val.py +340 -327
- megadetector/detection/run_tiled_inference.py +257 -249
- megadetector/detection/tf_detector.py +24 -24
- megadetector/detection/video_utils.py +332 -295
- megadetector/postprocessing/add_max_conf.py +19 -11
- megadetector/postprocessing/categorize_detections_by_size.py +45 -45
- megadetector/postprocessing/classification_postprocessing.py +468 -433
- megadetector/postprocessing/combine_batch_outputs.py +23 -23
- megadetector/postprocessing/compare_batch_results.py +590 -525
- megadetector/postprocessing/convert_output_format.py +106 -102
- megadetector/postprocessing/create_crop_folder.py +347 -147
- megadetector/postprocessing/detector_calibration.py +173 -168
- megadetector/postprocessing/generate_csv_report.py +508 -499
- megadetector/postprocessing/load_api_results.py +48 -27
- megadetector/postprocessing/md_to_coco.py +133 -102
- megadetector/postprocessing/md_to_labelme.py +107 -90
- megadetector/postprocessing/md_to_wi.py +40 -40
- megadetector/postprocessing/merge_detections.py +92 -114
- megadetector/postprocessing/postprocess_batch_results.py +319 -301
- megadetector/postprocessing/remap_detection_categories.py +91 -38
- megadetector/postprocessing/render_detection_confusion_matrix.py +214 -205
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +57 -57
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +27 -28
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +704 -679
- megadetector/postprocessing/separate_detections_into_folders.py +226 -211
- megadetector/postprocessing/subset_json_detector_output.py +265 -262
- megadetector/postprocessing/top_folders_to_bottom.py +45 -45
- megadetector/postprocessing/validate_batch_results.py +70 -70
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +52 -52
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +18 -19
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +54 -33
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +67 -67
- megadetector/taxonomy_mapping/retrieve_sample_image.py +16 -16
- megadetector/taxonomy_mapping/simple_image_download.py +8 -8
- megadetector/taxonomy_mapping/species_lookup.py +156 -74
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +14 -14
- megadetector/taxonomy_mapping/taxonomy_graph.py +10 -10
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +13 -13
- megadetector/utils/ct_utils.py +1049 -211
- megadetector/utils/directory_listing.py +21 -77
- megadetector/utils/gpu_test.py +22 -22
- megadetector/utils/md_tests.py +632 -529
- megadetector/utils/path_utils.py +1520 -431
- megadetector/utils/process_utils.py +41 -41
- megadetector/utils/split_locations_into_train_val.py +62 -62
- megadetector/utils/string_utils.py +148 -27
- megadetector/utils/url_utils.py +489 -176
- megadetector/utils/wi_utils.py +2658 -2526
- megadetector/utils/write_html_image_list.py +137 -137
- megadetector/visualization/plot_utils.py +34 -30
- megadetector/visualization/render_images_with_thumbnails.py +39 -74
- megadetector/visualization/visualization_utils.py +487 -435
- megadetector/visualization/visualize_db.py +232 -198
- megadetector/visualization/visualize_detector_output.py +82 -76
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/METADATA +5 -2
- megadetector-10.0.0.dist-info/RECORD +139 -0
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/WHEEL +1 -1
- megadetector/api/batch_processing/api_core/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/score.py +0 -439
- megadetector/api/batch_processing/api_core/server.py +0 -294
- megadetector/api/batch_processing/api_core/server_api_config.py +0 -97
- megadetector/api/batch_processing/api_core/server_app_config.py +0 -55
- megadetector/api/batch_processing/api_core/server_batch_job_manager.py +0 -220
- megadetector/api/batch_processing/api_core/server_job_status_table.py +0 -149
- megadetector/api/batch_processing/api_core/server_orchestration.py +0 -360
- megadetector/api/batch_processing/api_core/server_utils.py +0 -88
- megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
- megadetector/api/batch_processing/api_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_support/summarize_daily_activity.py +0 -152
- megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
- megadetector/api/synchronous/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +0 -151
- megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -263
- megadetector/api/synchronous/api_core/animal_detection_api/config.py +0 -35
- megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
- megadetector/api/synchronous/api_core/tests/load_test.py +0 -110
- megadetector/data_management/importers/add_nacti_sizes.py +0 -52
- megadetector/data_management/importers/add_timestamps_to_icct.py +0 -79
- megadetector/data_management/importers/animl_results_to_md_results.py +0 -158
- megadetector/data_management/importers/auckland_doc_test_to_json.py +0 -373
- megadetector/data_management/importers/auckland_doc_to_json.py +0 -201
- megadetector/data_management/importers/awc_to_json.py +0 -191
- megadetector/data_management/importers/bellevue_to_json.py +0 -272
- megadetector/data_management/importers/cacophony-thermal-importer.py +0 -793
- megadetector/data_management/importers/carrizo_shrubfree_2018.py +0 -269
- megadetector/data_management/importers/carrizo_trail_cam_2017.py +0 -289
- megadetector/data_management/importers/cct_field_adjustments.py +0 -58
- megadetector/data_management/importers/channel_islands_to_cct.py +0 -913
- megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
- megadetector/data_management/importers/eMammal/eMammal_helpers.py +0 -249
- megadetector/data_management/importers/eMammal/make_eMammal_json.py +0 -223
- megadetector/data_management/importers/ena24_to_json.py +0 -276
- megadetector/data_management/importers/filenames_to_json.py +0 -386
- megadetector/data_management/importers/helena_to_cct.py +0 -283
- megadetector/data_management/importers/idaho-camera-traps.py +0 -1407
- megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
- megadetector/data_management/importers/import_desert_lion_conservation_camera_traps.py +0 -387
- megadetector/data_management/importers/jb_csv_to_json.py +0 -150
- megadetector/data_management/importers/mcgill_to_json.py +0 -250
- megadetector/data_management/importers/missouri_to_json.py +0 -490
- megadetector/data_management/importers/nacti_fieldname_adjustments.py +0 -79
- megadetector/data_management/importers/noaa_seals_2019.py +0 -181
- megadetector/data_management/importers/osu-small-animals-to-json.py +0 -364
- megadetector/data_management/importers/pc_to_json.py +0 -365
- megadetector/data_management/importers/plot_wni_giraffes.py +0 -123
- megadetector/data_management/importers/prepare_zsl_imerit.py +0 -131
- megadetector/data_management/importers/raic_csv_to_md_results.py +0 -416
- megadetector/data_management/importers/rspb_to_json.py +0 -356
- megadetector/data_management/importers/save_the_elephants_survey_A.py +0 -320
- megadetector/data_management/importers/save_the_elephants_survey_B.py +0 -329
- megadetector/data_management/importers/snapshot_safari_importer.py +0 -758
- megadetector/data_management/importers/snapshot_serengeti_lila.py +0 -1067
- megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
- megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
- megadetector/data_management/importers/sulross_get_exif.py +0 -65
- megadetector/data_management/importers/timelapse_csv_set_to_json.py +0 -490
- megadetector/data_management/importers/ubc_to_json.py +0 -399
- megadetector/data_management/importers/umn_to_json.py +0 -507
- megadetector/data_management/importers/wellington_to_json.py +0 -263
- megadetector/data_management/importers/wi_to_json.py +0 -442
- megadetector/data_management/importers/zamba_results_to_md_results.py +0 -180
- megadetector/data_management/lila/add_locations_to_island_camera_traps.py +0 -101
- megadetector/data_management/lila/add_locations_to_nacti.py +0 -151
- megadetector/utils/azure_utils.py +0 -178
- megadetector/utils/sas_blob_utils.py +0 -509
- megadetector-5.0.28.dist-info/RECORD +0 -209
- /megadetector/{api/batch_processing/__init__.py → __init__.py} +0 -0
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/licenses/LICENSE +0 -0
- {megadetector-5.0.28.dist-info → megadetector-10.0.0.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,7 @@ import inspect
|
|
|
17
17
|
import os
|
|
18
18
|
import sys
|
|
19
19
|
import json
|
|
20
|
-
import pyexiv2
|
|
20
|
+
import pyexiv2 # type: ignore
|
|
21
21
|
import ntpath
|
|
22
22
|
import threading
|
|
23
23
|
import traceback
|
|
@@ -432,7 +432,7 @@ def args_to_object(args,obj):
|
|
|
432
432
|
setattr(obj, n, v)
|
|
433
433
|
|
|
434
434
|
|
|
435
|
-
def main():
|
|
435
|
+
def main(): # noqa
|
|
436
436
|
|
|
437
437
|
parser = argparse.ArgumentParser()
|
|
438
438
|
parser.add_argument('--input_file', help = 'Path to the MegaDetector .json file', default=None)
|
|
@@ -44,7 +44,7 @@ def main(classifier_results_csv_path: str,
|
|
|
44
44
|
Because the output CSV is often very large, we process it in chunks of 1000
|
|
45
45
|
rows at a time.
|
|
46
46
|
"""
|
|
47
|
-
|
|
47
|
+
|
|
48
48
|
chunked_df_iterator = pd.read_csv(
|
|
49
49
|
classifier_results_csv_path, chunksize=1000, float_precision='high',
|
|
50
50
|
index_col='path')
|
|
@@ -80,7 +80,7 @@ def main(classifier_results_csv_path: str,
|
|
|
80
80
|
#%% Command-line driver
|
|
81
81
|
|
|
82
82
|
def _parse_args() -> argparse.Namespace:
|
|
83
|
-
|
|
83
|
+
|
|
84
84
|
parser = argparse.ArgumentParser(
|
|
85
85
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
86
86
|
description='Aggregate classifier probabilities to target classes.')
|
|
@@ -100,7 +100,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
if __name__ == '__main__':
|
|
103
|
-
|
|
103
|
+
|
|
104
104
|
args = _parse_args()
|
|
105
105
|
main(classifier_results_csv_path=args.classifier_results_csv,
|
|
106
106
|
target_mapping_json_path=args.target_mapping,
|
|
@@ -62,7 +62,7 @@ def check_image_condition(img_path: str,
|
|
|
62
62
|
'bad': image exists, but cannot be opened even when setting
|
|
63
63
|
ImageFile.LOAD_TRUNCATED_IMAGES=True
|
|
64
64
|
"""
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
if (account is None) or (container is None) or (datasets_table is not None):
|
|
67
67
|
assert account is None
|
|
68
68
|
assert container is None
|
|
@@ -133,7 +133,7 @@ def analyze_images(url_or_path: str, json_keys: Optional[Sequence[str]] = None,
|
|
|
133
133
|
sas_token: str, optional SAS token (without leading '?') if the
|
|
134
134
|
container is not publicly accessible
|
|
135
135
|
"""
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
datasets_table = None
|
|
138
138
|
if (account is None) or (container is None):
|
|
139
139
|
assert account is None
|
|
@@ -190,8 +190,8 @@ def analyze_images(url_or_path: str, json_keys: Optional[Sequence[str]] = None,
|
|
|
190
190
|
|
|
191
191
|
#%% Command-line driver
|
|
192
192
|
|
|
193
|
-
def _parse_args() -> argparse.Namespace:
|
|
194
|
-
|
|
193
|
+
def _parse_args() -> argparse.Namespace:
|
|
194
|
+
|
|
195
195
|
parser = argparse.ArgumentParser(
|
|
196
196
|
description='Analyze a list of images that failed to download or crop.')
|
|
197
197
|
parser.add_argument(
|
|
@@ -220,7 +220,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
220
220
|
|
|
221
221
|
|
|
222
222
|
if __name__ == '__main__':
|
|
223
|
-
|
|
223
|
+
|
|
224
224
|
args = _parse_args()
|
|
225
225
|
analyze_images(url_or_path=args.failed_images, json_keys=args.json_keys,
|
|
226
226
|
account=args.account, container=args.container,
|
|
@@ -66,6 +66,7 @@ from api.batch_processing.data_preparation.prepare_api_submission import (
|
|
|
66
66
|
TaskStatus, Task)
|
|
67
67
|
from api.batch_processing.postprocessing.combine_api_outputs import (
|
|
68
68
|
combine_api_output_dictionaries)
|
|
69
|
+
from megadetector.utils import ct_utils
|
|
69
70
|
|
|
70
71
|
|
|
71
72
|
#%% Support functions
|
|
@@ -84,7 +85,7 @@ def cache_json(json_path: str,
|
|
|
84
85
|
detector_output_cache_base_dir: str
|
|
85
86
|
detector_version: str
|
|
86
87
|
"""
|
|
87
|
-
|
|
88
|
+
|
|
88
89
|
with open(json_path, 'r') as f:
|
|
89
90
|
js = json.load(f)
|
|
90
91
|
|
|
@@ -138,7 +139,7 @@ def cache_detections(detections: Mapping[str, Any], dataset: str,
|
|
|
138
139
|
|
|
139
140
|
Returns: str, message
|
|
140
141
|
"""
|
|
141
|
-
|
|
142
|
+
|
|
142
143
|
# combine detections with cache
|
|
143
144
|
dataset_cache_path = os.path.join(
|
|
144
145
|
detector_output_cache_dir, f'{dataset}.json')
|
|
@@ -155,8 +156,7 @@ def cache_detections(detections: Mapping[str, Any], dataset: str,
|
|
|
155
156
|
f'{dataset_cache_path}')
|
|
156
157
|
|
|
157
158
|
# write combined detections back out to cache
|
|
158
|
-
|
|
159
|
-
json.dump(merged_dataset_cache, f, indent=1)
|
|
159
|
+
ct_utils.write_json(dataset_cache_path, merged_dataset_cache)
|
|
160
160
|
return msg
|
|
161
161
|
|
|
162
162
|
|
|
@@ -188,7 +188,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
188
188
|
|
|
189
189
|
|
|
190
190
|
if __name__ == '__main__':
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
args = _parse_args()
|
|
193
193
|
cache_json(
|
|
194
194
|
json_path=args.json_file,
|
|
@@ -37,7 +37,7 @@ avoiding overlapping locations between the train/val/test splits.
|
|
|
37
37
|
This script outputs 3 files to <output_dir>:
|
|
38
38
|
|
|
39
39
|
1) classification_ds.csv, contains columns:
|
|
40
|
-
|
|
40
|
+
|
|
41
41
|
- 'path': str, path to cropped images
|
|
42
42
|
- 'dataset': str, name of dataset
|
|
43
43
|
- 'location': str, location that image was taken, as saved in MegaDB
|
|
@@ -75,6 +75,7 @@ import pandas as pd
|
|
|
75
75
|
from tqdm import tqdm
|
|
76
76
|
|
|
77
77
|
from megadetector.classification import detect_and_crop
|
|
78
|
+
from megadetector.utils import ct_utils
|
|
78
79
|
|
|
79
80
|
|
|
80
81
|
#%% Example usage
|
|
@@ -108,7 +109,7 @@ def main(output_dir: str,
|
|
|
108
109
|
test_frac: Optional[float],
|
|
109
110
|
splits_method: Optional[str],
|
|
110
111
|
label_spec_json_path: Optional[str]) -> None:
|
|
111
|
-
|
|
112
|
+
|
|
112
113
|
# input validation
|
|
113
114
|
assert set(mode) <= {'csv', 'splits'}
|
|
114
115
|
if label_spec_json_path is not None:
|
|
@@ -160,9 +161,8 @@ def main(output_dir: str,
|
|
|
160
161
|
labels = labels.map(lambda x: x.split(',')).explode()
|
|
161
162
|
# look into sklearn.preprocessing.MultiLabelBinarizer
|
|
162
163
|
label_names = sorted(labels.unique())
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
json.dump(dict(enumerate(label_names)), f, indent=1)
|
|
164
|
+
# Note: JSON always saves keys as strings!
|
|
165
|
+
ct_utils.write_json(os.path.join(output_dir, LABEL_INDEX_FILENAME), dict(enumerate(label_names)))
|
|
166
166
|
|
|
167
167
|
if 'splits' in mode:
|
|
168
168
|
assert splits_method is not None
|
|
@@ -181,8 +181,7 @@ def main(output_dir: str,
|
|
|
181
181
|
split_to_locs = create_splits_smallest_label_first(
|
|
182
182
|
df, val_frac, test_frac, test_split=test_set_locs,
|
|
183
183
|
label_spec_json_path=label_spec_json_path)
|
|
184
|
-
|
|
185
|
-
json.dump(split_to_locs, f, indent=1)
|
|
184
|
+
ct_utils.write_json(os.path.join(output_dir, SPLITS_FILENAME), split_to_locs)
|
|
186
185
|
|
|
187
186
|
|
|
188
187
|
#%% Support functions
|
|
@@ -236,7 +235,7 @@ def create_classification_csv(
|
|
|
236
235
|
'missing crops': list of tuple (img_path, i), where i is the
|
|
237
236
|
i-th crop index
|
|
238
237
|
"""
|
|
239
|
-
|
|
238
|
+
|
|
240
239
|
assert 0 <= confidence_threshold <= 1
|
|
241
240
|
|
|
242
241
|
columns = [
|
|
@@ -359,7 +358,7 @@ def create_splits_random(df: pd.DataFrame, val_frac: float,
|
|
|
359
358
|
Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
|
|
360
359
|
where each loc is a tuple (dataset, location)
|
|
361
360
|
"""
|
|
362
|
-
|
|
361
|
+
|
|
363
362
|
if test_split is not None:
|
|
364
363
|
assert test_frac == 0
|
|
365
364
|
train_frac = 1. - val_frac - test_frac
|
|
@@ -445,7 +444,7 @@ def create_splits_smallest_label_first(
|
|
|
445
444
|
Returns: dict, keys are ['train', 'val', 'test'], values are lists of locs,
|
|
446
445
|
where each loc is a tuple (dataset, location)
|
|
447
446
|
"""
|
|
448
|
-
|
|
447
|
+
|
|
449
448
|
# label => list of datasets to prioritize for test and validation sets
|
|
450
449
|
prioritize = {}
|
|
451
450
|
if label_spec_json_path is not None:
|
|
@@ -525,7 +524,7 @@ def sort_locs_by_size(loc_to_size: MutableMapping[tuple[str, str], int],
|
|
|
525
524
|
Returns: list of (dataset, location) tuples, ordered from smallest size to
|
|
526
525
|
largest. Locations from prioritized datasets come first.
|
|
527
526
|
"""
|
|
528
|
-
|
|
527
|
+
|
|
529
528
|
result = []
|
|
530
529
|
if prioritize is not None:
|
|
531
530
|
# modify loc_to_size in place, so copy its keys before iterating
|
|
@@ -610,7 +609,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
610
609
|
|
|
611
610
|
|
|
612
611
|
if __name__ == '__main__':
|
|
613
|
-
|
|
612
|
+
|
|
614
613
|
args = _parse_args()
|
|
615
614
|
main(output_dir=args.output_dir,
|
|
616
615
|
mode=args.mode,
|
|
@@ -33,7 +33,7 @@ bounding box width or height. In the case that the square crop boundaries exceed
|
|
|
33
33
|
the original image size, the crop is padded with 0s.
|
|
34
34
|
|
|
35
35
|
This script outputs a log file to:
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
<output_dir>/crop_detections_log_{timestamp}.json
|
|
38
38
|
|
|
39
39
|
...which contains images that failed to download and crop properly.
|
|
@@ -107,7 +107,7 @@ def main(detections_json_path: str,
|
|
|
107
107
|
threads: int, number of threads to use for downloading images
|
|
108
108
|
logdir: str, path to directory to save log file
|
|
109
109
|
"""
|
|
110
|
-
|
|
110
|
+
|
|
111
111
|
# error checking
|
|
112
112
|
assert 0 <= confidence_threshold <= 1, \
|
|
113
113
|
'Invalid confidence threshold {}'.format(confidence_threshold)
|
|
@@ -149,7 +149,7 @@ def main(detections_json_path: str,
|
|
|
149
149
|
for d in info_dict['detections']:
|
|
150
150
|
if d['category'] not in detection_categories:
|
|
151
151
|
print('Warning: ignoring detection with category {} for image {}'.format(
|
|
152
|
-
d['category'],img_path))
|
|
152
|
+
d['category'],img_path))
|
|
153
153
|
# This will be removed later when we filter for animals
|
|
154
154
|
d['category'] = 'unsupported'
|
|
155
155
|
else:
|
|
@@ -235,7 +235,7 @@ def download_and_crop(
|
|
|
235
235
|
total_downloads: int, number of images downloaded
|
|
236
236
|
total_new_crops: int, number of new crops saved to cropped_images_dir
|
|
237
237
|
"""
|
|
238
|
-
|
|
238
|
+
|
|
239
239
|
# True for ground truth, False for MegaDetector
|
|
240
240
|
# always save as .jpg for consistency
|
|
241
241
|
crop_path_template = {
|
|
@@ -297,7 +297,7 @@ def load_local_image(img_path: str | BinaryIO) -> Optional[Image.Image]:
|
|
|
297
297
|
"""
|
|
298
298
|
Attempts to load an image from a local path.
|
|
299
299
|
"""
|
|
300
|
-
|
|
300
|
+
|
|
301
301
|
try:
|
|
302
302
|
with Image.open(img_path) as img:
|
|
303
303
|
img.load()
|
|
@@ -347,7 +347,7 @@ def load_and_crop(img_path: str,
|
|
|
347
347
|
did_download: bool, whether image was downloaded from Azure Blob Storage
|
|
348
348
|
num_new_crops: int, number of new crops successfully saved
|
|
349
349
|
"""
|
|
350
|
-
|
|
350
|
+
|
|
351
351
|
did_download = False
|
|
352
352
|
num_new_crops = 0
|
|
353
353
|
|
|
@@ -393,7 +393,7 @@ def load_and_crop(img_path: str,
|
|
|
393
393
|
|
|
394
394
|
assert img is not None, 'image "{}" failed to load or download properly'.format(
|
|
395
395
|
debug_path)
|
|
396
|
-
|
|
396
|
+
|
|
397
397
|
if img.mode != 'RGB':
|
|
398
398
|
img = img.convert(mode='RGB') # always save as RGB for consistency
|
|
399
399
|
|
|
@@ -418,7 +418,7 @@ def save_crop(img: Image.Image, bbox_norm: Sequence[float], square_crop: bool,
|
|
|
418
418
|
|
|
419
419
|
Returns: bool, True if a crop was saved, False otherwise
|
|
420
420
|
"""
|
|
421
|
-
|
|
421
|
+
|
|
422
422
|
img_w, img_h = img.size
|
|
423
423
|
xmin = int(bbox_norm[0] * img_w)
|
|
424
424
|
ymin = int(bbox_norm[1] * img_h)
|
|
@@ -456,7 +456,7 @@ def save_crop(img: Image.Image, bbox_norm: Sequence[float], square_crop: bool,
|
|
|
456
456
|
#%% Command-line driver
|
|
457
457
|
|
|
458
458
|
def _parse_args() -> argparse.Namespace:
|
|
459
|
-
|
|
459
|
+
|
|
460
460
|
parser = argparse.ArgumentParser(
|
|
461
461
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
462
462
|
description='Crop detections from MegaDetector.')
|
|
@@ -501,7 +501,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
501
501
|
|
|
502
502
|
|
|
503
503
|
if __name__ == '__main__':
|
|
504
|
-
|
|
504
|
+
|
|
505
505
|
args = _parse_args()
|
|
506
506
|
main(detections_json_path=args.detections_json,
|
|
507
507
|
cropped_images_dir=args.cropped_images_dir,
|
|
@@ -40,7 +40,7 @@ Example CSV input:
|
|
|
40
40
|
|
|
41
41
|
Example JSON output:
|
|
42
42
|
|
|
43
|
-
"
|
|
43
|
+
"
|
|
44
44
|
{
|
|
45
45
|
"cervid": {
|
|
46
46
|
"dataset_labels": {
|
|
@@ -107,17 +107,17 @@ import json
|
|
|
107
107
|
from typing import Any
|
|
108
108
|
|
|
109
109
|
import pandas as pd
|
|
110
|
+
from megadetector.utils import ct_utils
|
|
110
111
|
|
|
111
112
|
|
|
112
113
|
#%% Main function
|
|
113
114
|
|
|
114
|
-
def main():
|
|
115
|
+
def main(): # noqa
|
|
115
116
|
args = _parse_args()
|
|
116
117
|
js = csv_to_jsondict(args.input_csv_file)
|
|
117
118
|
for label in js:
|
|
118
119
|
js[label] = order_spec_dict(js[label])
|
|
119
|
-
|
|
120
|
-
json.dump(js, f, indent=args.json_indent)
|
|
120
|
+
ct_utils.write_json(args.output_json_path, js, indent=args.json_indent)
|
|
121
121
|
|
|
122
122
|
|
|
123
123
|
#%% Support functions
|
|
@@ -126,7 +126,7 @@ def parse_csv_row(obj: dict[str, Any], rowtype: str, content: str) -> None:
|
|
|
126
126
|
"""
|
|
127
127
|
Parses a row in the CSV.
|
|
128
128
|
"""
|
|
129
|
-
|
|
129
|
+
|
|
130
130
|
if rowtype == 'row':
|
|
131
131
|
if 'dataset_labels' not in obj:
|
|
132
132
|
obj['dataset_labels'] = defaultdict(list)
|
|
@@ -169,7 +169,7 @@ def csv_to_jsondict(csv_path: str) -> dict[str, dict[str, Any]]:
|
|
|
169
169
|
"""
|
|
170
170
|
Converts CSV to json-style dictionary.
|
|
171
171
|
"""
|
|
172
|
-
|
|
172
|
+
|
|
173
173
|
df = pd.read_csv(csv_path, comment='#', skip_blank_lines=True)
|
|
174
174
|
assert (df.columns == ['output_label', 'type', 'content']).all()
|
|
175
175
|
|
|
@@ -193,7 +193,7 @@ def order_spec_dict(spec_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
193
193
|
"""
|
|
194
194
|
Returns spec_dict with keys in a specific order.
|
|
195
195
|
"""
|
|
196
|
-
|
|
196
|
+
|
|
197
197
|
if 'exclude' in spec_dict:
|
|
198
198
|
spec_dict['exclude'] = order_spec_dict(spec_dict['exclude'])
|
|
199
199
|
ordered_spec_dict: dict[str, Any] = {}
|
|
@@ -206,7 +206,7 @@ def order_spec_dict(spec_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
206
206
|
#%% Command-line driver
|
|
207
207
|
|
|
208
208
|
def _parse_args() -> argparse.Namespace:
|
|
209
|
-
|
|
209
|
+
|
|
210
210
|
parser = argparse.ArgumentParser(
|
|
211
211
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
212
212
|
description='Converts CSV to JSON format for label specification.')
|
|
@@ -118,6 +118,7 @@ from megadetector.classification.crop_detections import load_and_crop
|
|
|
118
118
|
from megadetector.data_management.megadb import megadb_utils
|
|
119
119
|
from megadetector.utils import path_utils
|
|
120
120
|
from megadetector.utils import sas_blob_utils
|
|
121
|
+
from megadetector.utils import ct_utils
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
#%% Example usage
|
|
@@ -264,8 +265,7 @@ def main(queried_images_json_path: str,
|
|
|
264
265
|
# save log of bad images
|
|
265
266
|
date = datetime.now().strftime('%Y%m%d_%H%M%S') # e.g., '20200722_110816'
|
|
266
267
|
log_path = os.path.join(output_dir, f'detect_and_crop_log_{date}.json')
|
|
267
|
-
|
|
268
|
-
json.dump(log, f, indent=1)
|
|
268
|
+
ct_utils.write_json(log_path, log)
|
|
269
269
|
|
|
270
270
|
|
|
271
271
|
#%% Support functions
|
|
@@ -291,7 +291,7 @@ def load_detection_cache(detector_output_cache_dir: str,
|
|
|
291
291
|
if no cached detections were found for the given dataset ds.
|
|
292
292
|
detection_categories: dict, maps str category ID to str category name
|
|
293
293
|
"""
|
|
294
|
-
|
|
294
|
+
|
|
295
295
|
# cache of Detector outputs: dataset name => {img_path => detection_dict}
|
|
296
296
|
detection_cache = {}
|
|
297
297
|
detection_categories: dict[str, str] = {}
|
|
@@ -339,7 +339,7 @@ def filter_detected_images(
|
|
|
339
339
|
detection_categories: dict, maps str category ID to str category name,
|
|
340
340
|
empty dict if no cached detections are found
|
|
341
341
|
"""
|
|
342
|
-
|
|
342
|
+
|
|
343
343
|
datasets = set(img_path[:img_path.find('/')]
|
|
344
344
|
for img_path in potential_images_to_detect)
|
|
345
345
|
detection_cache, detection_categories = load_detection_cache(
|
|
@@ -364,7 +364,7 @@ def split_images_list_by_dataset(images_to_detect: Iterable[str]
|
|
|
364
364
|
|
|
365
365
|
Returns: dict, maps dataset name to a list of image paths
|
|
366
366
|
"""
|
|
367
|
-
|
|
367
|
+
|
|
368
368
|
images_by_dataset: dict[str, list[str]] = {}
|
|
369
369
|
for img_path in images_to_detect:
|
|
370
370
|
dataset = img_path[:img_path.find('/')]
|
|
@@ -402,7 +402,7 @@ def submit_batch_detection_api(images_to_detect: Iterable[str],
|
|
|
402
402
|
|
|
403
403
|
Returns: dict, maps str dataset name to list of Task objects
|
|
404
404
|
"""
|
|
405
|
-
|
|
405
|
+
|
|
406
406
|
filtered_images_to_detect = [
|
|
407
407
|
x for x in images_to_detect if path_utils.is_image_file(x)]
|
|
408
408
|
not_images = set(images_to_detect) - set(filtered_images_to_detect)
|
|
@@ -450,8 +450,7 @@ def submit_batch_detection_api(images_to_detect: Iterable[str],
|
|
|
450
450
|
for dataset in tasks_by_dataset
|
|
451
451
|
for task in tasks_by_dataset[dataset]
|
|
452
452
|
]
|
|
453
|
-
|
|
454
|
-
json.dump(resume_json, f, indent=1)
|
|
453
|
+
ct_utils.write_json(resume_file_path, resume_json)
|
|
455
454
|
return tasks_by_dataset
|
|
456
455
|
|
|
457
456
|
|
|
@@ -478,7 +477,7 @@ def submit_batch_detection_api_by_dataset(
|
|
|
478
477
|
|
|
479
478
|
Returns: list of Task objects
|
|
480
479
|
"""
|
|
481
|
-
|
|
480
|
+
|
|
482
481
|
os.makedirs(task_lists_dir, exist_ok=True)
|
|
483
482
|
|
|
484
483
|
date = datetime.now().strftime('%Y%m%d_%H%M%S') # e.g., '20200722_110816'
|
|
@@ -522,7 +521,7 @@ def resume_tasks(resume_file_path: str, batch_detection_api_url: str
|
|
|
522
521
|
|
|
523
522
|
Returns: dict, maps str dataset name to list of Task objects
|
|
524
523
|
"""
|
|
525
|
-
|
|
524
|
+
|
|
526
525
|
with open(resume_file_path, 'r') as f:
|
|
527
526
|
resume_json = json.load(f)
|
|
528
527
|
|
|
@@ -559,7 +558,7 @@ def wait_for_tasks(tasks_by_dataset: Mapping[str, Iterable[Task]],
|
|
|
559
558
|
saved to <output_dir>/batchapi_response/{task_id}.json
|
|
560
559
|
poll_interval: int, # of seconds between pinging the task status API
|
|
561
560
|
"""
|
|
562
|
-
|
|
561
|
+
|
|
563
562
|
remaining_tasks: list[tuple[str, Task]] = [
|
|
564
563
|
(dataset, task) for dataset, tasks in tasks_by_dataset.items()
|
|
565
564
|
for task in tasks]
|
|
@@ -597,8 +596,7 @@ def wait_for_tasks(tasks_by_dataset: Mapping[str, Iterable[Task]],
|
|
|
597
596
|
if not os.path.exists(save_dir):
|
|
598
597
|
tqdm.write(f'Creating API output dir: {save_dir}')
|
|
599
598
|
os.makedirs(save_dir)
|
|
600
|
-
|
|
601
|
-
json.dump(task.response, f, indent=1)
|
|
599
|
+
ct_utils.write_json(os.path.join(save_dir, f'{task.id}.json'), task.response)
|
|
602
600
|
message = task.response['Status']['message']
|
|
603
601
|
num_failed_shards = message['num_failed_shards']
|
|
604
602
|
if num_failed_shards != 0:
|
|
@@ -676,7 +674,7 @@ def download_and_crop(
|
|
|
676
674
|
Returns: list of str, images with bounding boxes that failed to download or
|
|
677
675
|
crop properly
|
|
678
676
|
"""
|
|
679
|
-
|
|
677
|
+
|
|
680
678
|
# error checking before we download and crop any images
|
|
681
679
|
valid_img_paths = set(queried_images_json.keys())
|
|
682
680
|
if images_missing_detections is not None:
|
|
@@ -838,7 +836,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
838
836
|
|
|
839
837
|
|
|
840
838
|
if __name__ == '__main__':
|
|
841
|
-
|
|
839
|
+
|
|
842
840
|
args = _parse_args()
|
|
843
841
|
main(queried_images_json_path=args.queried_images_json,
|
|
844
842
|
output_dir=args.output_dir,
|
|
@@ -93,7 +93,7 @@ class MBConvBlock(nn.Module):
|
|
|
93
93
|
|
|
94
94
|
Args:
|
|
95
95
|
inputs (tensor): Input tensor.
|
|
96
|
-
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
|
96
|
+
drop_connect_rate (bool, optional): Drop connect rate (float, between 0 and 1).
|
|
97
97
|
|
|
98
98
|
Returns:
|
|
99
99
|
Output of this block after processing.
|
|
@@ -135,7 +135,7 @@ class MBConvBlock(nn.Module):
|
|
|
135
135
|
"""Sets swish function as memory efficient (for training) or standard (for export).
|
|
136
136
|
|
|
137
137
|
Args:
|
|
138
|
-
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
|
138
|
+
memory_efficient (bool, optional): Whether to use memory-efficient version of swish.
|
|
139
139
|
"""
|
|
140
140
|
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
|
141
141
|
|
|
@@ -221,7 +221,7 @@ class EfficientNet(nn.Module):
|
|
|
221
221
|
"""Sets swish function as memory efficient (for training) or standard (for export).
|
|
222
222
|
|
|
223
223
|
Args:
|
|
224
|
-
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
|
224
|
+
memory_efficient (bool, optional): Whether to use memory-efficient version of swish.
|
|
225
225
|
|
|
226
226
|
"""
|
|
227
227
|
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
|
@@ -323,7 +323,7 @@ class EfficientNet(nn.Module):
|
|
|
323
323
|
|
|
324
324
|
Args:
|
|
325
325
|
model_name (str): Name for efficientnet.
|
|
326
|
-
in_channels (int): Input data's channel number.
|
|
326
|
+
in_channels (int, optional): Input data's channel number.
|
|
327
327
|
override_params (other key word params):
|
|
328
328
|
Params to override model's global_params.
|
|
329
329
|
Optional key:
|
|
@@ -349,14 +349,14 @@ class EfficientNet(nn.Module):
|
|
|
349
349
|
|
|
350
350
|
Args:
|
|
351
351
|
model_name (str): Name for efficientnet.
|
|
352
|
-
weights_path (None or str):
|
|
352
|
+
weights_path (None or str, optional):
|
|
353
353
|
str: path to pretrained weights file on the local disk.
|
|
354
354
|
None: use pretrained weights downloaded from the Internet.
|
|
355
|
-
advprop (bool):
|
|
355
|
+
advprop (bool, optional):
|
|
356
356
|
Whether to load pretrained weights
|
|
357
357
|
trained with advprop (valid when weights_path is None).
|
|
358
|
-
in_channels (int): Input data's channel number.
|
|
359
|
-
num_classes (int):
|
|
358
|
+
in_channels (int, optional): Input data's channel number.
|
|
359
|
+
num_classes (int, optional):
|
|
360
360
|
Number of categories for classification.
|
|
361
361
|
It controls the output size for final linear layer.
|
|
362
362
|
override_params (other key word params):
|
|
@@ -194,7 +194,7 @@ def get_same_padding_conv2d(image_size=None):
|
|
|
194
194
|
Static padding is necessary for ONNX exporting of models.
|
|
195
195
|
|
|
196
196
|
Args:
|
|
197
|
-
image_size (int or tuple): Size of the image.
|
|
197
|
+
image_size (int or tuple, optional): Size of the image.
|
|
198
198
|
|
|
199
199
|
Returns:
|
|
200
200
|
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
|
@@ -274,7 +274,7 @@ def get_same_padding_maxPool2d(image_size=None):
|
|
|
274
274
|
Static padding is necessary for ONNX exporting of models.
|
|
275
275
|
|
|
276
276
|
Args:
|
|
277
|
-
image_size (int or tuple): Size of the image.
|
|
277
|
+
image_size (int or tuple, optional): Size of the image.
|
|
278
278
|
|
|
279
279
|
Returns:
|
|
280
280
|
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
|
@@ -579,11 +579,12 @@ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True,
|
|
|
579
579
|
Args:
|
|
580
580
|
model (Module): The whole model of efficientnet.
|
|
581
581
|
model_name (str): Model name of efficientnet.
|
|
582
|
-
weights_path (None or str):
|
|
582
|
+
weights_path (None or str, optional):
|
|
583
583
|
str: path to pretrained weights file on the local disk.
|
|
584
584
|
None: use pretrained weights downloaded from the Internet.
|
|
585
|
-
load_fc (bool): Whether to load pretrained weights for fc layer at the end
|
|
586
|
-
|
|
585
|
+
load_fc (bool, optional): Whether to load pretrained weights for fc layer at the end
|
|
586
|
+
of the model.
|
|
587
|
+
advprop (bool, optional): Whether to load pretrained weights
|
|
587
588
|
trained with advprop (valid when weights_path is None).
|
|
588
589
|
"""
|
|
589
590
|
if isinstance(weights_path, str):
|
|
@@ -52,6 +52,7 @@ import torchvision
|
|
|
52
52
|
import tqdm
|
|
53
53
|
|
|
54
54
|
from megadetector.classification import efficientnet, train_classifier
|
|
55
|
+
from megadetector.utils import ct_utils
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
#%% Example usage
|
|
@@ -75,7 +76,7 @@ def check_override(params: Mapping[str, Any], key: str,
|
|
|
75
76
|
"""
|
|
76
77
|
Return desired value, with optional override.
|
|
77
78
|
"""
|
|
78
|
-
|
|
79
|
+
|
|
79
80
|
if override is None:
|
|
80
81
|
return params[key]
|
|
81
82
|
saved = params.get(key, None)
|
|
@@ -102,7 +103,7 @@ def trace_model(model_name: str, ckpt_path: str, num_classes: int,
|
|
|
102
103
|
'/path/to/ckpt_16.pt', then the returned path is
|
|
103
104
|
'/path/to/ckpt_16_compiled.pt'.
|
|
104
105
|
"""
|
|
105
|
-
|
|
106
|
+
|
|
106
107
|
root, ext = os.path.splitext(ckpt_path)
|
|
107
108
|
compiled_path = root + '_compiled' + ext
|
|
108
109
|
if os.path.exists(compiled_path):
|
|
@@ -135,7 +136,7 @@ def calc_per_label_stats(cm: np.ndarray, label_names: Sequence[str]
|
|
|
135
136
|
recall values are in [0, 1], or np.nan if that label had 0 ground-truth
|
|
136
137
|
observations
|
|
137
138
|
"""
|
|
138
|
-
|
|
139
|
+
|
|
139
140
|
tp = np.diag(cm) # true positives
|
|
140
141
|
|
|
141
142
|
predicted_positives = cm.sum(axis=0, dtype=np.float64) # tp + fp
|
|
@@ -186,7 +187,7 @@ def test_epoch(model: torch.nn.Module,
|
|
|
186
187
|
cm: np.ndarray, confusion matrix C such that C[i,j] is the # of
|
|
187
188
|
observations known to be in group i and predicted to be in group j
|
|
188
189
|
"""
|
|
189
|
-
|
|
190
|
+
|
|
190
191
|
# set dropout and BN layers to eval mode
|
|
191
192
|
model.eval()
|
|
192
193
|
|
|
@@ -395,8 +396,7 @@ def main(params_json_path: str, ckpt_path: str, output_dir: str,
|
|
|
395
396
|
assert target_names == set(label_names) | {'other'}
|
|
396
397
|
label_names.append('other')
|
|
397
398
|
|
|
398
|
-
|
|
399
|
-
json.dump(dict(enumerate(label_names)), f)
|
|
399
|
+
ct_utils.write_json(os.path.join(output_dir, 'label_index.json'), dict(enumerate(label_names)), indent=None)
|
|
400
400
|
|
|
401
401
|
with open(label_index_json_path, 'r') as f:
|
|
402
402
|
idx_to_label = json.load(f)
|
|
@@ -510,7 +510,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
510
510
|
|
|
511
511
|
|
|
512
512
|
if __name__ == '__main__':
|
|
513
|
-
|
|
513
|
+
|
|
514
514
|
args = _parse_args()
|
|
515
515
|
main(params_json_path=args.params_json, ckpt_path=args.ckpt_path,
|
|
516
516
|
output_dir=args.output_dir, splits=args.splits,
|