megadetector 5.0.10__py3-none-any.whl → 5.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/LICENSE +0 -0
- {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/METADATA +12 -11
- megadetector-5.0.11.dist-info/RECORD +5 -0
- megadetector-5.0.11.dist-info/top_level.txt +1 -0
- api/__init__.py +0 -0
- api/batch_processing/__init__.py +0 -0
- api/batch_processing/api_core/__init__.py +0 -0
- api/batch_processing/api_core/batch_service/__init__.py +0 -0
- api/batch_processing/api_core/batch_service/score.py +0 -439
- api/batch_processing/api_core/server.py +0 -294
- api/batch_processing/api_core/server_api_config.py +0 -98
- api/batch_processing/api_core/server_app_config.py +0 -55
- api/batch_processing/api_core/server_batch_job_manager.py +0 -220
- api/batch_processing/api_core/server_job_status_table.py +0 -152
- api/batch_processing/api_core/server_orchestration.py +0 -360
- api/batch_processing/api_core/server_utils.py +0 -92
- api/batch_processing/api_core_support/__init__.py +0 -0
- api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
- api/batch_processing/api_support/__init__.py +0 -0
- api/batch_processing/api_support/summarize_daily_activity.py +0 -152
- api/batch_processing/data_preparation/__init__.py +0 -0
- api/batch_processing/data_preparation/manage_local_batch.py +0 -2391
- api/batch_processing/data_preparation/manage_video_batch.py +0 -327
- api/batch_processing/integration/digiKam/setup.py +0 -6
- api/batch_processing/integration/digiKam/xmp_integration.py +0 -465
- api/batch_processing/integration/eMammal/test_scripts/config_template.py +0 -5
- api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +0 -126
- api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +0 -55
- api/batch_processing/postprocessing/__init__.py +0 -0
- api/batch_processing/postprocessing/add_max_conf.py +0 -64
- api/batch_processing/postprocessing/categorize_detections_by_size.py +0 -163
- api/batch_processing/postprocessing/combine_api_outputs.py +0 -249
- api/batch_processing/postprocessing/compare_batch_results.py +0 -958
- api/batch_processing/postprocessing/convert_output_format.py +0 -397
- api/batch_processing/postprocessing/load_api_results.py +0 -195
- api/batch_processing/postprocessing/md_to_coco.py +0 -310
- api/batch_processing/postprocessing/md_to_labelme.py +0 -330
- api/batch_processing/postprocessing/merge_detections.py +0 -401
- api/batch_processing/postprocessing/postprocess_batch_results.py +0 -1904
- api/batch_processing/postprocessing/remap_detection_categories.py +0 -170
- api/batch_processing/postprocessing/render_detection_confusion_matrix.py +0 -661
- api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +0 -211
- api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +0 -82
- api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +0 -1631
- api/batch_processing/postprocessing/separate_detections_into_folders.py +0 -731
- api/batch_processing/postprocessing/subset_json_detector_output.py +0 -696
- api/batch_processing/postprocessing/top_folders_to_bottom.py +0 -223
- api/synchronous/__init__.py +0 -0
- api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
- api/synchronous/api_core/animal_detection_api/api_backend.py +0 -152
- api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -266
- api/synchronous/api_core/animal_detection_api/config.py +0 -35
- api/synchronous/api_core/animal_detection_api/data_management/annotations/annotation_constants.py +0 -47
- api/synchronous/api_core/animal_detection_api/detection/detector_training/copy_checkpoints.py +0 -43
- api/synchronous/api_core/animal_detection_api/detection/detector_training/model_main_tf2.py +0 -114
- api/synchronous/api_core/animal_detection_api/detection/process_video.py +0 -543
- api/synchronous/api_core/animal_detection_api/detection/pytorch_detector.py +0 -304
- api/synchronous/api_core/animal_detection_api/detection/run_detector.py +0 -627
- api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +0 -1029
- api/synchronous/api_core/animal_detection_api/detection/run_inference_with_yolov5_val.py +0 -581
- api/synchronous/api_core/animal_detection_api/detection/run_tiled_inference.py +0 -754
- api/synchronous/api_core/animal_detection_api/detection/tf_detector.py +0 -165
- api/synchronous/api_core/animal_detection_api/detection/video_utils.py +0 -495
- api/synchronous/api_core/animal_detection_api/md_utils/azure_utils.py +0 -174
- api/synchronous/api_core/animal_detection_api/md_utils/ct_utils.py +0 -262
- api/synchronous/api_core/animal_detection_api/md_utils/directory_listing.py +0 -251
- api/synchronous/api_core/animal_detection_api/md_utils/matlab_porting_tools.py +0 -97
- api/synchronous/api_core/animal_detection_api/md_utils/path_utils.py +0 -416
- api/synchronous/api_core/animal_detection_api/md_utils/process_utils.py +0 -110
- api/synchronous/api_core/animal_detection_api/md_utils/sas_blob_utils.py +0 -509
- api/synchronous/api_core/animal_detection_api/md_utils/string_utils.py +0 -59
- api/synchronous/api_core/animal_detection_api/md_utils/url_utils.py +0 -144
- api/synchronous/api_core/animal_detection_api/md_utils/write_html_image_list.py +0 -226
- api/synchronous/api_core/animal_detection_api/md_visualization/visualization_utils.py +0 -841
- api/synchronous/api_core/tests/__init__.py +0 -0
- api/synchronous/api_core/tests/load_test.py +0 -110
- classification/__init__.py +0 -0
- classification/aggregate_classifier_probs.py +0 -108
- classification/analyze_failed_images.py +0 -227
- classification/cache_batchapi_outputs.py +0 -198
- classification/create_classification_dataset.py +0 -627
- classification/crop_detections.py +0 -516
- classification/csv_to_json.py +0 -226
- classification/detect_and_crop.py +0 -855
- classification/efficientnet/__init__.py +0 -9
- classification/efficientnet/model.py +0 -415
- classification/efficientnet/utils.py +0 -610
- classification/evaluate_model.py +0 -520
- classification/identify_mislabeled_candidates.py +0 -152
- classification/json_to_azcopy_list.py +0 -63
- classification/json_validator.py +0 -695
- classification/map_classification_categories.py +0 -276
- classification/merge_classification_detection_output.py +0 -506
- classification/prepare_classification_script.py +0 -194
- classification/prepare_classification_script_mc.py +0 -228
- classification/run_classifier.py +0 -286
- classification/save_mislabeled.py +0 -110
- classification/train_classifier.py +0 -825
- classification/train_classifier_tf.py +0 -724
- classification/train_utils.py +0 -322
- data_management/__init__.py +0 -0
- data_management/annotations/__init__.py +0 -0
- data_management/annotations/annotation_constants.py +0 -34
- data_management/camtrap_dp_to_coco.py +0 -238
- data_management/cct_json_utils.py +0 -395
- data_management/cct_to_md.py +0 -176
- data_management/cct_to_wi.py +0 -289
- data_management/coco_to_labelme.py +0 -272
- data_management/coco_to_yolo.py +0 -662
- data_management/databases/__init__.py +0 -0
- data_management/databases/add_width_and_height_to_db.py +0 -33
- data_management/databases/combine_coco_camera_traps_files.py +0 -206
- data_management/databases/integrity_check_json_db.py +0 -477
- data_management/databases/subset_json_db.py +0 -115
- data_management/generate_crops_from_cct.py +0 -149
- data_management/get_image_sizes.py +0 -188
- data_management/importers/add_nacti_sizes.py +0 -52
- data_management/importers/add_timestamps_to_icct.py +0 -79
- data_management/importers/animl_results_to_md_results.py +0 -158
- data_management/importers/auckland_doc_test_to_json.py +0 -372
- data_management/importers/auckland_doc_to_json.py +0 -200
- data_management/importers/awc_to_json.py +0 -189
- data_management/importers/bellevue_to_json.py +0 -273
- data_management/importers/cacophony-thermal-importer.py +0 -796
- data_management/importers/carrizo_shrubfree_2018.py +0 -268
- data_management/importers/carrizo_trail_cam_2017.py +0 -287
- data_management/importers/cct_field_adjustments.py +0 -57
- data_management/importers/channel_islands_to_cct.py +0 -913
- data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
- data_management/importers/eMammal/eMammal_helpers.py +0 -249
- data_management/importers/eMammal/make_eMammal_json.py +0 -223
- data_management/importers/ena24_to_json.py +0 -275
- data_management/importers/filenames_to_json.py +0 -385
- data_management/importers/helena_to_cct.py +0 -282
- data_management/importers/idaho-camera-traps.py +0 -1407
- data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
- data_management/importers/jb_csv_to_json.py +0 -150
- data_management/importers/mcgill_to_json.py +0 -250
- data_management/importers/missouri_to_json.py +0 -489
- data_management/importers/nacti_fieldname_adjustments.py +0 -79
- data_management/importers/noaa_seals_2019.py +0 -181
- data_management/importers/pc_to_json.py +0 -365
- data_management/importers/plot_wni_giraffes.py +0 -123
- data_management/importers/prepare-noaa-fish-data-for-lila.py +0 -359
- data_management/importers/prepare_zsl_imerit.py +0 -131
- data_management/importers/rspb_to_json.py +0 -356
- data_management/importers/save_the_elephants_survey_A.py +0 -320
- data_management/importers/save_the_elephants_survey_B.py +0 -332
- data_management/importers/snapshot_safari_importer.py +0 -758
- data_management/importers/snapshot_safari_importer_reprise.py +0 -665
- data_management/importers/snapshot_serengeti_lila.py +0 -1067
- data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
- data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
- data_management/importers/sulross_get_exif.py +0 -65
- data_management/importers/timelapse_csv_set_to_json.py +0 -490
- data_management/importers/ubc_to_json.py +0 -399
- data_management/importers/umn_to_json.py +0 -507
- data_management/importers/wellington_to_json.py +0 -263
- data_management/importers/wi_to_json.py +0 -441
- data_management/importers/zamba_results_to_md_results.py +0 -181
- data_management/labelme_to_coco.py +0 -548
- data_management/labelme_to_yolo.py +0 -272
- data_management/lila/__init__.py +0 -0
- data_management/lila/add_locations_to_island_camera_traps.py +0 -97
- data_management/lila/add_locations_to_nacti.py +0 -147
- data_management/lila/create_lila_blank_set.py +0 -557
- data_management/lila/create_lila_test_set.py +0 -151
- data_management/lila/create_links_to_md_results_files.py +0 -106
- data_management/lila/download_lila_subset.py +0 -177
- data_management/lila/generate_lila_per_image_labels.py +0 -515
- data_management/lila/get_lila_annotation_counts.py +0 -170
- data_management/lila/get_lila_image_counts.py +0 -111
- data_management/lila/lila_common.py +0 -300
- data_management/lila/test_lila_metadata_urls.py +0 -132
- data_management/ocr_tools.py +0 -874
- data_management/read_exif.py +0 -681
- data_management/remap_coco_categories.py +0 -84
- data_management/remove_exif.py +0 -66
- data_management/resize_coco_dataset.py +0 -189
- data_management/wi_download_csv_to_coco.py +0 -246
- data_management/yolo_output_to_md_output.py +0 -441
- data_management/yolo_to_coco.py +0 -676
- detection/__init__.py +0 -0
- detection/detector_training/__init__.py +0 -0
- detection/detector_training/model_main_tf2.py +0 -114
- detection/process_video.py +0 -703
- detection/pytorch_detector.py +0 -337
- detection/run_detector.py +0 -779
- detection/run_detector_batch.py +0 -1219
- detection/run_inference_with_yolov5_val.py +0 -917
- detection/run_tiled_inference.py +0 -935
- detection/tf_detector.py +0 -188
- detection/video_utils.py +0 -606
- docs/source/conf.py +0 -43
- md_utils/__init__.py +0 -0
- md_utils/azure_utils.py +0 -174
- md_utils/ct_utils.py +0 -612
- md_utils/directory_listing.py +0 -246
- md_utils/md_tests.py +0 -968
- md_utils/path_utils.py +0 -1044
- md_utils/process_utils.py +0 -157
- md_utils/sas_blob_utils.py +0 -509
- md_utils/split_locations_into_train_val.py +0 -228
- md_utils/string_utils.py +0 -92
- md_utils/url_utils.py +0 -323
- md_utils/write_html_image_list.py +0 -225
- md_visualization/__init__.py +0 -0
- md_visualization/plot_utils.py +0 -293
- md_visualization/render_images_with_thumbnails.py +0 -275
- md_visualization/visualization_utils.py +0 -1537
- md_visualization/visualize_db.py +0 -551
- md_visualization/visualize_detector_output.py +0 -406
- megadetector-5.0.10.dist-info/RECORD +0 -224
- megadetector-5.0.10.dist-info/top_level.txt +0 -8
- taxonomy_mapping/__init__.py +0 -0
- taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +0 -491
- taxonomy_mapping/map_new_lila_datasets.py +0 -154
- taxonomy_mapping/prepare_lila_taxonomy_release.py +0 -142
- taxonomy_mapping/preview_lila_taxonomy.py +0 -591
- taxonomy_mapping/retrieve_sample_image.py +0 -71
- taxonomy_mapping/simple_image_download.py +0 -218
- taxonomy_mapping/species_lookup.py +0 -834
- taxonomy_mapping/taxonomy_csv_checker.py +0 -159
- taxonomy_mapping/taxonomy_graph.py +0 -346
- taxonomy_mapping/validate_lila_category_mappings.py +0 -83
- {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/WHEEL +0 -0
api/synchronous/api_core/animal_detection_api/data_management/annotations/annotation_constants.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
########
|
|
2
|
-
#
|
|
3
|
-
# annotation_constants.py
|
|
4
|
-
#
|
|
5
|
-
# Shared constants used to interpret annotation output
|
|
6
|
-
#
|
|
7
|
-
# Categories assigned to bounding boxes. Used throughout our repo; do not change unless
|
|
8
|
-
# you are Dan or Siyu. In fact, do not change unless you are both Dan *and* Siyu.
|
|
9
|
-
#
|
|
10
|
-
# We use integer indices here; this is different than the API output .json file,
|
|
11
|
-
# where indices are string integers.
|
|
12
|
-
#
|
|
13
|
-
########
|
|
14
|
-
|
|
15
|
-
NUM_DETECTOR_CATEGORIES = 3 # this is for choosing colors, so ignoring the "empty" class
|
|
16
|
-
|
|
17
|
-
# This is the label mapping used for our incoming iMerit annotations
|
|
18
|
-
# Only used to parse the incoming annotations. In our database, the string name is used to avoid confusion
|
|
19
|
-
annotation_bbox_categories = [
|
|
20
|
-
{'id': 0, 'name': 'empty'},
|
|
21
|
-
{'id': 1, 'name': 'animal'},
|
|
22
|
-
{'id': 2, 'name': 'person'},
|
|
23
|
-
{'id': 3, 'name': 'group'}, # group of animals
|
|
24
|
-
{'id': 4, 'name': 'vehicle'}
|
|
25
|
-
]
|
|
26
|
-
|
|
27
|
-
annotation_bbox_category_id_to_name = {}
|
|
28
|
-
annotation_bbox_category_name_to_id = {}
|
|
29
|
-
|
|
30
|
-
for cat in annotation_bbox_categories:
|
|
31
|
-
annotation_bbox_category_id_to_name[cat['id']] = cat['name']
|
|
32
|
-
annotation_bbox_category_name_to_id[cat['name']] = cat['id']
|
|
33
|
-
|
|
34
|
-
# MegaDetector outputs
|
|
35
|
-
detector_bbox_categories = [
|
|
36
|
-
{'id': 0, 'name': 'empty'},
|
|
37
|
-
{'id': 1, 'name': 'animal'},
|
|
38
|
-
{'id': 2, 'name': 'person'},
|
|
39
|
-
{'id': 3, 'name': 'vehicle'}
|
|
40
|
-
]
|
|
41
|
-
|
|
42
|
-
detector_bbox_category_id_to_name = {}
|
|
43
|
-
detector_bbox_category_name_to_id = {}
|
|
44
|
-
|
|
45
|
-
for cat in detector_bbox_categories:
|
|
46
|
-
detector_bbox_category_id_to_name[cat['id']] = cat['name']
|
|
47
|
-
detector_bbox_category_name_to_id[cat['name']] = cat['id']
|
api/synchronous/api_core/animal_detection_api/detection/detector_training/copy_checkpoints.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
########
|
|
2
|
-
#
|
|
3
|
-
# copy_checkpoints.py
|
|
4
|
-
#
|
|
5
|
-
# Run this script with specified source_dir and target_dir while the model is training to make a copy
|
|
6
|
-
# of every checkpoint (checkpoints are kept once an hour by default and is difficult to adjust)
|
|
7
|
-
#
|
|
8
|
-
########
|
|
9
|
-
|
|
10
|
-
#%% Imports and constants
|
|
11
|
-
|
|
12
|
-
import time
|
|
13
|
-
import os
|
|
14
|
-
import shutil
|
|
15
|
-
|
|
16
|
-
check_every_n_minutes = 10
|
|
17
|
-
|
|
18
|
-
source_dir = '/datadrive/megadetectorv3/experiments/190425'
|
|
19
|
-
target_dir = '/datadrive/megadetectorv3/experiments/0425_checkpoints'
|
|
20
|
-
|
|
21
|
-
os.makedirs(target_dir, exist_ok=True)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
#%% Main loop
|
|
25
|
-
|
|
26
|
-
num_checks = 0
|
|
27
|
-
|
|
28
|
-
while True:
|
|
29
|
-
|
|
30
|
-
num_checks += 1
|
|
31
|
-
print('Checking round {}.'.format(num_checks))
|
|
32
|
-
|
|
33
|
-
for f in os.listdir(source_dir):
|
|
34
|
-
# do not copy event or evaluation results
|
|
35
|
-
if f.startswith('model') or f.startswith('graph'):
|
|
36
|
-
target_path = os.path.join(target_dir, f)
|
|
37
|
-
if not os.path.exists(target_path):
|
|
38
|
-
_ = shutil.copy(os.path.join(source_dir, f), target_path)
|
|
39
|
-
print('Copied {}.'.format(f))
|
|
40
|
-
|
|
41
|
-
print('End of round {}.'.format(num_checks))
|
|
42
|
-
|
|
43
|
-
time.sleep(check_every_n_minutes * 60)
|
|
@@ -1,114 +0,0 @@
|
|
|
1
|
-
# Lint as: python3
|
|
2
|
-
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
# ==============================================================================
|
|
16
|
-
|
|
17
|
-
r"""Creates and runs TF2 object detection models.
|
|
18
|
-
For local training/evaluation run:
|
|
19
|
-
PIPELINE_CONFIG_PATH=path/to/pipeline.config
|
|
20
|
-
MODEL_DIR=/tmp/model_outputs
|
|
21
|
-
NUM_TRAIN_STEPS=10000
|
|
22
|
-
SAMPLE_1_OF_N_EVAL_EXAMPLES=1
|
|
23
|
-
python model_main_tf2.py -- \
|
|
24
|
-
--model_dir=$MODEL_DIR --num_train_steps=$NUM_TRAIN_STEPS \
|
|
25
|
-
--sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \
|
|
26
|
-
--pipeline_config_path=$PIPELINE_CONFIG_PATH \
|
|
27
|
-
--alsologtostderr
|
|
28
|
-
"""
|
|
29
|
-
from absl import flags
|
|
30
|
-
import tensorflow.compat.v2 as tf
|
|
31
|
-
from object_detection import model_lib_v2
|
|
32
|
-
|
|
33
|
-
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
|
|
34
|
-
'file.')
|
|
35
|
-
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
|
|
36
|
-
flags.DEFINE_bool('eval_on_train_data', False, 'Enable evaluating on train '
|
|
37
|
-
'data (only supported in distributed training).')
|
|
38
|
-
flags.DEFINE_integer('sample_1_of_n_eval_examples', None, 'Will sample one of '
|
|
39
|
-
'every n eval input examples, where n is provided.')
|
|
40
|
-
flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample '
|
|
41
|
-
'one of every n train input examples for evaluation, '
|
|
42
|
-
'where n is provided. This is only used if '
|
|
43
|
-
'`eval_training_data` is True.')
|
|
44
|
-
flags.DEFINE_string(
|
|
45
|
-
'model_dir', None, 'Path to output model directory '
|
|
46
|
-
'where event and checkpoint files will be written.')
|
|
47
|
-
flags.DEFINE_string(
|
|
48
|
-
'checkpoint_dir', None, 'Path to directory holding a checkpoint. If '
|
|
49
|
-
'`checkpoint_dir` is provided, this binary operates in eval-only mode, '
|
|
50
|
-
'writing resulting metrics to `model_dir`.')
|
|
51
|
-
|
|
52
|
-
flags.DEFINE_integer('eval_timeout', 3600, 'Number of seconds to wait for an'
|
|
53
|
-
'evaluation checkpoint before exiting.')
|
|
54
|
-
|
|
55
|
-
flags.DEFINE_bool('use_tpu', False, 'Whether the job is executing on a TPU.')
|
|
56
|
-
flags.DEFINE_string(
|
|
57
|
-
'tpu_name',
|
|
58
|
-
default=None,
|
|
59
|
-
help='Name of the Cloud TPU for Cluster Resolvers.')
|
|
60
|
-
flags.DEFINE_integer(
|
|
61
|
-
'num_workers', 1, 'When num_workers > 1, training uses '
|
|
62
|
-
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
|
|
63
|
-
'MirroredStrategy.')
|
|
64
|
-
flags.DEFINE_integer(
|
|
65
|
-
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
|
|
66
|
-
flags.DEFINE_boolean('record_summaries', True,
|
|
67
|
-
('Whether or not to record summaries defined by the model'
|
|
68
|
-
' or the training pipeline. This does not impact the'
|
|
69
|
-
' summaries of the loss values which are always'
|
|
70
|
-
' recorded.'))
|
|
71
|
-
|
|
72
|
-
FLAGS = flags.FLAGS
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def main(unused_argv):
|
|
76
|
-
flags.mark_flag_as_required('model_dir')
|
|
77
|
-
flags.mark_flag_as_required('pipeline_config_path')
|
|
78
|
-
tf.config.set_soft_device_placement(True)
|
|
79
|
-
|
|
80
|
-
if FLAGS.checkpoint_dir:
|
|
81
|
-
model_lib_v2.eval_continuously(
|
|
82
|
-
pipeline_config_path=FLAGS.pipeline_config_path,
|
|
83
|
-
model_dir=FLAGS.model_dir,
|
|
84
|
-
train_steps=FLAGS.num_train_steps,
|
|
85
|
-
sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
|
|
86
|
-
sample_1_of_n_eval_on_train_examples=(
|
|
87
|
-
FLAGS.sample_1_of_n_eval_on_train_examples),
|
|
88
|
-
checkpoint_dir=FLAGS.checkpoint_dir,
|
|
89
|
-
wait_interval=300, timeout=FLAGS.eval_timeout)
|
|
90
|
-
else:
|
|
91
|
-
if FLAGS.use_tpu:
|
|
92
|
-
# TPU is automatically inferred if tpu_name is None and
|
|
93
|
-
# we are running under cloud ai-platform.
|
|
94
|
-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
|
95
|
-
FLAGS.tpu_name)
|
|
96
|
-
tf.config.experimental_connect_to_cluster(resolver)
|
|
97
|
-
tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
98
|
-
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
|
99
|
-
elif FLAGS.num_workers > 1:
|
|
100
|
-
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
|
101
|
-
else:
|
|
102
|
-
strategy = tf.compat.v2.distribute.MirroredStrategy()
|
|
103
|
-
|
|
104
|
-
with strategy.scope():
|
|
105
|
-
model_lib_v2.train_loop(
|
|
106
|
-
pipeline_config_path=FLAGS.pipeline_config_path,
|
|
107
|
-
model_dir=FLAGS.model_dir,
|
|
108
|
-
train_steps=FLAGS.num_train_steps,
|
|
109
|
-
use_tpu=FLAGS.use_tpu,
|
|
110
|
-
checkpoint_every_n=FLAGS.checkpoint_every_n,
|
|
111
|
-
record_summaries=FLAGS.record_summaries)
|
|
112
|
-
|
|
113
|
-
if __name__ == '__main__':
|
|
114
|
-
tf.compat.v1.app.run()
|