megadetector 5.0.10__py3-none-any.whl → 5.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (226) hide show
  1. {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/LICENSE +0 -0
  2. {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/METADATA +12 -11
  3. megadetector-5.0.11.dist-info/RECORD +5 -0
  4. megadetector-5.0.11.dist-info/top_level.txt +1 -0
  5. api/__init__.py +0 -0
  6. api/batch_processing/__init__.py +0 -0
  7. api/batch_processing/api_core/__init__.py +0 -0
  8. api/batch_processing/api_core/batch_service/__init__.py +0 -0
  9. api/batch_processing/api_core/batch_service/score.py +0 -439
  10. api/batch_processing/api_core/server.py +0 -294
  11. api/batch_processing/api_core/server_api_config.py +0 -98
  12. api/batch_processing/api_core/server_app_config.py +0 -55
  13. api/batch_processing/api_core/server_batch_job_manager.py +0 -220
  14. api/batch_processing/api_core/server_job_status_table.py +0 -152
  15. api/batch_processing/api_core/server_orchestration.py +0 -360
  16. api/batch_processing/api_core/server_utils.py +0 -92
  17. api/batch_processing/api_core_support/__init__.py +0 -0
  18. api/batch_processing/api_core_support/aggregate_results_manually.py +0 -46
  19. api/batch_processing/api_support/__init__.py +0 -0
  20. api/batch_processing/api_support/summarize_daily_activity.py +0 -152
  21. api/batch_processing/data_preparation/__init__.py +0 -0
  22. api/batch_processing/data_preparation/manage_local_batch.py +0 -2391
  23. api/batch_processing/data_preparation/manage_video_batch.py +0 -327
  24. api/batch_processing/integration/digiKam/setup.py +0 -6
  25. api/batch_processing/integration/digiKam/xmp_integration.py +0 -465
  26. api/batch_processing/integration/eMammal/test_scripts/config_template.py +0 -5
  27. api/batch_processing/integration/eMammal/test_scripts/push_annotations_to_emammal.py +0 -126
  28. api/batch_processing/integration/eMammal/test_scripts/select_images_for_testing.py +0 -55
  29. api/batch_processing/postprocessing/__init__.py +0 -0
  30. api/batch_processing/postprocessing/add_max_conf.py +0 -64
  31. api/batch_processing/postprocessing/categorize_detections_by_size.py +0 -163
  32. api/batch_processing/postprocessing/combine_api_outputs.py +0 -249
  33. api/batch_processing/postprocessing/compare_batch_results.py +0 -958
  34. api/batch_processing/postprocessing/convert_output_format.py +0 -397
  35. api/batch_processing/postprocessing/load_api_results.py +0 -195
  36. api/batch_processing/postprocessing/md_to_coco.py +0 -310
  37. api/batch_processing/postprocessing/md_to_labelme.py +0 -330
  38. api/batch_processing/postprocessing/merge_detections.py +0 -401
  39. api/batch_processing/postprocessing/postprocess_batch_results.py +0 -1904
  40. api/batch_processing/postprocessing/remap_detection_categories.py +0 -170
  41. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +0 -661
  42. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +0 -211
  43. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +0 -82
  44. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +0 -1631
  45. api/batch_processing/postprocessing/separate_detections_into_folders.py +0 -731
  46. api/batch_processing/postprocessing/subset_json_detector_output.py +0 -696
  47. api/batch_processing/postprocessing/top_folders_to_bottom.py +0 -223
  48. api/synchronous/__init__.py +0 -0
  49. api/synchronous/api_core/animal_detection_api/__init__.py +0 -0
  50. api/synchronous/api_core/animal_detection_api/api_backend.py +0 -152
  51. api/synchronous/api_core/animal_detection_api/api_frontend.py +0 -266
  52. api/synchronous/api_core/animal_detection_api/config.py +0 -35
  53. api/synchronous/api_core/animal_detection_api/data_management/annotations/annotation_constants.py +0 -47
  54. api/synchronous/api_core/animal_detection_api/detection/detector_training/copy_checkpoints.py +0 -43
  55. api/synchronous/api_core/animal_detection_api/detection/detector_training/model_main_tf2.py +0 -114
  56. api/synchronous/api_core/animal_detection_api/detection/process_video.py +0 -543
  57. api/synchronous/api_core/animal_detection_api/detection/pytorch_detector.py +0 -304
  58. api/synchronous/api_core/animal_detection_api/detection/run_detector.py +0 -627
  59. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +0 -1029
  60. api/synchronous/api_core/animal_detection_api/detection/run_inference_with_yolov5_val.py +0 -581
  61. api/synchronous/api_core/animal_detection_api/detection/run_tiled_inference.py +0 -754
  62. api/synchronous/api_core/animal_detection_api/detection/tf_detector.py +0 -165
  63. api/synchronous/api_core/animal_detection_api/detection/video_utils.py +0 -495
  64. api/synchronous/api_core/animal_detection_api/md_utils/azure_utils.py +0 -174
  65. api/synchronous/api_core/animal_detection_api/md_utils/ct_utils.py +0 -262
  66. api/synchronous/api_core/animal_detection_api/md_utils/directory_listing.py +0 -251
  67. api/synchronous/api_core/animal_detection_api/md_utils/matlab_porting_tools.py +0 -97
  68. api/synchronous/api_core/animal_detection_api/md_utils/path_utils.py +0 -416
  69. api/synchronous/api_core/animal_detection_api/md_utils/process_utils.py +0 -110
  70. api/synchronous/api_core/animal_detection_api/md_utils/sas_blob_utils.py +0 -509
  71. api/synchronous/api_core/animal_detection_api/md_utils/string_utils.py +0 -59
  72. api/synchronous/api_core/animal_detection_api/md_utils/url_utils.py +0 -144
  73. api/synchronous/api_core/animal_detection_api/md_utils/write_html_image_list.py +0 -226
  74. api/synchronous/api_core/animal_detection_api/md_visualization/visualization_utils.py +0 -841
  75. api/synchronous/api_core/tests/__init__.py +0 -0
  76. api/synchronous/api_core/tests/load_test.py +0 -110
  77. classification/__init__.py +0 -0
  78. classification/aggregate_classifier_probs.py +0 -108
  79. classification/analyze_failed_images.py +0 -227
  80. classification/cache_batchapi_outputs.py +0 -198
  81. classification/create_classification_dataset.py +0 -627
  82. classification/crop_detections.py +0 -516
  83. classification/csv_to_json.py +0 -226
  84. classification/detect_and_crop.py +0 -855
  85. classification/efficientnet/__init__.py +0 -9
  86. classification/efficientnet/model.py +0 -415
  87. classification/efficientnet/utils.py +0 -610
  88. classification/evaluate_model.py +0 -520
  89. classification/identify_mislabeled_candidates.py +0 -152
  90. classification/json_to_azcopy_list.py +0 -63
  91. classification/json_validator.py +0 -695
  92. classification/map_classification_categories.py +0 -276
  93. classification/merge_classification_detection_output.py +0 -506
  94. classification/prepare_classification_script.py +0 -194
  95. classification/prepare_classification_script_mc.py +0 -228
  96. classification/run_classifier.py +0 -286
  97. classification/save_mislabeled.py +0 -110
  98. classification/train_classifier.py +0 -825
  99. classification/train_classifier_tf.py +0 -724
  100. classification/train_utils.py +0 -322
  101. data_management/__init__.py +0 -0
  102. data_management/annotations/__init__.py +0 -0
  103. data_management/annotations/annotation_constants.py +0 -34
  104. data_management/camtrap_dp_to_coco.py +0 -238
  105. data_management/cct_json_utils.py +0 -395
  106. data_management/cct_to_md.py +0 -176
  107. data_management/cct_to_wi.py +0 -289
  108. data_management/coco_to_labelme.py +0 -272
  109. data_management/coco_to_yolo.py +0 -662
  110. data_management/databases/__init__.py +0 -0
  111. data_management/databases/add_width_and_height_to_db.py +0 -33
  112. data_management/databases/combine_coco_camera_traps_files.py +0 -206
  113. data_management/databases/integrity_check_json_db.py +0 -477
  114. data_management/databases/subset_json_db.py +0 -115
  115. data_management/generate_crops_from_cct.py +0 -149
  116. data_management/get_image_sizes.py +0 -188
  117. data_management/importers/add_nacti_sizes.py +0 -52
  118. data_management/importers/add_timestamps_to_icct.py +0 -79
  119. data_management/importers/animl_results_to_md_results.py +0 -158
  120. data_management/importers/auckland_doc_test_to_json.py +0 -372
  121. data_management/importers/auckland_doc_to_json.py +0 -200
  122. data_management/importers/awc_to_json.py +0 -189
  123. data_management/importers/bellevue_to_json.py +0 -273
  124. data_management/importers/cacophony-thermal-importer.py +0 -796
  125. data_management/importers/carrizo_shrubfree_2018.py +0 -268
  126. data_management/importers/carrizo_trail_cam_2017.py +0 -287
  127. data_management/importers/cct_field_adjustments.py +0 -57
  128. data_management/importers/channel_islands_to_cct.py +0 -913
  129. data_management/importers/eMammal/copy_and_unzip_emammal.py +0 -180
  130. data_management/importers/eMammal/eMammal_helpers.py +0 -249
  131. data_management/importers/eMammal/make_eMammal_json.py +0 -223
  132. data_management/importers/ena24_to_json.py +0 -275
  133. data_management/importers/filenames_to_json.py +0 -385
  134. data_management/importers/helena_to_cct.py +0 -282
  135. data_management/importers/idaho-camera-traps.py +0 -1407
  136. data_management/importers/idfg_iwildcam_lila_prep.py +0 -294
  137. data_management/importers/jb_csv_to_json.py +0 -150
  138. data_management/importers/mcgill_to_json.py +0 -250
  139. data_management/importers/missouri_to_json.py +0 -489
  140. data_management/importers/nacti_fieldname_adjustments.py +0 -79
  141. data_management/importers/noaa_seals_2019.py +0 -181
  142. data_management/importers/pc_to_json.py +0 -365
  143. data_management/importers/plot_wni_giraffes.py +0 -123
  144. data_management/importers/prepare-noaa-fish-data-for-lila.py +0 -359
  145. data_management/importers/prepare_zsl_imerit.py +0 -131
  146. data_management/importers/rspb_to_json.py +0 -356
  147. data_management/importers/save_the_elephants_survey_A.py +0 -320
  148. data_management/importers/save_the_elephants_survey_B.py +0 -332
  149. data_management/importers/snapshot_safari_importer.py +0 -758
  150. data_management/importers/snapshot_safari_importer_reprise.py +0 -665
  151. data_management/importers/snapshot_serengeti_lila.py +0 -1067
  152. data_management/importers/snapshotserengeti/make_full_SS_json.py +0 -150
  153. data_management/importers/snapshotserengeti/make_per_season_SS_json.py +0 -153
  154. data_management/importers/sulross_get_exif.py +0 -65
  155. data_management/importers/timelapse_csv_set_to_json.py +0 -490
  156. data_management/importers/ubc_to_json.py +0 -399
  157. data_management/importers/umn_to_json.py +0 -507
  158. data_management/importers/wellington_to_json.py +0 -263
  159. data_management/importers/wi_to_json.py +0 -441
  160. data_management/importers/zamba_results_to_md_results.py +0 -181
  161. data_management/labelme_to_coco.py +0 -548
  162. data_management/labelme_to_yolo.py +0 -272
  163. data_management/lila/__init__.py +0 -0
  164. data_management/lila/add_locations_to_island_camera_traps.py +0 -97
  165. data_management/lila/add_locations_to_nacti.py +0 -147
  166. data_management/lila/create_lila_blank_set.py +0 -557
  167. data_management/lila/create_lila_test_set.py +0 -151
  168. data_management/lila/create_links_to_md_results_files.py +0 -106
  169. data_management/lila/download_lila_subset.py +0 -177
  170. data_management/lila/generate_lila_per_image_labels.py +0 -515
  171. data_management/lila/get_lila_annotation_counts.py +0 -170
  172. data_management/lila/get_lila_image_counts.py +0 -111
  173. data_management/lila/lila_common.py +0 -300
  174. data_management/lila/test_lila_metadata_urls.py +0 -132
  175. data_management/ocr_tools.py +0 -874
  176. data_management/read_exif.py +0 -681
  177. data_management/remap_coco_categories.py +0 -84
  178. data_management/remove_exif.py +0 -66
  179. data_management/resize_coco_dataset.py +0 -189
  180. data_management/wi_download_csv_to_coco.py +0 -246
  181. data_management/yolo_output_to_md_output.py +0 -441
  182. data_management/yolo_to_coco.py +0 -676
  183. detection/__init__.py +0 -0
  184. detection/detector_training/__init__.py +0 -0
  185. detection/detector_training/model_main_tf2.py +0 -114
  186. detection/process_video.py +0 -703
  187. detection/pytorch_detector.py +0 -337
  188. detection/run_detector.py +0 -779
  189. detection/run_detector_batch.py +0 -1219
  190. detection/run_inference_with_yolov5_val.py +0 -917
  191. detection/run_tiled_inference.py +0 -935
  192. detection/tf_detector.py +0 -188
  193. detection/video_utils.py +0 -606
  194. docs/source/conf.py +0 -43
  195. md_utils/__init__.py +0 -0
  196. md_utils/azure_utils.py +0 -174
  197. md_utils/ct_utils.py +0 -612
  198. md_utils/directory_listing.py +0 -246
  199. md_utils/md_tests.py +0 -968
  200. md_utils/path_utils.py +0 -1044
  201. md_utils/process_utils.py +0 -157
  202. md_utils/sas_blob_utils.py +0 -509
  203. md_utils/split_locations_into_train_val.py +0 -228
  204. md_utils/string_utils.py +0 -92
  205. md_utils/url_utils.py +0 -323
  206. md_utils/write_html_image_list.py +0 -225
  207. md_visualization/__init__.py +0 -0
  208. md_visualization/plot_utils.py +0 -293
  209. md_visualization/render_images_with_thumbnails.py +0 -275
  210. md_visualization/visualization_utils.py +0 -1537
  211. md_visualization/visualize_db.py +0 -551
  212. md_visualization/visualize_detector_output.py +0 -406
  213. megadetector-5.0.10.dist-info/RECORD +0 -224
  214. megadetector-5.0.10.dist-info/top_level.txt +0 -8
  215. taxonomy_mapping/__init__.py +0 -0
  216. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +0 -491
  217. taxonomy_mapping/map_new_lila_datasets.py +0 -154
  218. taxonomy_mapping/prepare_lila_taxonomy_release.py +0 -142
  219. taxonomy_mapping/preview_lila_taxonomy.py +0 -591
  220. taxonomy_mapping/retrieve_sample_image.py +0 -71
  221. taxonomy_mapping/simple_image_download.py +0 -218
  222. taxonomy_mapping/species_lookup.py +0 -834
  223. taxonomy_mapping/taxonomy_csv_checker.py +0 -159
  224. taxonomy_mapping/taxonomy_graph.py +0 -346
  225. taxonomy_mapping/validate_lila_category_mappings.py +0 -83
  226. {megadetector-5.0.10.dist-info → megadetector-5.0.11.dist-info}/WHEEL +0 -0
@@ -1,610 +0,0 @@
1
- """utils.py - Helper functions for building the model and for loading model parameters.
2
- These helper functions are built to mirror those in the official TensorFlow implementation.
3
- """
4
-
5
- # Author: lukemelas (github username)
6
- # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
- # With adjustments and added comments by workingcoder (github username).
8
-
9
- import re
10
- import math
11
- import collections
12
- from functools import partial
13
- import torch
14
- from torch import nn
15
- from torch.nn import functional as F
16
- from torch.utils import model_zoo
17
-
18
-
19
- ################################################################################
20
- ### Help functions for model architecture
21
- ################################################################################
22
-
23
- # GlobalParams and BlockArgs: Two namedtuples
24
- # Swish and MemoryEfficientSwish: Two implementations of the method
25
- # round_filters and round_repeats:
26
- # Functions to calculate params for scaling model width and depth ! ! !
27
- # get_width_and_height_from_size and calculate_output_image_size
28
- # drop_connect: A structural design
29
- # get_same_padding_conv2d:
30
- # Conv2dDynamicSamePadding
31
- # Conv2dStaticSamePadding
32
- # get_same_padding_maxPool2d:
33
- # MaxPool2dDynamicSamePadding
34
- # MaxPool2dStaticSamePadding
35
- # It's an additional function, not used in EfficientNet,
36
- # but can be used in other model (such as EfficientDet).
37
-
38
- # Parameters for the entire model (stem, all blocks, and head)
39
- GlobalParams = collections.namedtuple('GlobalParams', [
40
- 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
41
- 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
42
- 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
43
-
44
- # Parameters for an individual model block
45
- BlockArgs = collections.namedtuple('BlockArgs', [
46
- 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
47
- 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
48
-
49
- # Set GlobalParams and BlockArgs's defaults
50
- GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
51
- BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
52
-
53
-
54
- # An ordinary implementation of Swish function
55
- class Swish(nn.Module):
56
- def forward(self, x):
57
- return x * torch.sigmoid(x)
58
-
59
-
60
- # A memory-efficient implementation of Swish function
61
- class SwishImplementation(torch.autograd.Function):
62
- @staticmethod
63
- def forward(ctx, i):
64
- result = i * torch.sigmoid(i)
65
- ctx.save_for_backward(i)
66
- return result
67
-
68
- @staticmethod
69
- def backward(ctx, grad_output):
70
- i = ctx.saved_tensors[0]
71
- sigmoid_i = torch.sigmoid(i)
72
- return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
73
-
74
- class MemoryEfficientSwish(nn.Module):
75
- def forward(self, x):
76
- return SwishImplementation.apply(x)
77
-
78
-
79
- def round_filters(filters, global_params):
80
- """Calculate and round number of filters based on width multiplier.
81
- Use width_coefficient, depth_divisor and min_depth of global_params.
82
-
83
- Args:
84
- filters (int): Filters number to be calculated.
85
- global_params (namedtuple): Global params of the model.
86
-
87
- Returns:
88
- new_filters: New filters number after calculating.
89
- """
90
- multiplier = global_params.width_coefficient
91
- if not multiplier:
92
- return filters
93
- # TODO: modify the params names.
94
- # maybe the names (width_divisor,min_width)
95
- # are more suitable than (depth_divisor,min_depth).
96
- divisor = global_params.depth_divisor
97
- min_depth = global_params.min_depth
98
- filters *= multiplier
99
- min_depth = min_depth or divisor # pay attention to this line when using min_depth
100
- # follow the formula transferred from official TensorFlow implementation
101
- new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
102
- if new_filters < 0.9 * filters: # prevent rounding by more than 10%
103
- new_filters += divisor
104
- return int(new_filters)
105
-
106
-
107
- def round_repeats(repeats, global_params):
108
- """Calculate module's repeat number of a block based on depth multiplier.
109
- Use depth_coefficient of global_params.
110
-
111
- Args:
112
- repeats (int): num_repeat to be calculated.
113
- global_params (namedtuple): Global params of the model.
114
-
115
- Returns:
116
- new repeat: New repeat number after calculating.
117
- """
118
- multiplier = global_params.depth_coefficient
119
- if not multiplier:
120
- return repeats
121
- # follow the formula transferred from official TensorFlow implementation
122
- return int(math.ceil(multiplier * repeats))
123
-
124
-
125
- def drop_connect(inputs, p, training):
126
- """Drop connect.
127
-
128
- Args:
129
- input (tensor: BCWH): Input of this structure.
130
- p (float: 0.0~1.0): Probability of drop connection.
131
- training (bool): The running mode.
132
-
133
- Returns:
134
- output: Output after drop connection.
135
- """
136
- assert 0 <= p <= 1, 'p must be in range of [0,1]'
137
-
138
- if not training:
139
- return inputs
140
-
141
- batch_size = inputs.shape[0]
142
- keep_prob = 1 - p
143
-
144
- # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
145
- random_tensor = keep_prob
146
- random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
147
- binary_tensor = torch.floor(random_tensor)
148
-
149
- output = inputs / keep_prob * binary_tensor
150
- return output
151
-
152
-
153
- def get_width_and_height_from_size(x):
154
- """Obtain height and width from x.
155
-
156
- Args:
157
- x (int, tuple or list): Data size.
158
-
159
- Returns:
160
- size: A tuple or list (H,W).
161
- """
162
- if isinstance(x, int):
163
- return x, x
164
- if isinstance(x, list) or isinstance(x, tuple):
165
- return x
166
- else:
167
- raise TypeError()
168
-
169
-
170
- def calculate_output_image_size(input_image_size, stride):
171
- """Calculates the output image size when using Conv2dSamePadding with a stride.
172
- Necessary for static padding. Thanks to mannatsingh for pointing this out.
173
-
174
- Args:
175
- input_image_size (int, tuple or list): Size of input image.
176
- stride (int, tuple or list): Conv2d operation's stride.
177
-
178
- Returns:
179
- output_image_size: A list [H,W].
180
- """
181
- if input_image_size is None:
182
- return None
183
- image_height, image_width = get_width_and_height_from_size(input_image_size)
184
- stride = stride if isinstance(stride, int) else stride[0]
185
- image_height = int(math.ceil(image_height / stride))
186
- image_width = int(math.ceil(image_width / stride))
187
- return [image_height, image_width]
188
-
189
-
190
- # Note:
191
- # The following 'SamePadding' functions make output size equal ceil(input size/stride).
192
- # Only when stride equals 1, can the output size be the same as input size.
193
- # Don't be confused by their function names ! ! !
194
-
195
- def get_same_padding_conv2d(image_size=None):
196
- """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
197
- Static padding is necessary for ONNX exporting of models.
198
-
199
- Args:
200
- image_size (int or tuple): Size of the image.
201
-
202
- Returns:
203
- Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
204
- """
205
- if image_size is None:
206
- return Conv2dDynamicSamePadding
207
- else:
208
- return partial(Conv2dStaticSamePadding, image_size=image_size)
209
-
210
-
211
- class Conv2dDynamicSamePadding(nn.Conv2d):
212
- """2D Convolutions like TensorFlow, for a dynamic image size.
213
- The padding is operated in forward function by calculating dynamically.
214
- """
215
-
216
- # Tips for 'SAME' mode padding.
217
- # Given the following:
218
- # i: width or height
219
- # s: stride
220
- # k: kernel size
221
- # d: dilation
222
- # p: padding
223
- # Output after Conv2d:
224
- # o = floor((i+p-((k-1)*d+1))/s+1)
225
- # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
226
- # => p = (i-1)*s+((k-1)*d+1)-i
227
-
228
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
229
- super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
230
- self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
231
-
232
- def forward(self, x):
233
- ih, iw = x.size()[-2:]
234
- kh, kw = self.weight.size()[-2:]
235
- sh, sw = self.stride
236
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
237
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
238
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
239
- if pad_h > 0 or pad_w > 0:
240
- x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
241
- return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
242
-
243
-
244
- class Conv2dStaticSamePadding(nn.Conv2d):
245
- """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
246
- The padding module is calculated in construction function, then used in forward.
247
- """
248
-
249
- # With the same calculation as Conv2dDynamicSamePadding
250
-
251
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
252
- super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
253
- self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
254
-
255
- # Calculate padding based on image size and save it
256
- assert image_size is not None
257
- ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
258
- kh, kw = self.weight.size()[-2:]
259
- sh, sw = self.stride
260
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
261
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
262
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
263
- if pad_h > 0 or pad_w > 0:
264
- self.static_padding = nn.ZeroPad2d((pad_w - pad_w // 2, pad_w - pad_w // 2,
265
- pad_h - pad_h // 2, pad_h - pad_h // 2))
266
- else:
267
- self.static_padding = nn.Identity()
268
-
269
- def forward(self, x):
270
- x = self.static_padding(x)
271
- x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
272
- return x
273
-
274
-
275
- def get_same_padding_maxPool2d(image_size=None):
276
- """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
277
- Static padding is necessary for ONNX exporting of models.
278
-
279
- Args:
280
- image_size (int or tuple): Size of the image.
281
-
282
- Returns:
283
- MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
284
- """
285
- if image_size is None:
286
- return MaxPool2dDynamicSamePadding
287
- else:
288
- return partial(MaxPool2dStaticSamePadding, image_size=image_size)
289
-
290
-
291
- class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
292
- """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
293
- The padding is operated in forward function by calculating dynamically.
294
- """
295
-
296
- def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
297
- super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
298
- self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
299
- self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
300
- self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
301
-
302
- def forward(self, x):
303
- ih, iw = x.size()[-2:]
304
- kh, kw = self.kernel_size
305
- sh, sw = self.stride
306
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
307
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
308
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
309
- if pad_h > 0 or pad_w > 0:
310
- x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
311
- return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
312
- self.dilation, self.ceil_mode, self.return_indices)
313
-
314
- class MaxPool2dStaticSamePadding(nn.MaxPool2d):
315
- """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
316
- The padding module is calculated in construction function, then used in forward.
317
- """
318
-
319
- def __init__(self, kernel_size, stride, image_size=None, **kwargs):
320
- super().__init__(kernel_size, stride, **kwargs)
321
- self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
322
- self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
323
- self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
324
-
325
- # Calculate padding based on image size and save it
326
- assert image_size is not None
327
- ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
328
- kh, kw = self.kernel_size
329
- sh, sw = self.stride
330
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
331
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
332
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
333
- if pad_h > 0 or pad_w > 0:
334
- self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
335
- else:
336
- self.static_padding = nn.Identity()
337
-
338
- def forward(self, x):
339
- x = self.static_padding(x)
340
- x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
341
- self.dilation, self.ceil_mode, self.return_indices)
342
- return x
343
-
344
-
345
- ################################################################################
346
- ### Helper functions for loading model params
347
- ################################################################################
348
-
349
- # BlockDecoder: A Class for encoding and decoding BlockArgs
350
- # efficientnet_params: A function to query compound coefficient
351
- # get_model_params and efficientnet:
352
- # Functions to get BlockArgs and GlobalParams for efficientnet
353
- # url_map and url_map_advprop: Dicts of url_map for pretrained weights
354
- # load_pretrained_weights: A function to load pretrained weights
355
-
356
- class BlockDecoder(object):
357
- """Block Decoder for readability,
358
- straight from the official TensorFlow repository.
359
- """
360
-
361
- @staticmethod
362
- def _decode_block_string(block_string):
363
- """Get a block through a string notation of arguments.
364
-
365
- Args:
366
- block_string (str): A string notation of arguments.
367
- Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
368
-
369
- Returns:
370
- BlockArgs: The namedtuple defined at the top of this file.
371
- """
372
- assert isinstance(block_string, str)
373
-
374
- ops = block_string.split('_')
375
- options = {}
376
- for op in ops:
377
- splits = re.split(r'(\d.*)', op)
378
- if len(splits) >= 2:
379
- key, value = splits[:2]
380
- options[key] = value
381
-
382
- # Check stride
383
- assert (('s' in options and len(options['s']) == 1) or
384
- (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
385
-
386
- return BlockArgs(
387
- num_repeat=int(options['r']),
388
- kernel_size=int(options['k']),
389
- stride=[int(options['s'][0])],
390
- expand_ratio=int(options['e']),
391
- input_filters=int(options['i']),
392
- output_filters=int(options['o']),
393
- se_ratio=float(options['se']) if 'se' in options else None,
394
- id_skip=('noskip' not in block_string))
395
-
396
- @staticmethod
397
- def _encode_block_string(block):
398
- """Encode a block to a string.
399
-
400
- Args:
401
- block (namedtuple): A BlockArgs type argument.
402
-
403
- Returns:
404
- block_string: A String form of BlockArgs.
405
- """
406
- args = [
407
- 'r%d' % block.num_repeat,
408
- 'k%d' % block.kernel_size,
409
- 's%d%d' % (block.strides[0], block.strides[1]),
410
- 'e%s' % block.expand_ratio,
411
- 'i%d' % block.input_filters,
412
- 'o%d' % block.output_filters
413
- ]
414
- if 0 < block.se_ratio <= 1:
415
- args.append('se%s' % block.se_ratio)
416
- if block.id_skip is False:
417
- args.append('noskip')
418
- return '_'.join(args)
419
-
420
- @staticmethod
421
- def decode(string_list):
422
- """Decode a list of string notations to specify blocks inside the network.
423
-
424
- Args:
425
- string_list (list[str]): A list of strings, each string is a notation of block.
426
-
427
- Returns:
428
- blocks_args: A list of BlockArgs namedtuples of block args.
429
- """
430
- assert isinstance(string_list, list)
431
- blocks_args = []
432
- for block_string in string_list:
433
- blocks_args.append(BlockDecoder._decode_block_string(block_string))
434
- return blocks_args
435
-
436
- @staticmethod
437
- def encode(blocks_args):
438
- """Encode a list of BlockArgs to a list of strings.
439
-
440
- Args:
441
- blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
442
-
443
- Returns:
444
- block_strings: A list of strings, each string is a notation of block.
445
- """
446
- block_strings = []
447
- for block in blocks_args:
448
- block_strings.append(BlockDecoder._encode_block_string(block))
449
- return block_strings
450
-
451
-
452
- def efficientnet_params(model_name):
453
- """Map EfficientNet model name to parameter coefficients.
454
-
455
- Args:
456
- model_name (str): Model name to be queried.
457
-
458
- Returns:
459
- params_dict[model_name]: A (width,depth,res,dropout) tuple.
460
- """
461
- params_dict = {
462
- # Coefficients: width,depth,res,dropout
463
- 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
464
- 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
465
- 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
466
- 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
467
- 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
468
- 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
469
- 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
470
- 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
471
- 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
472
- 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
473
- }
474
- return params_dict[model_name]
475
-
476
-
477
- def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
478
- dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
479
- """Create BlockArgs and GlobalParams for efficientnet model.
480
-
481
- Args:
482
- width_coefficient (float)
483
- depth_coefficient (float)
484
- image_size (int)
485
- dropout_rate (float)
486
- drop_connect_rate (float)
487
- num_classes (int)
488
-
489
- Meaning as the name suggests.
490
-
491
- Returns:
492
- blocks_args, global_params.
493
- """
494
-
495
- # Blocks args for the whole model(efficientnet-b0 by default)
496
- # It will be modified in the construction of EfficientNet Class according to model
497
- blocks_args = [
498
- 'r1_k3_s11_e1_i32_o16_se0.25',
499
- 'r2_k3_s22_e6_i16_o24_se0.25',
500
- 'r2_k5_s22_e6_i24_o40_se0.25',
501
- 'r3_k3_s22_e6_i40_o80_se0.25',
502
- 'r3_k5_s11_e6_i80_o112_se0.25',
503
- 'r4_k5_s22_e6_i112_o192_se0.25',
504
- 'r1_k3_s11_e6_i192_o320_se0.25',
505
- ]
506
- blocks_args = BlockDecoder.decode(blocks_args)
507
-
508
- global_params = GlobalParams(
509
- width_coefficient=width_coefficient,
510
- depth_coefficient=depth_coefficient,
511
- image_size=image_size,
512
- dropout_rate=dropout_rate,
513
-
514
- num_classes=num_classes,
515
- batch_norm_momentum=0.99,
516
- batch_norm_epsilon=1e-3,
517
- drop_connect_rate=drop_connect_rate,
518
- depth_divisor=8,
519
- min_depth=None,
520
- include_top=include_top,
521
- )
522
-
523
- return blocks_args, global_params
524
-
525
-
526
- def get_model_params(model_name, override_params):
527
- """Get the block args and global params for a given model name.
528
-
529
- Args:
530
- model_name (str): Model's name.
531
- override_params (dict): A dict to modify global_params.
532
-
533
- Returns:
534
- blocks_args, global_params
535
- """
536
- if model_name.startswith('efficientnet'):
537
- w, d, s, p = efficientnet_params(model_name)
538
- # note: all models have drop connect rate = 0.2
539
- blocks_args, global_params = efficientnet(
540
- width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
541
- else:
542
- raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
543
- if override_params:
544
- # ValueError will be raised here if override_params has fields not included in global_params.
545
- global_params = global_params._replace(**override_params)
546
- return blocks_args, global_params
547
-
548
-
549
- # train with Standard methods
550
- # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
551
- url_map = {
552
- 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
553
- 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
554
- 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
555
- 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
556
- 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
557
- 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
558
- 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
559
- 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
560
- }
561
-
562
- # train with Adversarial Examples(AdvProp)
563
- # check more details in paper(Adversarial Examples Improve Image Recognition)
564
- url_map_advprop = {
565
- 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
566
- 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
567
- 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
568
- 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
569
- 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
570
- 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
571
- 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
572
- 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
573
- 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
574
- }
575
-
576
- # TODO: add the petrained weights url map of 'efficientnet-l2'
577
-
578
-
579
- def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False):
580
- """Loads pretrained weights from weights path or download using url.
581
-
582
- Args:
583
- model (Module): The whole model of efficientnet.
584
- model_name (str): Model name of efficientnet.
585
- weights_path (None or str):
586
- str: path to pretrained weights file on the local disk.
587
- None: use pretrained weights downloaded from the Internet.
588
- load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
589
- advprop (bool): Whether to load pretrained weights
590
- trained with advprop (valid when weights_path is None).
591
- """
592
- if isinstance(weights_path, str):
593
- state_dict = torch.load(weights_path)
594
- else:
595
- # AutoAugment or Advprop (different preprocessing)
596
- url_map_ = url_map_advprop if advprop else url_map
597
- state_dict = model_zoo.load_url(url_map_[model_name])
598
-
599
- if load_fc:
600
- ret = model.load_state_dict(state_dict, strict=False)
601
- assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
602
- else:
603
- state_dict.pop('_fc.weight')
604
- state_dict.pop('_fc.bias')
605
- ret = model.load_state_dict(state_dict, strict=False)
606
- assert set(ret.missing_keys) == set(
607
- ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
608
- assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
609
-
610
- print('Loaded pretrained weights for {}'.format(model_name))