megadetector 10.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- megadetector/__init__.py +0 -0
- megadetector/api/__init__.py +0 -0
- megadetector/api/batch_processing/integration/digiKam/setup.py +6 -0
- megadetector/api/batch_processing/integration/digiKam/xmp_integration.py +465 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/config_template.py +5 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +125 -0
- megadetector/api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +55 -0
- megadetector/classification/__init__.py +0 -0
- megadetector/classification/aggregate_classifier_probs.py +108 -0
- megadetector/classification/analyze_failed_images.py +227 -0
- megadetector/classification/cache_batchapi_outputs.py +198 -0
- megadetector/classification/create_classification_dataset.py +626 -0
- megadetector/classification/crop_detections.py +516 -0
- megadetector/classification/csv_to_json.py +226 -0
- megadetector/classification/detect_and_crop.py +853 -0
- megadetector/classification/efficientnet/__init__.py +9 -0
- megadetector/classification/efficientnet/model.py +415 -0
- megadetector/classification/efficientnet/utils.py +608 -0
- megadetector/classification/evaluate_model.py +520 -0
- megadetector/classification/identify_mislabeled_candidates.py +152 -0
- megadetector/classification/json_to_azcopy_list.py +63 -0
- megadetector/classification/json_validator.py +696 -0
- megadetector/classification/map_classification_categories.py +276 -0
- megadetector/classification/merge_classification_detection_output.py +509 -0
- megadetector/classification/prepare_classification_script.py +194 -0
- megadetector/classification/prepare_classification_script_mc.py +228 -0
- megadetector/classification/run_classifier.py +287 -0
- megadetector/classification/save_mislabeled.py +110 -0
- megadetector/classification/train_classifier.py +827 -0
- megadetector/classification/train_classifier_tf.py +725 -0
- megadetector/classification/train_utils.py +323 -0
- megadetector/data_management/__init__.py +0 -0
- megadetector/data_management/animl_to_md.py +161 -0
- megadetector/data_management/annotations/__init__.py +0 -0
- megadetector/data_management/annotations/annotation_constants.py +33 -0
- megadetector/data_management/camtrap_dp_to_coco.py +270 -0
- megadetector/data_management/cct_json_utils.py +566 -0
- megadetector/data_management/cct_to_md.py +184 -0
- megadetector/data_management/cct_to_wi.py +293 -0
- megadetector/data_management/coco_to_labelme.py +284 -0
- megadetector/data_management/coco_to_yolo.py +701 -0
- megadetector/data_management/databases/__init__.py +0 -0
- megadetector/data_management/databases/add_width_and_height_to_db.py +107 -0
- megadetector/data_management/databases/combine_coco_camera_traps_files.py +210 -0
- megadetector/data_management/databases/integrity_check_json_db.py +563 -0
- megadetector/data_management/databases/subset_json_db.py +195 -0
- megadetector/data_management/generate_crops_from_cct.py +200 -0
- megadetector/data_management/get_image_sizes.py +164 -0
- megadetector/data_management/labelme_to_coco.py +559 -0
- megadetector/data_management/labelme_to_yolo.py +349 -0
- megadetector/data_management/lila/__init__.py +0 -0
- megadetector/data_management/lila/create_lila_blank_set.py +556 -0
- megadetector/data_management/lila/create_lila_test_set.py +192 -0
- megadetector/data_management/lila/create_links_to_md_results_files.py +106 -0
- megadetector/data_management/lila/download_lila_subset.py +182 -0
- megadetector/data_management/lila/generate_lila_per_image_labels.py +777 -0
- megadetector/data_management/lila/get_lila_annotation_counts.py +174 -0
- megadetector/data_management/lila/get_lila_image_counts.py +112 -0
- megadetector/data_management/lila/lila_common.py +319 -0
- megadetector/data_management/lila/test_lila_metadata_urls.py +164 -0
- megadetector/data_management/mewc_to_md.py +344 -0
- megadetector/data_management/ocr_tools.py +873 -0
- megadetector/data_management/read_exif.py +964 -0
- megadetector/data_management/remap_coco_categories.py +195 -0
- megadetector/data_management/remove_exif.py +156 -0
- megadetector/data_management/rename_images.py +194 -0
- megadetector/data_management/resize_coco_dataset.py +665 -0
- megadetector/data_management/speciesnet_to_md.py +41 -0
- megadetector/data_management/wi_download_csv_to_coco.py +247 -0
- megadetector/data_management/yolo_output_to_md_output.py +594 -0
- megadetector/data_management/yolo_to_coco.py +984 -0
- megadetector/data_management/zamba_to_md.py +188 -0
- megadetector/detection/__init__.py +0 -0
- megadetector/detection/change_detection.py +840 -0
- megadetector/detection/process_video.py +479 -0
- megadetector/detection/pytorch_detector.py +1451 -0
- megadetector/detection/run_detector.py +1267 -0
- megadetector/detection/run_detector_batch.py +2172 -0
- megadetector/detection/run_inference_with_yolov5_val.py +1314 -0
- megadetector/detection/run_md_and_speciesnet.py +1604 -0
- megadetector/detection/run_tiled_inference.py +1044 -0
- megadetector/detection/tf_detector.py +209 -0
- megadetector/detection/video_utils.py +1379 -0
- megadetector/postprocessing/__init__.py +0 -0
- megadetector/postprocessing/add_max_conf.py +72 -0
- megadetector/postprocessing/categorize_detections_by_size.py +166 -0
- megadetector/postprocessing/classification_postprocessing.py +1943 -0
- megadetector/postprocessing/combine_batch_outputs.py +249 -0
- megadetector/postprocessing/compare_batch_results.py +2110 -0
- megadetector/postprocessing/convert_output_format.py +403 -0
- megadetector/postprocessing/create_crop_folder.py +629 -0
- megadetector/postprocessing/detector_calibration.py +570 -0
- megadetector/postprocessing/generate_csv_report.py +522 -0
- megadetector/postprocessing/load_api_results.py +223 -0
- megadetector/postprocessing/md_to_coco.py +428 -0
- megadetector/postprocessing/md_to_labelme.py +351 -0
- megadetector/postprocessing/md_to_wi.py +41 -0
- megadetector/postprocessing/merge_detections.py +392 -0
- megadetector/postprocessing/postprocess_batch_results.py +2140 -0
- megadetector/postprocessing/remap_detection_categories.py +226 -0
- megadetector/postprocessing/render_detection_confusion_matrix.py +677 -0
- megadetector/postprocessing/repeat_detection_elimination/find_repeat_detections.py +206 -0
- megadetector/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +82 -0
- megadetector/postprocessing/repeat_detection_elimination/repeat_detections_core.py +1665 -0
- megadetector/postprocessing/separate_detections_into_folders.py +795 -0
- megadetector/postprocessing/subset_json_detector_output.py +964 -0
- megadetector/postprocessing/top_folders_to_bottom.py +238 -0
- megadetector/postprocessing/validate_batch_results.py +332 -0
- megadetector/taxonomy_mapping/__init__.py +0 -0
- megadetector/taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +491 -0
- megadetector/taxonomy_mapping/map_new_lila_datasets.py +211 -0
- megadetector/taxonomy_mapping/prepare_lila_taxonomy_release.py +165 -0
- megadetector/taxonomy_mapping/preview_lila_taxonomy.py +543 -0
- megadetector/taxonomy_mapping/retrieve_sample_image.py +71 -0
- megadetector/taxonomy_mapping/simple_image_download.py +231 -0
- megadetector/taxonomy_mapping/species_lookup.py +1008 -0
- megadetector/taxonomy_mapping/taxonomy_csv_checker.py +159 -0
- megadetector/taxonomy_mapping/taxonomy_graph.py +346 -0
- megadetector/taxonomy_mapping/validate_lila_category_mappings.py +83 -0
- megadetector/tests/__init__.py +0 -0
- megadetector/tests/test_nms_synthetic.py +335 -0
- megadetector/utils/__init__.py +0 -0
- megadetector/utils/ct_utils.py +1857 -0
- megadetector/utils/directory_listing.py +199 -0
- megadetector/utils/extract_frames_from_video.py +307 -0
- megadetector/utils/gpu_test.py +125 -0
- megadetector/utils/md_tests.py +2072 -0
- megadetector/utils/path_utils.py +2872 -0
- megadetector/utils/process_utils.py +172 -0
- megadetector/utils/split_locations_into_train_val.py +237 -0
- megadetector/utils/string_utils.py +234 -0
- megadetector/utils/url_utils.py +825 -0
- megadetector/utils/wi_platform_utils.py +968 -0
- megadetector/utils/wi_taxonomy_utils.py +1766 -0
- megadetector/utils/write_html_image_list.py +239 -0
- megadetector/visualization/__init__.py +0 -0
- megadetector/visualization/plot_utils.py +309 -0
- megadetector/visualization/render_images_with_thumbnails.py +243 -0
- megadetector/visualization/visualization_utils.py +1973 -0
- megadetector/visualization/visualize_db.py +630 -0
- megadetector/visualization/visualize_detector_output.py +498 -0
- megadetector/visualization/visualize_video_output.py +705 -0
- megadetector-10.0.15.dist-info/METADATA +115 -0
- megadetector-10.0.15.dist-info/RECORD +147 -0
- megadetector-10.0.15.dist-info/WHEEL +5 -0
- megadetector-10.0.15.dist-info/licenses/LICENSE +19 -0
- megadetector-10.0.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
"""model.py - Model and module class for EfficientNet.
|
|
2
|
+
They are built to mirror those in the official TensorFlow implementation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Author: lukemelas (github username)
|
|
6
|
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
|
7
|
+
# With adjustments and added comments by workingcoder (github username).
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
from torch.nn import functional as F
|
|
12
|
+
from .utils import (
|
|
13
|
+
round_filters,
|
|
14
|
+
round_repeats,
|
|
15
|
+
drop_connect,
|
|
16
|
+
get_same_padding_conv2d,
|
|
17
|
+
get_model_params,
|
|
18
|
+
efficientnet_params,
|
|
19
|
+
load_pretrained_weights,
|
|
20
|
+
Swish,
|
|
21
|
+
MemoryEfficientSwish,
|
|
22
|
+
calculate_output_image_size
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
VALID_MODELS = (
|
|
27
|
+
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
|
|
28
|
+
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
|
|
29
|
+
'efficientnet-b8',
|
|
30
|
+
|
|
31
|
+
# Support the construction of 'efficientnet-l2' without pretrained weights
|
|
32
|
+
'efficientnet-l2'
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MBConvBlock(nn.Module):
|
|
37
|
+
"""Mobile Inverted Residual Bottleneck Block.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
|
41
|
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
|
42
|
+
image_size (tuple or list): [image_height, image_width].
|
|
43
|
+
|
|
44
|
+
References:
|
|
45
|
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
|
46
|
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
|
47
|
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, block_args, global_params, image_size=None):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self._block_args = block_args
|
|
53
|
+
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
|
54
|
+
self._bn_eps = global_params.batch_norm_epsilon
|
|
55
|
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
|
56
|
+
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
|
57
|
+
|
|
58
|
+
# Expansion phase (Inverted Bottleneck)
|
|
59
|
+
inp = self._block_args.input_filters # number of input channels
|
|
60
|
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
|
61
|
+
if self._block_args.expand_ratio != 1:
|
|
62
|
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
63
|
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
|
64
|
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
|
65
|
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
|
66
|
+
|
|
67
|
+
# Depthwise convolution phase
|
|
68
|
+
k = self._block_args.kernel_size
|
|
69
|
+
s = self._block_args.stride
|
|
70
|
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
71
|
+
self._depthwise_conv = Conv2d(
|
|
72
|
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
|
73
|
+
kernel_size=k, stride=s, bias=False)
|
|
74
|
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
|
75
|
+
image_size = calculate_output_image_size(image_size, s)
|
|
76
|
+
|
|
77
|
+
# Squeeze and Excitation layer, if desired
|
|
78
|
+
if self.has_se:
|
|
79
|
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
|
80
|
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
|
81
|
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
|
82
|
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
|
83
|
+
|
|
84
|
+
# Pointwise convolution phase
|
|
85
|
+
final_oup = self._block_args.output_filters
|
|
86
|
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
87
|
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
|
88
|
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
|
89
|
+
self._swish = MemoryEfficientSwish()
|
|
90
|
+
|
|
91
|
+
def forward(self, inputs, drop_connect_rate=None):
|
|
92
|
+
"""MBConvBlock's forward function.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
inputs (tensor): Input tensor.
|
|
96
|
+
drop_connect_rate (bool, optional): Drop connect rate (float, between 0 and 1).
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Output of this block after processing.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
# Expansion and Depthwise Convolution
|
|
103
|
+
x = inputs
|
|
104
|
+
if self._block_args.expand_ratio != 1:
|
|
105
|
+
x = self._expand_conv(inputs)
|
|
106
|
+
x = self._bn0(x)
|
|
107
|
+
x = self._swish(x)
|
|
108
|
+
|
|
109
|
+
x = self._depthwise_conv(x)
|
|
110
|
+
x = self._bn1(x)
|
|
111
|
+
x = self._swish(x)
|
|
112
|
+
|
|
113
|
+
# Squeeze and Excitation
|
|
114
|
+
if self.has_se:
|
|
115
|
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
|
116
|
+
x_squeezed = self._se_reduce(x_squeezed)
|
|
117
|
+
x_squeezed = self._swish(x_squeezed)
|
|
118
|
+
x_squeezed = self._se_expand(x_squeezed)
|
|
119
|
+
x = torch.sigmoid(x_squeezed) * x
|
|
120
|
+
|
|
121
|
+
# Pointwise Convolution
|
|
122
|
+
x = self._project_conv(x)
|
|
123
|
+
x = self._bn2(x)
|
|
124
|
+
|
|
125
|
+
# Skip connection and drop connect
|
|
126
|
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
|
127
|
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
|
128
|
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
|
129
|
+
if drop_connect_rate:
|
|
130
|
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
|
131
|
+
x = x + inputs # skip connection
|
|
132
|
+
return x
|
|
133
|
+
|
|
134
|
+
def set_swish(self, memory_efficient=True):
|
|
135
|
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
memory_efficient (bool, optional): Whether to use memory-efficient version of swish.
|
|
139
|
+
"""
|
|
140
|
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class EfficientNet(nn.Module):
|
|
144
|
+
"""EfficientNet model.
|
|
145
|
+
Most easily loaded with the .from_name or .from_pretrained methods.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
|
|
149
|
+
global_params (namedtuple): A set of GlobalParams shared between blocks.
|
|
150
|
+
|
|
151
|
+
References:
|
|
152
|
+
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
import torch
|
|
158
|
+
>>> from efficientnet.model import EfficientNet
|
|
159
|
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
|
160
|
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
|
161
|
+
>>> model.eval()
|
|
162
|
+
>>> outputs = model(inputs)
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(self, blocks_args=None, global_params=None):
|
|
166
|
+
super().__init__()
|
|
167
|
+
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
|
168
|
+
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
|
169
|
+
self._global_params = global_params
|
|
170
|
+
self._blocks_args = blocks_args
|
|
171
|
+
|
|
172
|
+
# Batch norm parameters
|
|
173
|
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
|
174
|
+
bn_eps = self._global_params.batch_norm_epsilon
|
|
175
|
+
|
|
176
|
+
# Get stem static or dynamic convolution depending on image size
|
|
177
|
+
image_size = global_params.image_size
|
|
178
|
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
179
|
+
|
|
180
|
+
# Stem
|
|
181
|
+
in_channels = 3 # rgb
|
|
182
|
+
out_channels = round_filters(32, self._global_params) # number of output channels
|
|
183
|
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
|
184
|
+
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
|
185
|
+
image_size = calculate_output_image_size(image_size, 2)
|
|
186
|
+
|
|
187
|
+
# Build blocks
|
|
188
|
+
self._blocks = nn.ModuleList([])
|
|
189
|
+
for block_args in self._blocks_args:
|
|
190
|
+
|
|
191
|
+
# Update block input and output filters based on depth multiplier.
|
|
192
|
+
block_args = block_args._replace(
|
|
193
|
+
input_filters=round_filters(block_args.input_filters, self._global_params),
|
|
194
|
+
output_filters=round_filters(block_args.output_filters, self._global_params),
|
|
195
|
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# The first block needs to take care of stride and filter size increase.
|
|
199
|
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
|
200
|
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
|
201
|
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
|
202
|
+
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
|
203
|
+
for _ in range(block_args.num_repeat - 1):
|
|
204
|
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
|
205
|
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
|
206
|
+
|
|
207
|
+
# Head
|
|
208
|
+
in_channels = block_args.output_filters # output of final block
|
|
209
|
+
out_channels = round_filters(1280, self._global_params)
|
|
210
|
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
|
211
|
+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
|
212
|
+
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
|
213
|
+
|
|
214
|
+
# Final linear layer
|
|
215
|
+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
|
216
|
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
|
217
|
+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
|
218
|
+
self._swish = MemoryEfficientSwish()
|
|
219
|
+
|
|
220
|
+
def set_swish(self, memory_efficient=True):
|
|
221
|
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
memory_efficient (bool, optional): Whether to use memory-efficient version of swish.
|
|
225
|
+
|
|
226
|
+
"""
|
|
227
|
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
|
228
|
+
for block in self._blocks:
|
|
229
|
+
block.set_swish(memory_efficient)
|
|
230
|
+
|
|
231
|
+
def extract_endpoints(self, inputs):
|
|
232
|
+
"""Use convolution layer to extract features
|
|
233
|
+
from reduction levels i in [1, 2, 3, 4, 5].
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
inputs (tensor): Input tensor.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Dictionary of last intermediate features
|
|
240
|
+
with reduction levels i in [1, 2, 3, 4, 5].
|
|
241
|
+
Example:
|
|
242
|
+
>>> import torch
|
|
243
|
+
>>> from efficientnet.model import EfficientNet
|
|
244
|
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
|
245
|
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
|
246
|
+
>>> endpoints = model.extract_endpoints(inputs)
|
|
247
|
+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
|
|
248
|
+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
|
|
249
|
+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
|
|
250
|
+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
|
|
251
|
+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
|
|
252
|
+
"""
|
|
253
|
+
endpoints = dict()
|
|
254
|
+
|
|
255
|
+
# Stem
|
|
256
|
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
|
257
|
+
prev_x = x
|
|
258
|
+
|
|
259
|
+
# Blocks
|
|
260
|
+
for idx, block in enumerate(self._blocks):
|
|
261
|
+
drop_connect_rate = self._global_params.drop_connect_rate
|
|
262
|
+
if drop_connect_rate:
|
|
263
|
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
|
264
|
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
|
265
|
+
if prev_x.size(2) > x.size(2):
|
|
266
|
+
endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x
|
|
267
|
+
prev_x = x
|
|
268
|
+
|
|
269
|
+
# Head
|
|
270
|
+
x = self._swish(self._bn1(self._conv_head(x)))
|
|
271
|
+
endpoints['reduction_{}'.format(len(endpoints)+1)] = x
|
|
272
|
+
|
|
273
|
+
return endpoints
|
|
274
|
+
|
|
275
|
+
def extract_features(self, inputs):
|
|
276
|
+
"""use convolution layer to extract feature .
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
inputs (tensor): Input tensor.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Output of the final convolution
|
|
283
|
+
layer in the efficientnet model.
|
|
284
|
+
"""
|
|
285
|
+
# Stem
|
|
286
|
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
|
287
|
+
|
|
288
|
+
# Blocks
|
|
289
|
+
for idx, block in enumerate(self._blocks):
|
|
290
|
+
drop_connect_rate = self._global_params.drop_connect_rate
|
|
291
|
+
if drop_connect_rate:
|
|
292
|
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
|
293
|
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
|
294
|
+
|
|
295
|
+
# Head
|
|
296
|
+
x = self._swish(self._bn1(self._conv_head(x)))
|
|
297
|
+
|
|
298
|
+
return x
|
|
299
|
+
|
|
300
|
+
def forward(self, inputs):
|
|
301
|
+
"""EfficientNet's forward function.
|
|
302
|
+
Calls extract_features to extract features, applies final linear layer, and returns logits.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
inputs (tensor): Input tensor.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Output of this model after processing.
|
|
309
|
+
"""
|
|
310
|
+
# Convolution layers
|
|
311
|
+
x = self.extract_features(inputs)
|
|
312
|
+
# Pooling and final linear layer
|
|
313
|
+
x = self._avg_pooling(x)
|
|
314
|
+
if self._global_params.include_top:
|
|
315
|
+
x = x.flatten(start_dim=1)
|
|
316
|
+
x = self._dropout(x)
|
|
317
|
+
x = self._fc(x)
|
|
318
|
+
return x
|
|
319
|
+
|
|
320
|
+
@classmethod
|
|
321
|
+
def from_name(cls, model_name, in_channels=3, **override_params):
|
|
322
|
+
"""create an efficientnet model according to name.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
model_name (str): Name for efficientnet.
|
|
326
|
+
in_channels (int, optional): Input data's channel number.
|
|
327
|
+
override_params (other key word params):
|
|
328
|
+
Params to override model's global_params.
|
|
329
|
+
Optional key:
|
|
330
|
+
'width_coefficient', 'depth_coefficient',
|
|
331
|
+
'image_size', 'dropout_rate',
|
|
332
|
+
'num_classes', 'batch_norm_momentum',
|
|
333
|
+
'batch_norm_epsilon', 'drop_connect_rate',
|
|
334
|
+
'depth_divisor', 'min_depth'
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
An efficientnet model.
|
|
338
|
+
"""
|
|
339
|
+
cls._check_model_name_is_valid(model_name)
|
|
340
|
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
|
341
|
+
model = cls(blocks_args, global_params)
|
|
342
|
+
model._change_in_channels(in_channels)
|
|
343
|
+
return model
|
|
344
|
+
|
|
345
|
+
@classmethod
|
|
346
|
+
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
|
|
347
|
+
in_channels=3, num_classes=1000, **override_params):
|
|
348
|
+
"""create an efficientnet model according to name.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
model_name (str): Name for efficientnet.
|
|
352
|
+
weights_path (None or str, optional):
|
|
353
|
+
str: path to pretrained weights file on the local disk.
|
|
354
|
+
None: use pretrained weights downloaded from the Internet.
|
|
355
|
+
advprop (bool, optional):
|
|
356
|
+
Whether to load pretrained weights
|
|
357
|
+
trained with advprop (valid when weights_path is None).
|
|
358
|
+
in_channels (int, optional): Input data's channel number.
|
|
359
|
+
num_classes (int, optional):
|
|
360
|
+
Number of categories for classification.
|
|
361
|
+
It controls the output size for final linear layer.
|
|
362
|
+
override_params (other key word params):
|
|
363
|
+
Params to override model's global_params.
|
|
364
|
+
Optional key:
|
|
365
|
+
'width_coefficient', 'depth_coefficient',
|
|
366
|
+
'image_size', 'dropout_rate',
|
|
367
|
+
'batch_norm_momentum',
|
|
368
|
+
'batch_norm_epsilon', 'drop_connect_rate',
|
|
369
|
+
'depth_divisor', 'min_depth'
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
A pretrained efficientnet model.
|
|
373
|
+
"""
|
|
374
|
+
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
|
|
375
|
+
load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
|
|
376
|
+
model._change_in_channels(in_channels)
|
|
377
|
+
return model
|
|
378
|
+
|
|
379
|
+
@classmethod
|
|
380
|
+
def get_image_size(cls, model_name):
|
|
381
|
+
"""Get the input image size for a given efficientnet model.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
model_name (str): Name for efficientnet.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
Input image size (resolution).
|
|
388
|
+
"""
|
|
389
|
+
cls._check_model_name_is_valid(model_name)
|
|
390
|
+
_, _, res, _ = efficientnet_params(model_name)
|
|
391
|
+
return res
|
|
392
|
+
|
|
393
|
+
@classmethod
|
|
394
|
+
def _check_model_name_is_valid(cls, model_name):
|
|
395
|
+
"""Validates model name.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
model_name (str): Name for efficientnet.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
bool: Is a valid name or not.
|
|
402
|
+
"""
|
|
403
|
+
if model_name not in VALID_MODELS:
|
|
404
|
+
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
|
|
405
|
+
|
|
406
|
+
def _change_in_channels(self, in_channels):
|
|
407
|
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
in_channels (int): Input data's channel number.
|
|
411
|
+
"""
|
|
412
|
+
if in_channels != 3:
|
|
413
|
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
|
414
|
+
out_channels = round_filters(32, self._global_params)
|
|
415
|
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|