python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,691 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import numpy as np
3
+ import fvcore.nn.weight_init as weight_init
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from .r50_config import _C as r50_config
8
+ from collections import OrderedDict
9
+ from wml.wtorch.nets.shape_spec import ShapeSpec
10
+
11
+ from .layers import (
12
+ CNNBlockBase,
13
+ Conv2d,
14
+ get_norm,
15
+ )
16
+ #https://dl.fbaipublicfiles.com/detectron2
17
+
18
+ __all__ = [
19
+ "ResNetBlockBase",
20
+ "BasicBlock",
21
+ "BottleneckBlock",
22
+ "DeformBottleneckBlock",
23
+ "BasicStem",
24
+ "ResNet",
25
+ "make_stage",
26
+ "build_resnet_backbone",
27
+ ]
28
+
29
+
30
+ class BasicBlock(CNNBlockBase):
31
+ """
32
+ The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
33
+ with two 3x3 conv layers and a projection shortcut if needed.
34
+ """
35
+
36
+ def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
37
+ """
38
+ Args:
39
+ in_channels (int): Number of input channels.
40
+ out_channels (int): Number of output channels.
41
+ stride (int): Stride for the first conv.
42
+ norm (str or callable): normalization for all conv layers.
43
+ See :func:`layers.get_norm` for supported format.
44
+ """
45
+ super().__init__(in_channels, out_channels, stride)
46
+
47
+ if in_channels != out_channels:
48
+ self.shortcut = Conv2d(
49
+ in_channels,
50
+ out_channels,
51
+ kernel_size=1,
52
+ stride=stride,
53
+ bias=False,
54
+ norm=get_norm(norm, out_channels),
55
+ )
56
+ else:
57
+ self.shortcut = None
58
+
59
+ self.conv1 = Conv2d(
60
+ in_channels,
61
+ out_channels,
62
+ kernel_size=3,
63
+ stride=stride,
64
+ padding=1,
65
+ bias=False,
66
+ norm=get_norm(norm, out_channels),
67
+ )
68
+
69
+ self.conv2 = Conv2d(
70
+ out_channels,
71
+ out_channels,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ bias=False,
76
+ norm=get_norm(norm, out_channels),
77
+ )
78
+
79
+ for layer in [self.conv1, self.conv2, self.shortcut]:
80
+ if layer is not None: # shortcut can be None
81
+ weight_init.c2_msra_fill(layer)
82
+
83
+ def forward(self, x):
84
+ out = self.conv1(x)
85
+ out = F.relu_(out)
86
+ out = self.conv2(out)
87
+
88
+ if self.shortcut is not None:
89
+ shortcut = self.shortcut(x)
90
+ else:
91
+ shortcut = x
92
+
93
+ out += shortcut
94
+ out = F.relu_(out)
95
+ return out
96
+
97
+
98
+ class BottleneckBlock(CNNBlockBase):
99
+ """
100
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
101
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
102
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ in_channels,
108
+ out_channels,
109
+ *,
110
+ bottleneck_channels,
111
+ stride=1,
112
+ num_groups=1,
113
+ norm="BN",
114
+ stride_in_1x1=False,
115
+ dilation=1,
116
+ ):
117
+ """
118
+ Args:
119
+ bottleneck_channels (int): number of output channels for the 3x3
120
+ "bottleneck" conv layers.
121
+ num_groups (int): number of groups for the 3x3 conv layer.
122
+ norm (str or callable): normalization for all conv layers.
123
+ See :func:`layers.get_norm` for supported format.
124
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
125
+ first 1x1 convolution or the bottleneck 3x3 convolution.
126
+ dilation (int): the dilation rate of the 3x3 conv layer.
127
+ """
128
+ super().__init__(in_channels, out_channels, stride)
129
+
130
+ if in_channels != out_channels:
131
+ self.shortcut = Conv2d(
132
+ in_channels,
133
+ out_channels,
134
+ kernel_size=1,
135
+ stride=stride,
136
+ bias=False,
137
+ norm=get_norm(norm, out_channels),
138
+ )
139
+ else:
140
+ self.shortcut = None
141
+
142
+ # The original MSRA ResNet models have stride in the first 1x1 conv
143
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
144
+ # stride in the 3x3 conv
145
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
146
+
147
+ self.conv1 = Conv2d(
148
+ in_channels,
149
+ bottleneck_channels,
150
+ kernel_size=1,
151
+ stride=stride_1x1,
152
+ bias=False,
153
+ norm=get_norm(norm, bottleneck_channels),
154
+ )
155
+
156
+ self.conv2 = Conv2d(
157
+ bottleneck_channels,
158
+ bottleneck_channels,
159
+ kernel_size=3,
160
+ stride=stride_3x3,
161
+ padding=1 * dilation,
162
+ bias=False,
163
+ groups=num_groups,
164
+ dilation=dilation,
165
+ norm=get_norm(norm, bottleneck_channels),
166
+ )
167
+
168
+ self.conv3 = Conv2d(
169
+ bottleneck_channels,
170
+ out_channels,
171
+ kernel_size=1,
172
+ bias=False,
173
+ norm=get_norm(norm, out_channels),
174
+ )
175
+
176
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
177
+ if layer is not None: # shortcut can be None
178
+ weight_init.c2_msra_fill(layer)
179
+
180
+ # Zero-initialize the last normalization in each residual branch,
181
+ # so that at the beginning, the residual branch starts with zeros,
182
+ # and each residual block behaves like an identity.
183
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
184
+ # "For BN layers, the learnable scaling coefficient γ is initialized
185
+ # to be 1, except for each residual block's last BN
186
+ # where γ is initialized to be 0."
187
+
188
+ # nn.init.constant_(self.conv3.norm.weight, 0)
189
+ # TODO this somehow hurts performance when training GN models from scratch.
190
+ # Add it as an option when we need to use this code to train a backbone.
191
+
192
+ def forward(self, x):
193
+ out = self.conv1(x)
194
+ out = F.relu_(out)
195
+
196
+ out = self.conv2(out)
197
+ out = F.relu_(out)
198
+
199
+ out = self.conv3(out)
200
+
201
+ if self.shortcut is not None:
202
+ shortcut = self.shortcut(x)
203
+ else:
204
+ shortcut = x
205
+
206
+ out += shortcut
207
+ out = F.relu_(out)
208
+ return out
209
+
210
+
211
+ class DeformBottleneckBlock(CNNBlockBase):
212
+ """
213
+ Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
214
+ in the 3x3 convolution.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ in_channels,
220
+ out_channels,
221
+ *,
222
+ bottleneck_channels,
223
+ stride=1,
224
+ num_groups=1,
225
+ norm="BN",
226
+ stride_in_1x1=False,
227
+ dilation=1,
228
+ deform_modulated=False,
229
+ deform_num_groups=1,
230
+ ):
231
+ super().__init__(in_channels, out_channels, stride)
232
+ self.deform_modulated = deform_modulated
233
+
234
+ if in_channels != out_channels:
235
+ self.shortcut = Conv2d(
236
+ in_channels,
237
+ out_channels,
238
+ kernel_size=1,
239
+ stride=stride,
240
+ bias=False,
241
+ norm=get_norm(norm, out_channels),
242
+ )
243
+ else:
244
+ self.shortcut = None
245
+
246
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
247
+
248
+ self.conv1 = Conv2d(
249
+ in_channels,
250
+ bottleneck_channels,
251
+ kernel_size=1,
252
+ stride=stride_1x1,
253
+ bias=False,
254
+ norm=get_norm(norm, bottleneck_channels),
255
+ )
256
+
257
+ if deform_modulated:
258
+ deform_conv_op = ModulatedDeformConv
259
+ # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
260
+ offset_channels = 27
261
+ else:
262
+ deform_conv_op = DeformConv
263
+ offset_channels = 18
264
+
265
+ self.conv2_offset = Conv2d(
266
+ bottleneck_channels,
267
+ offset_channels * deform_num_groups,
268
+ kernel_size=3,
269
+ stride=stride_3x3,
270
+ padding=1 * dilation,
271
+ dilation=dilation,
272
+ )
273
+ self.conv2 = deform_conv_op(
274
+ bottleneck_channels,
275
+ bottleneck_channels,
276
+ kernel_size=3,
277
+ stride=stride_3x3,
278
+ padding=1 * dilation,
279
+ bias=False,
280
+ groups=num_groups,
281
+ dilation=dilation,
282
+ deformable_groups=deform_num_groups,
283
+ norm=get_norm(norm, bottleneck_channels),
284
+ )
285
+
286
+ self.conv3 = Conv2d(
287
+ bottleneck_channels,
288
+ out_channels,
289
+ kernel_size=1,
290
+ bias=False,
291
+ norm=get_norm(norm, out_channels),
292
+ )
293
+
294
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
295
+ if layer is not None: # shortcut can be None
296
+ weight_init.c2_msra_fill(layer)
297
+
298
+ nn.init.constant_(self.conv2_offset.weight, 0)
299
+ nn.init.constant_(self.conv2_offset.bias, 0)
300
+
301
+ def forward(self, x):
302
+ out = self.conv1(x)
303
+ out = F.relu_(out)
304
+
305
+ if self.deform_modulated:
306
+ offset_mask = self.conv2_offset(out)
307
+ offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
308
+ offset = torch.cat((offset_x, offset_y), dim=1)
309
+ mask = mask.sigmoid()
310
+ out = self.conv2(out, offset, mask)
311
+ else:
312
+ offset = self.conv2_offset(out)
313
+ out = self.conv2(out, offset)
314
+ out = F.relu_(out)
315
+
316
+ out = self.conv3(out)
317
+
318
+ if self.shortcut is not None:
319
+ shortcut = self.shortcut(x)
320
+ else:
321
+ shortcut = x
322
+
323
+ out += shortcut
324
+ out = F.relu_(out)
325
+ return out
326
+
327
+
328
+ class BasicStem(CNNBlockBase):
329
+ """
330
+ The standard ResNet stem (layers before the first residual block).
331
+ """
332
+
333
+ def __init__(self, in_channels=3, out_channels=64, norm="BN"):
334
+ """
335
+ Args:
336
+ norm (str or callable): norm after the first conv layer.
337
+ See :func:`layers.get_norm` for supported format.
338
+ """
339
+ super().__init__(in_channels, out_channels, 4)
340
+ self.in_channels = in_channels
341
+ self.conv1 = Conv2d(
342
+ in_channels,
343
+ out_channels,
344
+ kernel_size=7,
345
+ stride=2,
346
+ padding=3,
347
+ bias=False,
348
+ norm=get_norm(norm, out_channels),
349
+ )
350
+ weight_init.c2_msra_fill(self.conv1)
351
+
352
+ def forward(self, x):
353
+ x = self.conv1(x)
354
+ x = F.relu_(x)
355
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
356
+ return x
357
+
358
+
359
+ class ResNet(torch.nn.Module):
360
+ """
361
+ Implement :paper:`ResNet`.
362
+ """
363
+
364
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
365
+ """
366
+ Args:
367
+ stem (nn.Module): a stem module
368
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
369
+ each contains multiple :class:`CNNBlockBase`.
370
+ num_classes (None or int): if None, will not perform classification.
371
+ Otherwise, will create a linear layer.
372
+ out_features (list[str]): name of the layers whose outputs should
373
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
374
+ If None, will return the output of the last layer.
375
+ freeze_at (int): The number of stages at the beginning to freeze.
376
+ see :meth:`freeze` for detailed explanation.
377
+ """
378
+ super().__init__()
379
+ self.stem = stem
380
+ self.num_classes = num_classes
381
+
382
+
383
+ current_stride = self.stem.stride
384
+ self._out_feature_strides = {"stem": current_stride}
385
+ self._out_feature_channels = {"stem": self.stem.out_channels}
386
+
387
+ self.stage_names, self.stages = [], []
388
+
389
+ if out_features is not None:
390
+ # Avoid keeping unused layers in this module. They consume extra memory
391
+ # and may cause allreduce to fail
392
+ num_stages = max(
393
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
394
+ )
395
+ stages = stages[:num_stages]
396
+ for i, blocks in enumerate(stages):
397
+ assert len(blocks) > 0, len(blocks)
398
+ for block in blocks:
399
+ assert isinstance(block, CNNBlockBase), block
400
+
401
+ name = "res" + str(i + 2)
402
+ stage = nn.Sequential(*blocks)
403
+
404
+ self.add_module(name, stage)
405
+ self.stage_names.append(name)
406
+ self.stages.append(stage)
407
+
408
+ self._out_feature_strides[name] = current_stride = int(
409
+ current_stride * np.prod([k.stride for k in blocks])
410
+ )
411
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
412
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
413
+
414
+ if num_classes is not None:
415
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
416
+ self.linear = nn.Linear(curr_channels, num_classes)
417
+
418
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
419
+ # "The 1000-way fully-connected layer is initialized by
420
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
421
+ nn.init.normal_(self.linear.weight, std=0.01)
422
+ name = "linear"
423
+
424
+ if out_features is None:
425
+ out_features = [name]
426
+ self._out_features = out_features
427
+ assert len(self._out_features)
428
+ children = [x[0] for x in self.named_children()]
429
+ for out_feature in self._out_features:
430
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
431
+ self.freeze(freeze_at)
432
+
433
+ def forward(self, x):
434
+ """
435
+ Args:
436
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
437
+
438
+ Returns:
439
+ dict[str->Tensor]: names and the corresponding features
440
+ """
441
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
442
+ outputs = {}
443
+ x = self.stem(x)
444
+ if "stem" in self._out_features:
445
+ outputs["stem"] = x
446
+ for name, stage in zip(self.stage_names, self.stages):
447
+ x = stage(x)
448
+ if name in self._out_features:
449
+ outputs[name] = x
450
+ if self.num_classes is not None:
451
+ x = self.avgpool(x)
452
+ x = torch.flatten(x, 1)
453
+ x = self.linear(x)
454
+ if "linear" in self._out_features:
455
+ outputs["linear"] = x
456
+ return outputs
457
+
458
+ def output_shape(self):
459
+ return {
460
+ name: ShapeSpec(
461
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
462
+ )
463
+ for name in self._out_features
464
+ }
465
+
466
+ def freeze(self, freeze_at=0):
467
+ """
468
+ Freeze the first several stages of the ResNet. Commonly used in
469
+ fine-tuning.
470
+
471
+ Layers that produce the same feature map spatial size are defined as one
472
+ "stage" by :paper:`FPN`.
473
+
474
+ Args:
475
+ freeze_at (int): number of stages to freeze.
476
+ `1` means freezing the stem. `2` means freezing the stem and
477
+ one residual stage, etc.
478
+
479
+ Returns:
480
+ nn.Module: this ResNet itself
481
+ """
482
+ if freeze_at >= 1:
483
+ self.stem.freeze()
484
+ for idx, stage in enumerate(self.stages, start=2):
485
+ if freeze_at >= idx:
486
+ for block in stage.children():
487
+ block.freeze()
488
+ return self
489
+
490
+ @staticmethod
491
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
492
+ """
493
+ Create a list of blocks of the same type that forms one ResNet stage.
494
+
495
+ Args:
496
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
497
+ stage. A module of this type must not change spatial resolution of inputs unless its
498
+ stride != 1.
499
+ num_blocks (int): number of blocks in this stage
500
+ in_channels (int): input channels of the entire stage.
501
+ out_channels (int): output channels of **every block** in the stage.
502
+ kwargs: other arguments passed to the constructor of
503
+ `block_class`. If the argument name is "xx_per_block", the
504
+ argument is a list of values to be passed to each block in the
505
+ stage. Otherwise, the same argument is passed to every block
506
+ in the stage.
507
+
508
+ Returns:
509
+ list[CNNBlockBase]: a list of block module.
510
+
511
+ Examples:
512
+ ::
513
+ stage = ResNet.make_stage(
514
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
515
+ bottleneck_channels=16, num_groups=1,
516
+ stride_per_block=[2, 1, 1],
517
+ dilations_per_block=[1, 1, 2]
518
+ )
519
+
520
+ Usually, layers that produce the same feature map spatial size are defined as one
521
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
522
+ all be 1.
523
+ """
524
+ blocks = []
525
+ for i in range(num_blocks):
526
+ curr_kwargs = {}
527
+ for k, v in kwargs.items():
528
+ if k.endswith("_per_block"):
529
+ assert len(v) == num_blocks, (
530
+ f"Argument '{k}' of make_stage should have the "
531
+ f"same length as num_blocks={num_blocks}."
532
+ )
533
+ newk = k[: -len("_per_block")]
534
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
535
+ curr_kwargs[newk] = v[i]
536
+ else:
537
+ curr_kwargs[k] = v
538
+
539
+ blocks.append(
540
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
541
+ )
542
+ in_channels = out_channels
543
+ return blocks
544
+
545
+ @staticmethod
546
+ def make_default_stages(depth, block_class=None, **kwargs):
547
+ """
548
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
549
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
550
+ instead for fine-grained customization.
551
+
552
+ Args:
553
+ depth (int): depth of ResNet
554
+ block_class (type): the CNN block class. Has to accept
555
+ `bottleneck_channels` argument for depth > 50.
556
+ By default it is BasicBlock or BottleneckBlock, based on the
557
+ depth.
558
+ kwargs:
559
+ other arguments to pass to `make_stage`. Should not contain
560
+ stride and channels, as they are predefined for each depth.
561
+
562
+ Returns:
563
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
564
+ :class:`ResNet.__init__`.
565
+ """
566
+ num_blocks_per_stage = {
567
+ 18: [2, 2, 2, 2],
568
+ 34: [3, 4, 6, 3],
569
+ 50: [3, 4, 6, 3],
570
+ 101: [3, 4, 23, 3],
571
+ 152: [3, 8, 36, 3],
572
+ }[depth]
573
+ if block_class is None:
574
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
575
+ if depth < 50:
576
+ in_channels = [64, 64, 128, 256]
577
+ out_channels = [64, 128, 256, 512]
578
+ else:
579
+ in_channels = [64, 256, 512, 1024]
580
+ out_channels = [256, 512, 1024, 2048]
581
+ ret = []
582
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
583
+ if depth >= 50:
584
+ kwargs["bottleneck_channels"] = o // 4
585
+ ret.append(
586
+ ResNet.make_stage(
587
+ block_class=block_class,
588
+ num_blocks=n,
589
+ stride_per_block=[s] + [1] * (n - 1),
590
+ in_channels=i,
591
+ out_channels=o,
592
+ **kwargs,
593
+ )
594
+ )
595
+ return ret
596
+
597
+
598
+ ResNetBlockBase = CNNBlockBase
599
+ """
600
+ Alias for backward compatibiltiy.
601
+ """
602
+
603
+
604
+ def make_stage(*args, **kwargs):
605
+ """
606
+ Deprecated alias for backward compatibiltiy.
607
+ """
608
+ return ResNet.make_stage(*args, **kwargs)
609
+
610
+
611
+ def build_resnet_backbone(cfg=r50_config, in_channels=3,out_features=["res2", "res3", "res4", "res5"],
612
+ freeze_at=0):
613
+ """
614
+ Create a ResNet instance from config.
615
+
616
+ Returns:
617
+ ResNet: a :class:`ResNet` instance.
618
+ """
619
+ # need registration of new blocks/stems?
620
+ norm = cfg.MODEL.RESNETS.NORM
621
+ stem = BasicStem(
622
+ in_channels=in_channels,
623
+ out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
624
+ norm=norm,
625
+ )
626
+
627
+ # fmt: off
628
+ out_features = out_features
629
+ depth = cfg.MODEL.RESNETS.DEPTH
630
+ num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
631
+ width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
632
+ bottleneck_channels = num_groups * width_per_group
633
+ in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
634
+ out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
635
+ stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
636
+ res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
637
+ deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
638
+ deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
639
+ deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
640
+ # fmt: on
641
+ assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
642
+
643
+ num_blocks_per_stage = {
644
+ 18: [2, 2, 2, 2],
645
+ 34: [3, 4, 6, 3],
646
+ 50: [3, 4, 6, 3],
647
+ 101: [3, 4, 23, 3],
648
+ 152: [3, 8, 36, 3],
649
+ }[depth]
650
+
651
+ if depth in [18, 34]:
652
+ assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
653
+ assert not any(
654
+ deform_on_per_stage
655
+ ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
656
+ assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
657
+ assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
658
+
659
+ stages = []
660
+
661
+ for idx, stage_idx in enumerate(range(2, 6)):
662
+ # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
663
+ dilation = res5_dilation if stage_idx == 5 else 1
664
+ first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
665
+ stage_kargs = {
666
+ "num_blocks": num_blocks_per_stage[idx],
667
+ "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
668
+ "in_channels": in_channels,
669
+ "out_channels": out_channels,
670
+ "norm": norm,
671
+ }
672
+ # Use BasicBlock for R18 and R34.
673
+ if depth in [18, 34]:
674
+ stage_kargs["block_class"] = BasicBlock
675
+ else:
676
+ stage_kargs["bottleneck_channels"] = bottleneck_channels
677
+ stage_kargs["stride_in_1x1"] = stride_in_1x1
678
+ stage_kargs["dilation"] = dilation
679
+ stage_kargs["num_groups"] = num_groups
680
+ if deform_on_per_stage[idx]:
681
+ stage_kargs["block_class"] = DeformBottleneckBlock
682
+ stage_kargs["deform_modulated"] = deform_modulated
683
+ stage_kargs["deform_num_groups"] = deform_num_groups
684
+ else:
685
+ stage_kargs["block_class"] = BottleneckBlock
686
+ blocks = ResNet.make_stage(**stage_kargs)
687
+ in_channels = out_channels
688
+ out_channels *= 2
689
+ bottleneck_channels *= 2
690
+ stages.append(blocks)
691
+ return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
@@ -0,0 +1,20 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from collections import namedtuple
4
+
5
+
6
+ class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
7
+ """
8
+ A simple structure that contains basic shape specification about a tensor.
9
+ It is often used as the auxiliary inputs/outputs of models,
10
+ to complement the lack of shape inference ability among pytorch modules.
11
+
12
+ Attributes:
13
+ channels:
14
+ height:
15
+ width:
16
+ stride:
17
+ """
18
+
19
+ def __new__(cls, channels=None, height=None, width=None, stride=None):
20
+ return super().__new__(cls, channels, height, width, stride)