megadetector 10.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of megadetector might be problematic. Click here for more details.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +702 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +528 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +187 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +663 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +876 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2159 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1494 -0
- megadetector/detection/run_tiled_inference.py +1038 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1752 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2077 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +213 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +224 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2832 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1759 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1940 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +479 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.13.dist-info/METADATA +134 -0
- megadetector-10.0.13.dist-info/RECORD +147 -0
- megadetector-10.0.13.dist-info/WHEEL +5 -0
- megadetector-10.0.13.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
"""utils.py - Helper functions for building the model and for loading model parameters.
|
|
2
|
+
These helper functions are built to mirror those in the official TensorFlow implementation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Author: lukemelas (github username)
|
|
6
|
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
|
7
|
+
# With adjustments and added comments by workingcoder (github username).
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
import math
|
|
11
|
+
import collections
|
|
12
|
+
from functools import partial
|
|
13
|
+
import torch
|
|
14
|
+
from torch import nn
|
|
15
|
+
from torch.nn import functional as F
|
|
16
|
+
from torch.utils import model_zoo
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
################################################################################
|
|
20
|
+
### Help functions for model architecture
|
|
21
|
+
################################################################################
|
|
22
|
+
|
|
23
|
+
# GlobalParams and BlockArgs: Two namedtuples
|
|
24
|
+
# Swish and MemoryEfficientSwish: Two implementations of the method
|
|
25
|
+
# round_filters and round_repeats:
|
|
26
|
+
# Functions to calculate params for scaling model width and depth ! ! !
|
|
27
|
+
# get_width_and_height_from_size and calculate_output_image_size
|
|
28
|
+
# drop_connect: A structural design
|
|
29
|
+
# get_same_padding_conv2d:
|
|
30
|
+
# Conv2dDynamicSamePadding
|
|
31
|
+
# Conv2dStaticSamePadding
|
|
32
|
+
# get_same_padding_maxPool2d:
|
|
33
|
+
# MaxPool2dDynamicSamePadding
|
|
34
|
+
# MaxPool2dStaticSamePadding
|
|
35
|
+
# It's an additional function, not used in EfficientNet,
|
|
36
|
+
# but can be used in other model (such as EfficientDet).
|
|
37
|
+
|
|
38
|
+
# Parameters for the entire model (stem, all blocks, and head)
|
|
39
|
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
|
40
|
+
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
|
|
41
|
+
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
|
|
42
|
+
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
|
|
43
|
+
|
|
44
|
+
# Parameters for an individual model block
|
|
45
|
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
|
46
|
+
'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
|
|
47
|
+
'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
|
|
48
|
+
|
|
49
|
+
# Set GlobalParams and BlockArgs's defaults
|
|
50
|
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
|
51
|
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# An ordinary implementation of Swish function
|
|
55
|
+
class Swish(nn.Module):
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
return x * torch.sigmoid(x)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# A memory-efficient implementation of Swish function
|
|
61
|
+
class SwishImplementation(torch.autograd.Function):
|
|
62
|
+
@staticmethod
|
|
63
|
+
def forward(ctx, i):
|
|
64
|
+
result = i * torch.sigmoid(i)
|
|
65
|
+
ctx.save_for_backward(i)
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def backward(ctx, grad_output):
|
|
70
|
+
i = ctx.saved_tensors[0]
|
|
71
|
+
sigmoid_i = torch.sigmoid(i)
|
|
72
|
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
|
73
|
+
|
|
74
|
+
class MemoryEfficientSwish(nn.Module):
|
|
75
|
+
def forward(self, x):
|
|
76
|
+
return SwishImplementation.apply(x)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def round_filters(filters, global_params):
|
|
80
|
+
"""Calculate and round number of filters based on width multiplier.
|
|
81
|
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
filters (int): Filters number to be calculated.
|
|
85
|
+
global_params (namedtuple): Global params of the model.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
new_filters: New filters number after calculating.
|
|
89
|
+
"""
|
|
90
|
+
multiplier = global_params.width_coefficient
|
|
91
|
+
if not multiplier:
|
|
92
|
+
return filters
|
|
93
|
+
divisor = global_params.depth_divisor
|
|
94
|
+
min_depth = global_params.min_depth
|
|
95
|
+
filters *= multiplier
|
|
96
|
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
|
97
|
+
# follow the formula transferred from official TensorFlow implementation
|
|
98
|
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
|
99
|
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
|
100
|
+
new_filters += divisor
|
|
101
|
+
return int(new_filters)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def round_repeats(repeats, global_params):
|
|
105
|
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
|
106
|
+
Use depth_coefficient of global_params.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
repeats (int): num_repeat to be calculated.
|
|
110
|
+
global_params (namedtuple): Global params of the model.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
new repeat: New repeat number after calculating.
|
|
114
|
+
"""
|
|
115
|
+
multiplier = global_params.depth_coefficient
|
|
116
|
+
if not multiplier:
|
|
117
|
+
return repeats
|
|
118
|
+
# follow the formula transferred from official TensorFlow implementation
|
|
119
|
+
return int(math.ceil(multiplier * repeats))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def drop_connect(inputs, p, training):
|
|
123
|
+
"""Drop connect.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
input (tensor: BCWH): Input of this structure.
|
|
127
|
+
p (float: 0.0~1.0): Probability of drop connection.
|
|
128
|
+
training (bool): The running mode.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
output: Output after drop connection.
|
|
132
|
+
"""
|
|
133
|
+
assert 0 <= p <= 1, 'p must be in range of [0,1]'
|
|
134
|
+
|
|
135
|
+
if not training:
|
|
136
|
+
return inputs
|
|
137
|
+
|
|
138
|
+
batch_size = inputs.shape[0]
|
|
139
|
+
keep_prob = 1 - p
|
|
140
|
+
|
|
141
|
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
|
142
|
+
random_tensor = keep_prob
|
|
143
|
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
|
144
|
+
binary_tensor = torch.floor(random_tensor)
|
|
145
|
+
|
|
146
|
+
output = inputs / keep_prob * binary_tensor
|
|
147
|
+
return output
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def get_width_and_height_from_size(x):
|
|
151
|
+
"""Obtain height and width from x.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
x (int, tuple or list): Data size.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
size: A tuple or list (H,W).
|
|
158
|
+
"""
|
|
159
|
+
if isinstance(x, int):
|
|
160
|
+
return x, x
|
|
161
|
+
if isinstance(x, list) or isinstance(x, tuple):
|
|
162
|
+
return x
|
|
163
|
+
else:
|
|
164
|
+
raise TypeError()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def calculate_output_image_size(input_image_size, stride):
|
|
168
|
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
|
169
|
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
input_image_size (int, tuple or list): Size of input image.
|
|
173
|
+
stride (int, tuple or list): Conv2d operation's stride.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
output_image_size: A list [H,W].
|
|
177
|
+
"""
|
|
178
|
+
if input_image_size is None:
|
|
179
|
+
return None
|
|
180
|
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
|
181
|
+
stride = stride if isinstance(stride, int) else stride[0]
|
|
182
|
+
image_height = int(math.ceil(image_height / stride))
|
|
183
|
+
image_width = int(math.ceil(image_width / stride))
|
|
184
|
+
return [image_height, image_width]
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# Note:
|
|
188
|
+
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
|
|
189
|
+
# Only when stride equals 1, can the output size be the same as input size.
|
|
190
|
+
# Don't be confused by their function names ! ! !
|
|
191
|
+
|
|
192
|
+
def get_same_padding_conv2d(image_size=None):
|
|
193
|
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
|
194
|
+
Static padding is necessary for ONNX exporting of models.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
image_size (int or tuple, optional): Size of the image.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
|
201
|
+
"""
|
|
202
|
+
if image_size is None:
|
|
203
|
+
return Conv2dDynamicSamePadding
|
|
204
|
+
else:
|
|
205
|
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
|
209
|
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
|
210
|
+
The padding is operated in forward function by calculating dynamically.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
# Tips for 'SAME' mode padding.
|
|
214
|
+
# Given the following:
|
|
215
|
+
# i: width or height
|
|
216
|
+
# s: stride
|
|
217
|
+
# k: kernel size
|
|
218
|
+
# d: dilation
|
|
219
|
+
# p: padding
|
|
220
|
+
# Output after Conv2d:
|
|
221
|
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
|
222
|
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
|
223
|
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
|
224
|
+
|
|
225
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
|
226
|
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
|
227
|
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
|
228
|
+
|
|
229
|
+
def forward(self, x):
|
|
230
|
+
ih, iw = x.size()[-2:]
|
|
231
|
+
kh, kw = self.weight.size()[-2:]
|
|
232
|
+
sh, sw = self.stride
|
|
233
|
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
|
|
234
|
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
|
235
|
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
|
236
|
+
if pad_h > 0 or pad_w > 0:
|
|
237
|
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
|
238
|
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
|
242
|
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
|
243
|
+
The padding module is calculated in construction function, then used in forward.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
# With the same calculation as Conv2dDynamicSamePadding
|
|
247
|
+
|
|
248
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
|
|
249
|
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
|
250
|
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
|
251
|
+
|
|
252
|
+
# Calculate padding based on image size and save it
|
|
253
|
+
assert image_size is not None
|
|
254
|
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
|
255
|
+
kh, kw = self.weight.size()[-2:]
|
|
256
|
+
sh, sw = self.stride
|
|
257
|
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
|
258
|
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
|
259
|
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
|
260
|
+
if pad_h > 0 or pad_w > 0:
|
|
261
|
+
self.static_padding = nn.ZeroPad2d((pad_w - pad_w // 2, pad_w - pad_w // 2,
|
|
262
|
+
pad_h - pad_h // 2, pad_h - pad_h // 2))
|
|
263
|
+
else:
|
|
264
|
+
self.static_padding = nn.Identity()
|
|
265
|
+
|
|
266
|
+
def forward(self, x):
|
|
267
|
+
x = self.static_padding(x)
|
|
268
|
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
269
|
+
return x
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def get_same_padding_maxPool2d(image_size=None):
|
|
273
|
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
|
274
|
+
Static padding is necessary for ONNX exporting of models.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
image_size (int or tuple, optional): Size of the image.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
|
281
|
+
"""
|
|
282
|
+
if image_size is None:
|
|
283
|
+
return MaxPool2dDynamicSamePadding
|
|
284
|
+
else:
|
|
285
|
+
return partial(MaxPool2dStaticSamePadding, image_size=image_size)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
|
|
289
|
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
|
|
290
|
+
The padding is operated in forward function by calculating dynamically.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
|
|
294
|
+
super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
|
|
295
|
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
|
296
|
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
|
297
|
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
|
298
|
+
|
|
299
|
+
def forward(self, x):
|
|
300
|
+
ih, iw = x.size()[-2:]
|
|
301
|
+
kh, kw = self.kernel_size
|
|
302
|
+
sh, sw = self.stride
|
|
303
|
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
|
304
|
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
|
305
|
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
|
306
|
+
if pad_h > 0 or pad_w > 0:
|
|
307
|
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
|
308
|
+
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
|
309
|
+
self.dilation, self.ceil_mode, self.return_indices)
|
|
310
|
+
|
|
311
|
+
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
|
|
312
|
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
|
|
313
|
+
The padding module is calculated in construction function, then used in forward.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
def __init__(self, kernel_size, stride, image_size=None, **kwargs):
|
|
317
|
+
super().__init__(kernel_size, stride, **kwargs)
|
|
318
|
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
|
319
|
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
|
320
|
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
|
321
|
+
|
|
322
|
+
# Calculate padding based on image size and save it
|
|
323
|
+
assert image_size is not None
|
|
324
|
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
|
325
|
+
kh, kw = self.kernel_size
|
|
326
|
+
sh, sw = self.stride
|
|
327
|
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
|
328
|
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
|
329
|
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
|
330
|
+
if pad_h > 0 or pad_w > 0:
|
|
331
|
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
|
332
|
+
else:
|
|
333
|
+
self.static_padding = nn.Identity()
|
|
334
|
+
|
|
335
|
+
def forward(self, x):
|
|
336
|
+
x = self.static_padding(x)
|
|
337
|
+
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
|
338
|
+
self.dilation, self.ceil_mode, self.return_indices)
|
|
339
|
+
return x
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
################################################################################
|
|
343
|
+
### Helper functions for loading model params
|
|
344
|
+
################################################################################
|
|
345
|
+
|
|
346
|
+
# BlockDecoder: A Class for encoding and decoding BlockArgs
|
|
347
|
+
# efficientnet_params: A function to query compound coefficient
|
|
348
|
+
# get_model_params and efficientnet:
|
|
349
|
+
# Functions to get BlockArgs and GlobalParams for efficientnet
|
|
350
|
+
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
|
|
351
|
+
# load_pretrained_weights: A function to load pretrained weights
|
|
352
|
+
|
|
353
|
+
class BlockDecoder(object):
|
|
354
|
+
"""Block Decoder for readability,
|
|
355
|
+
straight from the official TensorFlow repository.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def _decode_block_string(block_string):
|
|
360
|
+
"""Get a block through a string notation of arguments.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
block_string (str): A string notation of arguments.
|
|
364
|
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
BlockArgs: The namedtuple defined at the top of this file.
|
|
368
|
+
"""
|
|
369
|
+
assert isinstance(block_string, str)
|
|
370
|
+
|
|
371
|
+
ops = block_string.split('_')
|
|
372
|
+
options = {}
|
|
373
|
+
for op in ops:
|
|
374
|
+
splits = re.split(r'(\d.*)', op)
|
|
375
|
+
if len(splits) >= 2:
|
|
376
|
+
key, value = splits[:2]
|
|
377
|
+
options[key] = value
|
|
378
|
+
|
|
379
|
+
# Check stride
|
|
380
|
+
assert (('s' in options and len(options['s']) == 1) or
|
|
381
|
+
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
|
382
|
+
|
|
383
|
+
return BlockArgs(
|
|
384
|
+
num_repeat=int(options['r']),
|
|
385
|
+
kernel_size=int(options['k']),
|
|
386
|
+
stride=[int(options['s'][0])],
|
|
387
|
+
expand_ratio=int(options['e']),
|
|
388
|
+
input_filters=int(options['i']),
|
|
389
|
+
output_filters=int(options['o']),
|
|
390
|
+
se_ratio=float(options['se']) if 'se' in options else None,
|
|
391
|
+
id_skip=('noskip' not in block_string))
|
|
392
|
+
|
|
393
|
+
@staticmethod
|
|
394
|
+
def _encode_block_string(block):
|
|
395
|
+
"""Encode a block to a string.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
block (namedtuple): A BlockArgs type argument.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
block_string: A String form of BlockArgs.
|
|
402
|
+
"""
|
|
403
|
+
args = [
|
|
404
|
+
'r%d' % block.num_repeat,
|
|
405
|
+
'k%d' % block.kernel_size,
|
|
406
|
+
's%d%d' % (block.strides[0], block.strides[1]),
|
|
407
|
+
'e%s' % block.expand_ratio,
|
|
408
|
+
'i%d' % block.input_filters,
|
|
409
|
+
'o%d' % block.output_filters
|
|
410
|
+
]
|
|
411
|
+
if 0 < block.se_ratio <= 1:
|
|
412
|
+
args.append('se%s' % block.se_ratio)
|
|
413
|
+
if block.id_skip is False:
|
|
414
|
+
args.append('noskip')
|
|
415
|
+
return '_'.join(args)
|
|
416
|
+
|
|
417
|
+
@staticmethod
|
|
418
|
+
def decode(string_list):
|
|
419
|
+
"""Decode a list of string notations to specify blocks inside the network.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
|
426
|
+
"""
|
|
427
|
+
assert isinstance(string_list, list)
|
|
428
|
+
blocks_args = []
|
|
429
|
+
for block_string in string_list:
|
|
430
|
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
|
431
|
+
return blocks_args
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def encode(blocks_args):
|
|
435
|
+
"""Encode a list of BlockArgs to a list of strings.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
block_strings: A list of strings, each string is a notation of block.
|
|
442
|
+
"""
|
|
443
|
+
block_strings = []
|
|
444
|
+
for block in blocks_args:
|
|
445
|
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
|
446
|
+
return block_strings
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def efficientnet_params(model_name):
|
|
450
|
+
"""Map EfficientNet model name to parameter coefficients.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
model_name (str): Model name to be queried.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
params_dict[model_name]: A (width,depth,res,dropout) tuple.
|
|
457
|
+
"""
|
|
458
|
+
params_dict = {
|
|
459
|
+
# Coefficients: width,depth,res,dropout
|
|
460
|
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
|
461
|
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
|
462
|
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
|
463
|
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
|
464
|
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
|
465
|
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
|
466
|
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
|
467
|
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
|
468
|
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
|
469
|
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
|
470
|
+
}
|
|
471
|
+
return params_dict[model_name]
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
|
|
475
|
+
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
|
|
476
|
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
width_coefficient (float)
|
|
480
|
+
depth_coefficient (float)
|
|
481
|
+
image_size (int)
|
|
482
|
+
dropout_rate (float)
|
|
483
|
+
drop_connect_rate (float)
|
|
484
|
+
num_classes (int)
|
|
485
|
+
|
|
486
|
+
Meaning as the name suggests.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
blocks_args, global_params.
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
|
493
|
+
# It will be modified in the construction of EfficientNet Class according to model
|
|
494
|
+
blocks_args = [
|
|
495
|
+
'r1_k3_s11_e1_i32_o16_se0.25',
|
|
496
|
+
'r2_k3_s22_e6_i16_o24_se0.25',
|
|
497
|
+
'r2_k5_s22_e6_i24_o40_se0.25',
|
|
498
|
+
'r3_k3_s22_e6_i40_o80_se0.25',
|
|
499
|
+
'r3_k5_s11_e6_i80_o112_se0.25',
|
|
500
|
+
'r4_k5_s22_e6_i112_o192_se0.25',
|
|
501
|
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
|
502
|
+
]
|
|
503
|
+
blocks_args = BlockDecoder.decode(blocks_args)
|
|
504
|
+
|
|
505
|
+
global_params = GlobalParams(
|
|
506
|
+
width_coefficient=width_coefficient,
|
|
507
|
+
depth_coefficient=depth_coefficient,
|
|
508
|
+
image_size=image_size,
|
|
509
|
+
dropout_rate=dropout_rate,
|
|
510
|
+
|
|
511
|
+
num_classes=num_classes,
|
|
512
|
+
batch_norm_momentum=0.99,
|
|
513
|
+
batch_norm_epsilon=1e-3,
|
|
514
|
+
drop_connect_rate=drop_connect_rate,
|
|
515
|
+
depth_divisor=8,
|
|
516
|
+
min_depth=None,
|
|
517
|
+
include_top=include_top,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return blocks_args, global_params
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def get_model_params(model_name, override_params):
|
|
524
|
+
"""Get the block args and global params for a given model name.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
model_name (str): Model's name.
|
|
528
|
+
override_params (dict): A dict to modify global_params.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
blocks_args, global_params
|
|
532
|
+
"""
|
|
533
|
+
if model_name.startswith('efficientnet'):
|
|
534
|
+
w, d, s, p = efficientnet_params(model_name)
|
|
535
|
+
# note: all models have drop connect rate = 0.2
|
|
536
|
+
blocks_args, global_params = efficientnet(
|
|
537
|
+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
|
538
|
+
else:
|
|
539
|
+
raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
|
|
540
|
+
if override_params:
|
|
541
|
+
# ValueError will be raised here if override_params has fields not included in global_params.
|
|
542
|
+
global_params = global_params._replace(**override_params)
|
|
543
|
+
return blocks_args, global_params
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
# train with Standard methods
|
|
547
|
+
# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
|
|
548
|
+
url_map = {
|
|
549
|
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
|
|
550
|
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
|
|
551
|
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
|
|
552
|
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
|
|
553
|
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
|
|
554
|
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
|
|
555
|
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
|
|
556
|
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
# train with Adversarial Examples(AdvProp)
|
|
560
|
+
# check more details in paper(Adversarial Examples Improve Image Recognition)
|
|
561
|
+
url_map_advprop = {
|
|
562
|
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
|
|
563
|
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
|
|
564
|
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
|
|
565
|
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
|
|
566
|
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
|
|
567
|
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
|
|
568
|
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
|
|
569
|
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
|
|
570
|
+
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
# TODO: add the petrained weights url map of 'efficientnet-l2'
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False):
|
|
577
|
+
"""Loads pretrained weights from weights path or download using url.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
model (Module): The whole model of efficientnet.
|
|
581
|
+
model_name (str): Model name of efficientnet.
|
|
582
|
+
weights_path (None or str, optional):
|
|
583
|
+
str: path to pretrained weights file on the local disk.
|
|
584
|
+
None: use pretrained weights downloaded from the Internet.
|
|
585
|
+
load_fc (bool, optional): Whether to load pretrained weights for fc layer at the end
|
|
586
|
+
of the model.
|
|
587
|
+
advprop (bool, optional): Whether to load pretrained weights
|
|
588
|
+
trained with advprop (valid when weights_path is None).
|
|
589
|
+
"""
|
|
590
|
+
if isinstance(weights_path, str):
|
|
591
|
+
state_dict = torch.load(weights_path)
|
|
592
|
+
else:
|
|
593
|
+
# AutoAugment or Advprop (different preprocessing)
|
|
594
|
+
url_map_ = url_map_advprop if advprop else url_map
|
|
595
|
+
state_dict = model_zoo.load_url(url_map_[model_name])
|
|
596
|
+
|
|
597
|
+
if load_fc:
|
|
598
|
+
ret = model.load_state_dict(state_dict, strict=False)
|
|
599
|
+
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
|
600
|
+
else:
|
|
601
|
+
state_dict.pop('_fc.weight')
|
|
602
|
+
state_dict.pop('_fc.bias')
|
|
603
|
+
ret = model.load_state_dict(state_dict, strict=False)
|
|
604
|
+
assert set(ret.missing_keys) == set(
|
|
605
|
+
['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
|
606
|
+
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
|
|
607
|
+
|
|
608
|
+
print('Loaded pretrained weights for {}'.format(model_name))
|