megadetector 5.0.11__py3-none-any.whl → 5.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/__init__.py +0 -0
- megadetector/api/batch_processing/api_core/batch_service/score.py +439 -0
- megadetector/api/batch_processing/api_core/server.py +294 -0
- megadetector/api/batch_processing/api_core/server_api_config.py +98 -0
- megadetector/api/batch_processing/api_core/server_app_config.py +55 -0
- megadetector/api/batch_processing/api_core/server_batch_job_manager.py +220 -0
- megadetector/api/batch_processing/api_core/server_job_status_table.py +152 -0
- megadetector/api/batch_processing/api_core/server_orchestration.py +360 -0
- megadetector/api/batch_processing/api_core/server_utils.py +92 -0
- megadetector/api/batch_processing/api_core_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_core_support/aggregate_results_manually.py +46 -0
- megadetector/api/batch_processing/api_support/__init__.py +0 -0
- megadetector/api/batch_processing/api_support/summarize_daily_activity.py +152 -0
- megadetector/api/batch_processing/data_preparation/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +126 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/api/synchronous/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
- megadetector/api/synchronous/api_core/animal_detection_api/api_backend.py +152 -0
- megadetector/api/synchronous/api_core/animal_detection_api/api_frontend.py +266 -0
- megadetector/api/synchronous/api_core/animal_detection_api/config.py +35 -0
- megadetector/api/synchronous/api_core/tests/__init__.py +0 -0
- megadetector/api/synchronous/api_core/tests/load_test.py +110 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +627 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +855 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +610 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +699 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +506 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +34 -0
- megadetector/data_management/camtrap_dp_to_coco.py +239 -0
- megadetector/data_management/cct_json_utils.py +395 -0
- megadetector/data_management/cct_to_md.py +176 -0
- megadetector/data_management/cct_to_wi.py +289 -0
- megadetector/data_management/coco_to_labelme.py +272 -0
- megadetector/data_management/coco_to_yolo.py +662 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +33 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +206 -0
- megadetector/data_management/databases/integrity_check_json_db.py +477 -0
- megadetector/data_management/databases/subset_json_db.py +115 -0
- megadetector/data_management/generate_crops_from_cct.py +149 -0
- megadetector/data_management/get_image_sizes.py +189 -0
- megadetector/data_management/importers/add_nacti_sizes.py +52 -0
- megadetector/data_management/importers/add_timestamps_to_icct.py +79 -0
- megadetector/data_management/importers/animl_results_to_md_results.py +158 -0
- megadetector/data_management/importers/auckland_doc_test_to_json.py +373 -0
- megadetector/data_management/importers/auckland_doc_to_json.py +201 -0
- megadetector/data_management/importers/awc_to_json.py +191 -0
- megadetector/data_management/importers/bellevue_to_json.py +273 -0
- megadetector/data_management/importers/cacophony-thermal-importer.py +796 -0
- megadetector/data_management/importers/carrizo_shrubfree_2018.py +269 -0
- megadetector/data_management/importers/carrizo_trail_cam_2017.py +289 -0
- megadetector/data_management/importers/cct_field_adjustments.py +58 -0
- megadetector/data_management/importers/channel_islands_to_cct.py +913 -0
- megadetector/data_management/importers/eMammal/copy_and_unzip_emammal.py +180 -0
- megadetector/data_management/importers/eMammal/eMammal_helpers.py +249 -0
- megadetector/data_management/importers/eMammal/make_eMammal_json.py +223 -0
- megadetector/data_management/importers/ena24_to_json.py +276 -0
- megadetector/data_management/importers/filenames_to_json.py +386 -0
- megadetector/data_management/importers/helena_to_cct.py +283 -0
- megadetector/data_management/importers/idaho-camera-traps.py +1407 -0
- megadetector/data_management/importers/idfg_iwildcam_lila_prep.py +294 -0
- megadetector/data_management/importers/jb_csv_to_json.py +150 -0
- megadetector/data_management/importers/mcgill_to_json.py +250 -0
- megadetector/data_management/importers/missouri_to_json.py +490 -0
- megadetector/data_management/importers/nacti_fieldname_adjustments.py +79 -0
- megadetector/data_management/importers/noaa_seals_2019.py +181 -0
- megadetector/data_management/importers/pc_to_json.py +365 -0
- megadetector/data_management/importers/plot_wni_giraffes.py +123 -0
- megadetector/data_management/importers/prepare-noaa-fish-data-for-lila.py +359 -0
- megadetector/data_management/importers/prepare_zsl_imerit.py +131 -0
- megadetector/data_management/importers/rspb_to_json.py +356 -0
- megadetector/data_management/importers/save_the_elephants_survey_A.py +320 -0
- megadetector/data_management/importers/save_the_elephants_survey_B.py +329 -0
- megadetector/data_management/importers/snapshot_safari_importer.py +758 -0
- megadetector/data_management/importers/snapshot_safari_importer_reprise.py +665 -0
- megadetector/data_management/importers/snapshot_serengeti_lila.py +1067 -0
- megadetector/data_management/importers/snapshotserengeti/make_full_SS_json.py +150 -0
- megadetector/data_management/importers/snapshotserengeti/make_per_season_SS_json.py +153 -0
- megadetector/data_management/importers/sulross_get_exif.py +65 -0
- megadetector/data_management/importers/timelapse_csv_set_to_json.py +490 -0
- megadetector/data_management/importers/ubc_to_json.py +399 -0
- megadetector/data_management/importers/umn_to_json.py +507 -0
- megadetector/data_management/importers/wellington_to_json.py +263 -0
- megadetector/data_management/importers/wi_to_json.py +442 -0
- megadetector/data_management/importers/zamba_results_to_md_results.py +181 -0
- megadetector/data_management/labelme_to_coco.py +547 -0
- megadetector/data_management/labelme_to_yolo.py +272 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/add_locations_to_island_camera_traps.py +97 -0
- megadetector/data_management/lila/add_locations_to_nacti.py +147 -0
- megadetector/data_management/lila/create_lila_blank_set.py +558 -0
- megadetector/data_management/lila/create_lila_test_set.py +152 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +178 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +516 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +170 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +300 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +132 -0
- megadetector/data_management/ocr_tools.py +874 -0
- megadetector/data_management/read_exif.py +681 -0
- megadetector/data_management/remap_coco_categories.py +84 -0
- megadetector/data_management/remove_exif.py +66 -0
- megadetector/data_management/resize_coco_dataset.py +189 -0
- megadetector/data_management/wi_download_csv_to_coco.py +246 -0
- megadetector/data_management/yolo_output_to_md_output.py +441 -0
- megadetector/data_management/yolo_to_coco.py +676 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/detector_training/__init__.py +0 -0
- megadetector/detection/detector_training/model_main_tf2.py +114 -0
- megadetector/detection/process_video.py +702 -0
- megadetector/detection/pytorch_detector.py +341 -0
- megadetector/detection/run_detector.py +779 -0
- megadetector/detection/run_detector_batch.py +1219 -0
- megadetector/detection/run_inference_with_yolov5_val.py +917 -0
- megadetector/detection/run_tiled_inference.py +934 -0
- megadetector/detection/tf_detector.py +189 -0
- megadetector/detection/video_utils.py +606 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +64 -0
- megadetector/postprocessing/categorize_detections_by_size.py +163 -0
- megadetector/postprocessing/combine_api_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +958 -0
- megadetector/postprocessing/convert_output_format.py +396 -0
- megadetector/postprocessing/load_api_results.py +195 -0
- megadetector/postprocessing/md_to_coco.py +310 -0
- megadetector/postprocessing/md_to_labelme.py +330 -0
- megadetector/postprocessing/merge_detections.py +401 -0
- megadetector/postprocessing/postprocess_batch_results.py +1902 -0
- megadetector/postprocessing/remap_detection_categories.py +170 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +660 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +211 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +83 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1631 -0
- megadetector/postprocessing/separate_detections_into_folders.py +730 -0
- megadetector/postprocessing/subset_json_detector_output.py +696 -0
- megadetector/postprocessing/top_folders_to_bottom.py +223 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +150 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +142 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +590 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +219 -0
- megadetector/taxonomy_mapping/species_lookup.py +834 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/azure_utils.py +178 -0
- megadetector/utils/ct_utils.py +612 -0
- megadetector/utils/directory_listing.py +246 -0
- megadetector/utils/md_tests.py +968 -0
- megadetector/utils/path_utils.py +1044 -0
- megadetector/utils/process_utils.py +157 -0
- megadetector/utils/sas_blob_utils.py +509 -0
- megadetector/utils/split_locations_into_train_val.py +228 -0
- megadetector/utils/string_utils.py +92 -0
- megadetector/utils/url_utils.py +323 -0
- megadetector/utils/write_html_image_list.py +225 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +293 -0
- megadetector/visualization/render_images_with_thumbnails.py +275 -0
- megadetector/visualization/visualization_utils.py +1536 -0
- megadetector/visualization/visualize_db.py +550 -0
- megadetector/visualization/visualize_detector_output.py +405 -0
- {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/METADATA +1 -1
- megadetector-5.0.12.dist-info/RECORD +199 -0
- megadetector-5.0.12.dist-info/top_level.txt +1 -0
- megadetector-5.0.11.dist-info/RECORD +0 -5
- megadetector-5.0.11.dist-info/top_level.txt +0 -1
- {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/LICENSE +0 -0
- {megadetector-5.0.11.dist-info → megadetector-5.0.12.dist-info}/WHEEL +0 -0
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
from PIL import Image
|
|
11
|
+
import numpy as np
|
|
12
|
+
import requests
|
|
13
|
+
import tensorflow as tf
|
|
14
|
+
from azure.storage.blob import ContainerClient
|
|
15
|
+
|
|
16
|
+
print('score.py, tensorflow version:', tf.__version__)
|
|
17
|
+
print('score.py, tf.test.is_gpu_available:', tf.test.is_gpu_available())
|
|
18
|
+
|
|
19
|
+
PRINT_EVERY = 500
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
#%% Helper functions *copied* from ct_utils.py and visualization/visualization_utils.py
|
|
23
|
+
|
|
24
|
+
IMAGE_ROTATIONS = {
|
|
25
|
+
3: 180,
|
|
26
|
+
6: 270,
|
|
27
|
+
8: 90
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def truncate_float(x, precision=3):
|
|
31
|
+
"""
|
|
32
|
+
Function for truncating a float scalar to the defined precision.
|
|
33
|
+
For example: truncate_float(0.0003214884) --> 0.000321
|
|
34
|
+
This function is primarily used to achieve a certain float representation
|
|
35
|
+
before exporting to JSON
|
|
36
|
+
Args:
|
|
37
|
+
x (float) Scalar to truncate
|
|
38
|
+
precision (int) The number of significant digits to preserve, should be
|
|
39
|
+
greater or equal 1
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
assert precision > 0
|
|
43
|
+
|
|
44
|
+
if np.isclose(x, 0):
|
|
45
|
+
return 0
|
|
46
|
+
else:
|
|
47
|
+
# Determine the factor, which shifts the decimal point of x
|
|
48
|
+
# just behind the last significant digit
|
|
49
|
+
factor = math.pow(10, precision - 1 - math.floor(math.log10(abs(x))))
|
|
50
|
+
# Shift decimal point by multiplicatipon with factor, flooring, and
|
|
51
|
+
# division by factor
|
|
52
|
+
return math.floor(x * factor)/factor
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def open_image(input_file: Union[str, BytesIO]) -> Image:
|
|
56
|
+
"""Opens an image in binary format using PIL.Image and converts to RGB mode.
|
|
57
|
+
|
|
58
|
+
This operation is lazy; image will not be actually loaded until the first
|
|
59
|
+
operation that needs to load it (for example, resizing), so file opening
|
|
60
|
+
errors can show up later.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input_file: str or BytesIO, either a path to an image file (anything
|
|
64
|
+
that PIL can open), or an image as a stream of bytes
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
an PIL image object in RGB mode
|
|
68
|
+
"""
|
|
69
|
+
if (isinstance(input_file, str)
|
|
70
|
+
and input_file.startswith(('http://', 'https://'))):
|
|
71
|
+
response = requests.get(input_file)
|
|
72
|
+
image = Image.open(BytesIO(response.content))
|
|
73
|
+
try:
|
|
74
|
+
response = requests.get(input_file)
|
|
75
|
+
image = Image.open(BytesIO(response.content))
|
|
76
|
+
except Exception as e:
|
|
77
|
+
print(f'Error opening image {input_file}: {e}')
|
|
78
|
+
raise
|
|
79
|
+
else:
|
|
80
|
+
image = Image.open(input_file)
|
|
81
|
+
if image.mode not in ('RGBA', 'RGB', 'L'):
|
|
82
|
+
raise AttributeError(f'Image {input_file} uses unsupported mode {image.mode}')
|
|
83
|
+
if image.mode == 'RGBA' or image.mode == 'L':
|
|
84
|
+
# PIL.Image.convert() returns a converted copy of this image
|
|
85
|
+
image = image.convert(mode='RGB')
|
|
86
|
+
|
|
87
|
+
# alter orientation as needed according to EXIF tag 0x112 (274) for Orientation
|
|
88
|
+
# https://gist.github.com/dangtrinhnt/a577ece4cbe5364aad28
|
|
89
|
+
# https://www.media.mit.edu/pia/Research/deepview/exif.html
|
|
90
|
+
try:
|
|
91
|
+
exif = image._getexif()
|
|
92
|
+
orientation: int = exif.get(274, None) # 274 is the key for the Orientation field
|
|
93
|
+
if orientation is not None and orientation in IMAGE_ROTATIONS:
|
|
94
|
+
image = image.rotate(IMAGE_ROTATIONS[orientation], expand=True) # returns a rotated copy
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
return image
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def load_image(input_file: Union[str, BytesIO]) -> Image.Image:
|
|
102
|
+
"""Loads the image at input_file as a PIL Image into memory.
|
|
103
|
+
Image.open() used in open_image() is lazy and errors will occur downstream
|
|
104
|
+
if not explicitly loaded.
|
|
105
|
+
Args:
|
|
106
|
+
input_file: str or BytesIO, either a path to an image file (anything
|
|
107
|
+
that PIL can open), or an image as a stream of bytes
|
|
108
|
+
Returns: PIL.Image.Image, in RGB mode
|
|
109
|
+
"""
|
|
110
|
+
image = open_image(input_file)
|
|
111
|
+
image.load()
|
|
112
|
+
return image
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
#%% TFDetector class, an unmodified *copy* of the class in detection/tf_detector.py,
|
|
116
|
+
# so we do not have to import the packages required by run_detector.py
|
|
117
|
+
|
|
118
|
+
class TFDetector:
|
|
119
|
+
"""
|
|
120
|
+
A detector model loaded at the time of initialization. It is intended to be used with
|
|
121
|
+
MegaDetector (TF). The inference batch size is set to 1; code needs to be modified
|
|
122
|
+
to support larger batch sizes, including resizing appropriately.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
# Number of decimal places to round to for confidence and bbox coordinates
|
|
126
|
+
CONF_DIGITS = 3
|
|
127
|
+
COORD_DIGITS = 4
|
|
128
|
+
|
|
129
|
+
# MegaDetector was trained with batch size of 1, and the resizing function is a part
|
|
130
|
+
# of the inference graph
|
|
131
|
+
BATCH_SIZE = 1
|
|
132
|
+
|
|
133
|
+
# An enumeration of failure reasons
|
|
134
|
+
FAILURE_TF_INFER = 'Failure TF inference'
|
|
135
|
+
FAILURE_IMAGE_OPEN = 'Failure image access'
|
|
136
|
+
|
|
137
|
+
DEFAULT_RENDERING_CONFIDENCE_THRESHOLD = 0.85 # to render bounding boxes
|
|
138
|
+
DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.1 # to include in the output json file
|
|
139
|
+
|
|
140
|
+
DEFAULT_DETECTOR_LABEL_MAP = {
|
|
141
|
+
'1': 'animal',
|
|
142
|
+
'2': 'person',
|
|
143
|
+
'3': 'vehicle' # available in megadetector v4+
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
NUM_DETECTOR_CATEGORIES = 4 # animal, person, group, vehicle - for color assignment
|
|
147
|
+
|
|
148
|
+
def __init__(self, model_path):
|
|
149
|
+
"""Loads model from model_path and starts a tf.Session with this graph. Obtains
|
|
150
|
+
input and output tensor handles."""
|
|
151
|
+
detection_graph = TFDetector.__load_model(model_path)
|
|
152
|
+
self.tf_session = tf.Session(graph=detection_graph)
|
|
153
|
+
|
|
154
|
+
self.image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
|
|
155
|
+
self.box_tensor = detection_graph.get_tensor_by_name('detection_boxes:0')
|
|
156
|
+
self.score_tensor = detection_graph.get_tensor_by_name('detection_scores:0')
|
|
157
|
+
self.class_tensor = detection_graph.get_tensor_by_name('detection_classes:0')
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def round_and_make_float(d, precision=4):
|
|
161
|
+
return truncate_float(float(d), precision=precision)
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def __convert_coords(tf_coords):
|
|
165
|
+
"""Converts coordinates from the model's output format [y1, x1, y2, x2] to the
|
|
166
|
+
format used by our API and MegaDB: [x1, y1, width, height]. All coordinates
|
|
167
|
+
(including model outputs) are normalized in the range [0, 1].
|
|
168
|
+
Args:
|
|
169
|
+
tf_coords: np.array of predicted bounding box coordinates from the TF detector,
|
|
170
|
+
has format [y1, x1, y2, x2]
|
|
171
|
+
Returns: list of Python float, predicted bounding box coordinates [x1, y1, width, height]
|
|
172
|
+
"""
|
|
173
|
+
# change from [y1, x1, y2, x2] to [x1, y1, width, height]
|
|
174
|
+
width = tf_coords[3] - tf_coords[1]
|
|
175
|
+
height = tf_coords[2] - tf_coords[0]
|
|
176
|
+
|
|
177
|
+
new = [tf_coords[1], tf_coords[0], width, height] # must be a list instead of np.array
|
|
178
|
+
|
|
179
|
+
# convert numpy floats to Python floats
|
|
180
|
+
for i, d in enumerate(new):
|
|
181
|
+
new[i] = TFDetector.round_and_make_float(d, precision=TFDetector.COORD_DIGITS)
|
|
182
|
+
return new
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def convert_to_tf_coords(array):
|
|
186
|
+
"""From [x1, y1, width, height] to [y1, x1, y2, x2], where x1 is x_min, x2 is x_max
|
|
187
|
+
This is an extraneous step as the model outputs [y1, x1, y2, x2] but were converted to the API
|
|
188
|
+
output format - only to keep the interface of the sync API.
|
|
189
|
+
"""
|
|
190
|
+
x1 = array[0]
|
|
191
|
+
y1 = array[1]
|
|
192
|
+
width = array[2]
|
|
193
|
+
height = array[3]
|
|
194
|
+
x2 = x1 + width
|
|
195
|
+
y2 = y1 + height
|
|
196
|
+
return [y1, x1, y2, x2]
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def __load_model(model_path):
|
|
200
|
+
"""Loads a detection model (i.e., create a graph) from a .pb file.
|
|
201
|
+
Args:
|
|
202
|
+
model_path: .pb file of the model.
|
|
203
|
+
Returns: the loaded graph.
|
|
204
|
+
"""
|
|
205
|
+
print('TFDetector: Loading graph...')
|
|
206
|
+
detection_graph = tf.Graph()
|
|
207
|
+
with detection_graph.as_default():
|
|
208
|
+
od_graph_def = tf.GraphDef()
|
|
209
|
+
with tf.gfile.GFile(model_path, 'rb') as fid:
|
|
210
|
+
serialized_graph = fid.read()
|
|
211
|
+
od_graph_def.ParseFromString(serialized_graph)
|
|
212
|
+
tf.import_graph_def(od_graph_def, name='')
|
|
213
|
+
print('TFDetector: Detection graph loaded.')
|
|
214
|
+
|
|
215
|
+
return detection_graph
|
|
216
|
+
|
|
217
|
+
def _generate_detections_one_image(self, image):
|
|
218
|
+
np_im = np.asarray(image, np.uint8)
|
|
219
|
+
im_w_batch_dim = np.expand_dims(np_im, axis=0)
|
|
220
|
+
|
|
221
|
+
# need to change the above line to the following if supporting a batch size > 1 and resizing to the same size
|
|
222
|
+
# np_images = [np.asarray(image, np.uint8) for image in images]
|
|
223
|
+
# images_stacked = np.stack(np_images, axis=0) if len(images) > 1 else np.expand_dims(np_images[0], axis=0)
|
|
224
|
+
|
|
225
|
+
# performs inference
|
|
226
|
+
(box_tensor_out, score_tensor_out, class_tensor_out) = self.tf_session.run(
|
|
227
|
+
[self.box_tensor, self.score_tensor, self.class_tensor],
|
|
228
|
+
feed_dict={self.image_tensor: im_w_batch_dim})
|
|
229
|
+
|
|
230
|
+
return box_tensor_out, score_tensor_out, class_tensor_out
|
|
231
|
+
|
|
232
|
+
def generate_detections_one_image(self, image, image_id,
|
|
233
|
+
detection_threshold=DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD):
|
|
234
|
+
"""Apply the detector to an image.
|
|
235
|
+
Args:
|
|
236
|
+
image: the PIL Image object
|
|
237
|
+
image_id: a path to identify the image; will be in the "file" field of the output object
|
|
238
|
+
detection_threshold: confidence above which to include the detection proposal
|
|
239
|
+
Returns:
|
|
240
|
+
A dict with the following fields, see the 'images' key in https://github.com/agentmorris/MegaDetector/tree/main/megadetector/api/batch_processing#batch-processing-api-output-format
|
|
241
|
+
- 'file' (always present)
|
|
242
|
+
- 'max_detection_conf'
|
|
243
|
+
- 'detections', which is a list of detection objects containing keys 'category', 'conf' and 'bbox'
|
|
244
|
+
- 'failure'
|
|
245
|
+
"""
|
|
246
|
+
result = {
|
|
247
|
+
'file': image_id
|
|
248
|
+
}
|
|
249
|
+
try:
|
|
250
|
+
b_box, b_score, b_class = self._generate_detections_one_image(image)
|
|
251
|
+
|
|
252
|
+
# our batch size is 1; need to loop the batch dim if supporting batch size > 1
|
|
253
|
+
boxes, scores, classes = b_box[0], b_score[0], b_class[0]
|
|
254
|
+
|
|
255
|
+
detections_cur_image = [] # will be empty for an image with no confident detections
|
|
256
|
+
max_detection_conf = 0.0
|
|
257
|
+
for b, s, c in zip(boxes, scores, classes):
|
|
258
|
+
if s > detection_threshold:
|
|
259
|
+
detection_entry = {
|
|
260
|
+
'category': str(int(c)), # use string type for the numerical class label, not int
|
|
261
|
+
'conf': truncate_float(float(s), # cast to float for json serialization
|
|
262
|
+
precision=TFDetector.CONF_DIGITS),
|
|
263
|
+
'bbox': TFDetector.__convert_coords(b)
|
|
264
|
+
}
|
|
265
|
+
detections_cur_image.append(detection_entry)
|
|
266
|
+
if s > max_detection_conf:
|
|
267
|
+
max_detection_conf = s
|
|
268
|
+
|
|
269
|
+
result['max_detection_conf'] = truncate_float(float(max_detection_conf),
|
|
270
|
+
precision=TFDetector.CONF_DIGITS)
|
|
271
|
+
result['detections'] = detections_cur_image
|
|
272
|
+
|
|
273
|
+
except Exception as e:
|
|
274
|
+
result['failure'] = TFDetector.FAILURE_TF_INFER
|
|
275
|
+
print('TFDetector: image {} failed during inference: {}'.format(image_id, str(e)))
|
|
276
|
+
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
#%% Scoring script
|
|
281
|
+
|
|
282
|
+
class BatchScorer:
|
|
283
|
+
"""
|
|
284
|
+
Coordinates scoring images in this Task.
|
|
285
|
+
|
|
286
|
+
1. have a synchronized queue that download tasks enqueue and scoring function dequeues - but need to be able to
|
|
287
|
+
limit the size of the queue. We do not want to write the image to disk and then load it in the scoring func.
|
|
288
|
+
"""
|
|
289
|
+
def __init__(self, **kwargs):
|
|
290
|
+
print('score.py BatchScorer, __init__()')
|
|
291
|
+
|
|
292
|
+
detector_path = kwargs.get('detector_path')
|
|
293
|
+
self.detector = TFDetector(detector_path)
|
|
294
|
+
|
|
295
|
+
self.use_url = kwargs.get('use_url')
|
|
296
|
+
if not self.use_url:
|
|
297
|
+
input_container_sas = kwargs.get('input_container_sas')
|
|
298
|
+
self.input_container_client = ContainerClient.from_container_url(input_container_sas)
|
|
299
|
+
|
|
300
|
+
self.detection_threshold = kwargs.get('detection_threshold')
|
|
301
|
+
|
|
302
|
+
self.image_ids_to_score = kwargs.get('image_ids_to_score')
|
|
303
|
+
|
|
304
|
+
# determine if there is metadata attached to each image_id
|
|
305
|
+
self.metadata_available = True if isinstance(self.image_ids_to_score[0], list) else False
|
|
306
|
+
|
|
307
|
+
def _download_image(self, image_file) -> Image:
|
|
308
|
+
"""
|
|
309
|
+
Args:
|
|
310
|
+
image_file: Public URL if use_url, else the full path from container root
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
PIL image loaded
|
|
314
|
+
"""
|
|
315
|
+
if not self.use_url:
|
|
316
|
+
downloader = self.input_container_client.download_blob(image_file)
|
|
317
|
+
image_file = io.BytesIO()
|
|
318
|
+
blob_props = downloader.download_to_stream(image_file)
|
|
319
|
+
|
|
320
|
+
image = open_image(image_file)
|
|
321
|
+
return image
|
|
322
|
+
|
|
323
|
+
def score_images(self) -> list:
|
|
324
|
+
detections = []
|
|
325
|
+
|
|
326
|
+
for i in self.image_ids_to_score:
|
|
327
|
+
|
|
328
|
+
if self.metadata_available:
|
|
329
|
+
image_id = i[0]
|
|
330
|
+
image_metadata = i[1]
|
|
331
|
+
else:
|
|
332
|
+
image_id = i
|
|
333
|
+
|
|
334
|
+
try:
|
|
335
|
+
image = self._download_image(image_id)
|
|
336
|
+
except Exception as e:
|
|
337
|
+
print(f'score.py BatchScorer, score_images, download_image exception: {e}')
|
|
338
|
+
result = {
|
|
339
|
+
'file': image_id,
|
|
340
|
+
'failure': TFDetector.FAILURE_IMAGE_OPEN
|
|
341
|
+
}
|
|
342
|
+
else:
|
|
343
|
+
result = self.detector.generate_detections_one_image(
|
|
344
|
+
image, image_id, detection_threshold=self.detection_threshold)
|
|
345
|
+
|
|
346
|
+
if self.metadata_available:
|
|
347
|
+
result['meta'] = image_metadata
|
|
348
|
+
|
|
349
|
+
detections.append(result)
|
|
350
|
+
if len(detections) % PRINT_EVERY == 0:
|
|
351
|
+
print(f'scored {len(detections)} images')
|
|
352
|
+
|
|
353
|
+
return detections
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def main():
|
|
357
|
+
print('score.py, main()')
|
|
358
|
+
|
|
359
|
+
# information to determine input and output locations
|
|
360
|
+
api_instance_name = os.environ['API_INSTANCE_NAME']
|
|
361
|
+
job_id = os.environ['AZ_BATCH_JOB_ID']
|
|
362
|
+
task_id = os.environ['AZ_BATCH_TASK_ID']
|
|
363
|
+
mount_point = os.environ['AZ_BATCH_NODE_MOUNTS_DIR']
|
|
364
|
+
|
|
365
|
+
# other parameters for the task
|
|
366
|
+
begin_index = int(os.environ['TASK_BEGIN_INDEX'])
|
|
367
|
+
end_index = int(os.environ['TASK_END_INDEX'])
|
|
368
|
+
|
|
369
|
+
input_container_sas = os.environ.get('JOB_CONTAINER_SAS', None) # could be None if use_url
|
|
370
|
+
use_url = os.environ.get('JOB_USE_URL', None)
|
|
371
|
+
|
|
372
|
+
if use_url and use_url.lower() == 'true': # bool of any non-empty string is True
|
|
373
|
+
use_url = True
|
|
374
|
+
else:
|
|
375
|
+
use_url = False
|
|
376
|
+
|
|
377
|
+
detection_threshold = float(os.environ['DETECTION_CONF_THRESHOLD'])
|
|
378
|
+
|
|
379
|
+
print(f'score.py, main(), api_instance_name: {api_instance_name}, job_id: {job_id}, task_id: {task_id}, '
|
|
380
|
+
f'mount_point: {mount_point}, begin_index: {begin_index}, end_index: {end_index}, '
|
|
381
|
+
f'input_container_sas: {input_container_sas}, use_url (parsed): {use_url}'
|
|
382
|
+
f'detection_threshold: {detection_threshold}')
|
|
383
|
+
|
|
384
|
+
job_folder_mounted = os.path.join(mount_point, 'batch-api', f'api_{api_instance_name}', f'job_{job_id}')
|
|
385
|
+
task_out_dir = os.path.join(job_folder_mounted, 'task_outputs')
|
|
386
|
+
os.makedirs(task_out_dir, exist_ok=True)
|
|
387
|
+
task_output_path = os.path.join(task_out_dir, f'job_{job_id}_task_{task_id}.json')
|
|
388
|
+
|
|
389
|
+
# test that we can write to output path; also in case there is no image to process
|
|
390
|
+
with open(task_output_path, 'w') as f:
|
|
391
|
+
json.dump([], f)
|
|
392
|
+
|
|
393
|
+
# list images to process
|
|
394
|
+
list_images_path = os.path.join(job_folder_mounted, f'{job_id}_images.json')
|
|
395
|
+
with open(list_images_path) as f:
|
|
396
|
+
list_images = json.load(f)
|
|
397
|
+
print(f'score.py, main(), length of list_images: {len(list_images)}')
|
|
398
|
+
|
|
399
|
+
if (not isinstance(list_images, list)) or len(list_images) == 0:
|
|
400
|
+
print('score.py, main(), zero images in specified overall list, exiting...')
|
|
401
|
+
sys.exit(0)
|
|
402
|
+
|
|
403
|
+
# items in this list can be strings or [image_id, metadata]
|
|
404
|
+
list_images = list_images[begin_index: end_index]
|
|
405
|
+
if len(list_images) == 0:
|
|
406
|
+
print('score.py, main(), zero images in the shard, exiting')
|
|
407
|
+
sys.exit(0)
|
|
408
|
+
|
|
409
|
+
print(f'score.py, main(), processing {len(list_images)} images in this Task')
|
|
410
|
+
|
|
411
|
+
# model path
|
|
412
|
+
# Path to .pb TensorFlow detector model file, relative to the
|
|
413
|
+
# models/megadetector_copies folder in mounted container
|
|
414
|
+
detector_model_rel_path = os.environ['DETECTOR_REL_PATH']
|
|
415
|
+
detector_path = os.path.join(mount_point, 'models', 'megadetector_copies', detector_model_rel_path)
|
|
416
|
+
assert os.path.exists(detector_path), f'detector is not found at the specified path: {detector_path}'
|
|
417
|
+
|
|
418
|
+
# score the images
|
|
419
|
+
scorer = BatchScorer(
|
|
420
|
+
detector_path=detector_path,
|
|
421
|
+
use_url=use_url,
|
|
422
|
+
input_container_sas=input_container_sas,
|
|
423
|
+
detection_threshold=detection_threshold,
|
|
424
|
+
image_ids_to_score=list_images
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
try:
|
|
428
|
+
tick = datetime.now()
|
|
429
|
+
detections = scorer.score_images()
|
|
430
|
+
duration = datetime.now() - tick
|
|
431
|
+
print(f'score.py, main(), score_images() duration: {duration}')
|
|
432
|
+
except Exception as e:
|
|
433
|
+
raise RuntimeError(f'score.py, main(), exception in score_images(): {e}')
|
|
434
|
+
|
|
435
|
+
with open(task_output_path, 'w', encoding='utf-8') as f:
|
|
436
|
+
json.dump(detections, f, ensure_ascii=False)
|
|
437
|
+
|
|
438
|
+
if __name__ == '__main__':
|
|
439
|
+
main()
|