python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,249 @@
1
+ import warnings
2
+ from typing import Callable, List, Optional
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ interpolate = torch.nn.functional.interpolate
8
+
9
+
10
+ # This is not in nn
11
+ class FrozenBatchNorm2d(torch.nn.Module):
12
+ """
13
+ BatchNorm2d where the batch statistics and the affine parameters are fixed
14
+
15
+ Args:
16
+ num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
17
+ eps (float): a value added to the denominator for numerical stability. Default: 1e-5
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ num_features: int,
23
+ eps: float = 1e-5,
24
+ ):
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.register_buffer("weight", torch.ones(num_features))
28
+ self.register_buffer("bias", torch.zeros(num_features))
29
+ self.register_buffer("running_mean", torch.zeros(num_features))
30
+ self.register_buffer("running_var", torch.ones(num_features))
31
+
32
+ def _load_from_state_dict(
33
+ self,
34
+ state_dict: dict,
35
+ prefix: str,
36
+ local_metadata: dict,
37
+ strict: bool,
38
+ missing_keys: List[str],
39
+ unexpected_keys: List[str],
40
+ error_msgs: List[str],
41
+ ):
42
+ num_batches_tracked_key = prefix + "num_batches_tracked"
43
+ if num_batches_tracked_key in state_dict:
44
+ del state_dict[num_batches_tracked_key]
45
+
46
+ super()._load_from_state_dict(
47
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
48
+ )
49
+
50
+ def forward(self, x: Tensor) -> Tensor:
51
+ # move reshapes to the beginning
52
+ # to make it fuser-friendly
53
+ w = self.weight.reshape(1, -1, 1, 1)
54
+ b = self.bias.reshape(1, -1, 1, 1)
55
+ rv = self.running_var.reshape(1, -1, 1, 1)
56
+ rm = self.running_mean.reshape(1, -1, 1, 1)
57
+ scale = w * (rv + self.eps).rsqrt()
58
+ bias = b - rm * scale
59
+ return x * scale + bias
60
+
61
+ def __repr__(self) -> str:
62
+ return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
63
+
64
+
65
+ class ConvNormActivation(torch.nn.Sequential):
66
+ def __init__(
67
+ self,
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int = 3,
71
+ stride: int = 1,
72
+ padding: Optional[int] = None,
73
+ groups: int = 1,
74
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
75
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
76
+ dilation: int = 1,
77
+ inplace: Optional[bool] = True,
78
+ bias: Optional[bool] = None,
79
+ conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
80
+ ) -> None:
81
+
82
+ if padding is None:
83
+ padding = (kernel_size - 1) // 2 * dilation
84
+ if bias is None:
85
+ bias = norm_layer is None
86
+
87
+ layers = [
88
+ conv_layer(
89
+ in_channels,
90
+ out_channels,
91
+ kernel_size,
92
+ stride,
93
+ padding,
94
+ dilation=dilation,
95
+ groups=groups,
96
+ bias=bias,
97
+ )
98
+ ]
99
+
100
+ if norm_layer is not None:
101
+ layers.append(norm_layer(out_channels))
102
+
103
+ if activation_layer is not None:
104
+ params = {} if inplace is None else {"inplace": inplace}
105
+ layers.append(activation_layer(**params))
106
+ super().__init__(*layers)
107
+ self.out_channels = out_channels
108
+
109
+ if self.__class__ == ConvNormActivation:
110
+ warnings.warn(
111
+ "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
112
+ )
113
+
114
+
115
+ class Conv2dNormActivation(ConvNormActivation):
116
+ """
117
+ Configurable block used for Convolution2d-Normalzation-Activation blocks.
118
+
119
+ Args:
120
+ in_channels (int): Number of channels in the input image
121
+ out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
122
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
123
+ stride (int, optional): Stride of the convolution. Default: 1
124
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
125
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
126
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
127
+ activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
128
+ dilation (int): Spacing between kernel elements. Default: 1
129
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
130
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
131
+
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ in_channels: int,
137
+ out_channels: int,
138
+ kernel_size: int = 3,
139
+ stride: int = 1,
140
+ padding: Optional[int] = None,
141
+ groups: int = 1,
142
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
143
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
144
+ dilation: int = 1,
145
+ inplace: Optional[bool] = True,
146
+ bias: Optional[bool] = None,
147
+ ) -> None:
148
+
149
+ super().__init__(
150
+ in_channels,
151
+ out_channels,
152
+ kernel_size,
153
+ stride,
154
+ padding,
155
+ groups,
156
+ norm_layer,
157
+ activation_layer,
158
+ dilation,
159
+ inplace,
160
+ bias,
161
+ torch.nn.Conv2d,
162
+ )
163
+
164
+
165
+ class Conv3dNormActivation(ConvNormActivation):
166
+ """
167
+ Configurable block used for Convolution3d-Normalzation-Activation blocks.
168
+
169
+ Args:
170
+ in_channels (int): Number of channels in the input video.
171
+ out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
172
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
173
+ stride (int, optional): Stride of the convolution. Default: 1
174
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
175
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
176
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
177
+ activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
178
+ dilation (int): Spacing between kernel elements. Default: 1
179
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
180
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ in_channels: int,
186
+ out_channels: int,
187
+ kernel_size: int = 3,
188
+ stride: int = 1,
189
+ padding: Optional[int] = None,
190
+ groups: int = 1,
191
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
192
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
193
+ dilation: int = 1,
194
+ inplace: Optional[bool] = True,
195
+ bias: Optional[bool] = None,
196
+ ) -> None:
197
+
198
+ super().__init__(
199
+ in_channels,
200
+ out_channels,
201
+ kernel_size,
202
+ stride,
203
+ padding,
204
+ groups,
205
+ norm_layer,
206
+ activation_layer,
207
+ dilation,
208
+ inplace,
209
+ bias,
210
+ torch.nn.Conv3d,
211
+ )
212
+
213
+
214
+ class SqueezeExcitation(torch.nn.Module):
215
+ """
216
+ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
217
+ Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3.
218
+
219
+ Args:
220
+ input_channels (int): Number of channels in the input image
221
+ squeeze_channels (int): Number of squeeze channels
222
+ activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
223
+ scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ input_channels: int,
229
+ squeeze_channels: int,
230
+ activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
231
+ scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
232
+ ) -> None:
233
+ super().__init__()
234
+ self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
235
+ self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
236
+ self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
237
+ self.activation = activation()
238
+ self.scale_activation = scale_activation()
239
+
240
+ def _scale(self, input: Tensor) -> Tensor:
241
+ scale = self.avgpool(input)
242
+ scale = self.fc1(scale)
243
+ scale = self.activation(scale)
244
+ scale = self.fc2(scale)
245
+ return self.scale_activation(scale)
246
+
247
+ def forward(self, input: Tensor) -> Tensor:
248
+ scale = self._scale(input)
249
+ return scale * input
File without changes
@@ -0,0 +1,17 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm
3
+ from wml.wtorch.nets.shape_spec import ShapeSpec
4
+ from .wrappers import (
5
+ BatchNorm2d,
6
+ Conv2d,
7
+ ConvTranspose2d,
8
+ cat,
9
+ interpolate,
10
+ Linear,
11
+ nonzero_tuple,
12
+ cross_entropy,
13
+ )
14
+ from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
15
+ from .aspp import ASPP
16
+
17
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
@@ -0,0 +1,144 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ from copy import deepcopy
4
+ import fvcore.nn.weight_init as weight_init
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from .batch_norm import get_norm
10
+ from .blocks import DepthwiseSeparableConv2d
11
+ from .wrappers import Conv2d
12
+
13
+
14
+ class ASPP(nn.Module):
15
+ """
16
+ Atrous Spatial Pyramid Pooling (ASPP).
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ in_channels,
22
+ out_channels,
23
+ dilations,
24
+ *,
25
+ norm,
26
+ activation,
27
+ pool_kernel_size=None,
28
+ dropout: float = 0.0,
29
+ use_depthwise_separable_conv=False,
30
+ ):
31
+ """
32
+ Args:
33
+ in_channels (int): number of input channels for ASPP.
34
+ out_channels (int): number of output channels.
35
+ dilations (list): a list of 3 dilations in ASPP.
36
+ norm (str or callable): normalization for all conv layers.
37
+ See :func:`layers.get_norm` for supported format. norm is
38
+ applied to all conv layers except the conv following
39
+ global average pooling.
40
+ activation (callable): activation function.
41
+ pool_kernel_size (tuple, list): the average pooling size (kh, kw)
42
+ for image pooling layer in ASPP. If set to None, it always
43
+ performs global average pooling. If not None, it must be
44
+ divisible by the shape of inputs in forward(). It is recommended
45
+ to use a fixed input feature size in training, and set this
46
+ option to match this size, so that it performs global average
47
+ pooling in training, and the size of the pooling window stays
48
+ consistent in inference.
49
+ dropout (float): apply dropout on the output of ASPP. It is used in
50
+ the official DeepLab implementation with a rate of 0.1:
51
+ https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa
52
+ use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d
53
+ for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`.
54
+ """
55
+ super(ASPP, self).__init__()
56
+ assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations))
57
+ self.pool_kernel_size = pool_kernel_size
58
+ self.dropout = dropout
59
+ use_bias = norm == ""
60
+ self.convs = nn.ModuleList()
61
+ # conv 1x1
62
+ self.convs.append(
63
+ Conv2d(
64
+ in_channels,
65
+ out_channels,
66
+ kernel_size=1,
67
+ bias=use_bias,
68
+ norm=get_norm(norm, out_channels),
69
+ activation=deepcopy(activation),
70
+ )
71
+ )
72
+ weight_init.c2_xavier_fill(self.convs[-1])
73
+ # atrous convs
74
+ for dilation in dilations:
75
+ if use_depthwise_separable_conv:
76
+ self.convs.append(
77
+ DepthwiseSeparableConv2d(
78
+ in_channels,
79
+ out_channels,
80
+ kernel_size=3,
81
+ padding=dilation,
82
+ dilation=dilation,
83
+ norm1=norm,
84
+ activation1=deepcopy(activation),
85
+ norm2=norm,
86
+ activation2=deepcopy(activation),
87
+ )
88
+ )
89
+ else:
90
+ self.convs.append(
91
+ Conv2d(
92
+ in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ padding=dilation,
96
+ dilation=dilation,
97
+ bias=use_bias,
98
+ norm=get_norm(norm, out_channels),
99
+ activation=deepcopy(activation),
100
+ )
101
+ )
102
+ weight_init.c2_xavier_fill(self.convs[-1])
103
+ # image pooling
104
+ # We do not add BatchNorm because the spatial resolution is 1x1,
105
+ # the original TF implementation has BatchNorm.
106
+ if pool_kernel_size is None:
107
+ image_pooling = nn.Sequential(
108
+ nn.AdaptiveAvgPool2d(1),
109
+ Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
110
+ )
111
+ else:
112
+ image_pooling = nn.Sequential(
113
+ nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1),
114
+ Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
115
+ )
116
+ weight_init.c2_xavier_fill(image_pooling[1])
117
+ self.convs.append(image_pooling)
118
+
119
+ self.project = Conv2d(
120
+ 5 * out_channels,
121
+ out_channels,
122
+ kernel_size=1,
123
+ bias=use_bias,
124
+ norm=get_norm(norm, out_channels),
125
+ activation=deepcopy(activation),
126
+ )
127
+ weight_init.c2_xavier_fill(self.project)
128
+
129
+ def forward(self, x):
130
+ size = x.shape[-2:]
131
+ if self.pool_kernel_size is not None:
132
+ if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]:
133
+ raise ValueError(
134
+ "`pool_kernel_size` must be divisible by the shape of inputs. "
135
+ "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size)
136
+ )
137
+ res = []
138
+ for conv in self.convs:
139
+ res.append(conv(x))
140
+ res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False)
141
+ res = torch.cat(res, dim=1)
142
+ res = self.project(res)
143
+ res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res
144
+ return res
@@ -0,0 +1,231 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ import torch
4
+ import torch.distributed as dist
5
+ from fvcore.nn.distributed import differentiable_all_reduce
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from wml.wtorch.dist import get_world_size
9
+ from wml.wtorch.utils import TORCH_VERSION
10
+
11
+ from .wrappers import BatchNorm2d
12
+
13
+
14
+ class FrozenBatchNorm2d(nn.Module):
15
+ """
16
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
17
+
18
+ It contains non-trainable buffers called
19
+ "weight" and "bias", "running_mean", "running_var",
20
+ initialized to perform identity transformation.
21
+
22
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
23
+ which are computed from the original four parameters of BN.
24
+ The affine transform `x * weight + bias` will perform the equivalent
25
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
26
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
27
+ will be left unchanged as identity transformation.
28
+
29
+ Other pre-trained backbone models may contain all 4 parameters.
30
+
31
+ The forward is implemented by `F.batch_norm(..., training=False)`.
32
+ """
33
+
34
+ _version = 3
35
+
36
+ def __init__(self, num_features, eps=1e-5):
37
+ super().__init__()
38
+ self.num_features = num_features
39
+ self.eps = eps
40
+ self.register_buffer("weight", torch.ones(num_features))
41
+ self.register_buffer("bias", torch.zeros(num_features))
42
+ self.register_buffer("running_mean", torch.zeros(num_features))
43
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
44
+
45
+ def forward(self, x):
46
+ if x.requires_grad:
47
+ # When gradients are needed, F.batch_norm will use extra memory
48
+ # because its backward op computes gradients for weight/bias as well.
49
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
50
+ bias = self.bias - self.running_mean * scale
51
+ scale = scale.reshape(1, -1, 1, 1)
52
+ bias = bias.reshape(1, -1, 1, 1)
53
+ out_dtype = x.dtype # may be half
54
+ return x * scale.to(out_dtype) + bias.to(out_dtype)
55
+ else:
56
+ # When gradients are not needed, F.batch_norm is a single fused op
57
+ # and provide more optimization opportunities.
58
+ return F.batch_norm(
59
+ x,
60
+ self.running_mean,
61
+ self.running_var,
62
+ self.weight,
63
+ self.bias,
64
+ training=False,
65
+ eps=self.eps,
66
+ )
67
+
68
+ def _load_from_state_dict(
69
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
70
+ ):
71
+ version = local_metadata.get("version", None)
72
+
73
+ if version is None or version < 2:
74
+ # No running_mean/var in early versions
75
+ # This will silent the warnings
76
+ if prefix + "running_mean" not in state_dict:
77
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
78
+ if prefix + "running_var" not in state_dict:
79
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
80
+
81
+ # NOTE: if a checkpoint is trained with BatchNorm and loaded (together with
82
+ # version number) to FrozenBatchNorm, running_var will be wrong. One solution
83
+ # is to remove the version number from the checkpoint.
84
+ if version is not None and version < 3:
85
+ logger = logging.getLogger(__name__)
86
+ logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
87
+ # In version < 3, running_var are used without +eps.
88
+ state_dict[prefix + "running_var"] -= self.eps
89
+
90
+ super()._load_from_state_dict(
91
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
92
+ )
93
+
94
+ def __repr__(self):
95
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
96
+
97
+ @classmethod
98
+ def convert_frozen_batchnorm(cls, module):
99
+ """
100
+ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
101
+
102
+ Args:
103
+ module (torch.nn.Module):
104
+
105
+ Returns:
106
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
107
+ Otherwise, in-place convert module and return it.
108
+
109
+ Similar to convert_sync_batchnorm in
110
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
111
+ """
112
+ bn_module = nn.modules.batchnorm
113
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
114
+ res = module
115
+ if isinstance(module, bn_module):
116
+ res = cls(module.num_features)
117
+ if module.affine:
118
+ res.weight.data = module.weight.data.clone().detach()
119
+ res.bias.data = module.bias.data.clone().detach()
120
+ res.running_mean.data = module.running_mean.data
121
+ res.running_var.data = module.running_var.data
122
+ res.eps = module.eps
123
+ else:
124
+ for name, child in module.named_children():
125
+ new_child = cls.convert_frozen_batchnorm(child)
126
+ if new_child is not child:
127
+ res.add_module(name, new_child)
128
+ return res
129
+
130
+
131
+ def get_norm(norm, out_channels):
132
+ """
133
+ Args:
134
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
135
+ or a callable that takes a channel number and returns
136
+ the normalization layer as a nn.Module.
137
+
138
+ Returns:
139
+ nn.Module or None: the normalization layer
140
+ """
141
+ if norm is None:
142
+ return None
143
+ if isinstance(norm, str):
144
+ if len(norm) == 0:
145
+ return None
146
+ norm = {
147
+ "BN": BatchNorm2d,
148
+ # Fixed in https://github.com/pytorch/pytorch/pull/36382
149
+ "SyncBN": NaiveSyncBatchNorm if TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
150
+ "FrozenBN": FrozenBatchNorm2d,
151
+ "GN": lambda channels: nn.GroupNorm(32, channels),
152
+ # for debugging:
153
+ "nnSyncBN": nn.SyncBatchNorm,
154
+ "naiveSyncBN": NaiveSyncBatchNorm,
155
+ }[norm]
156
+ return norm(out_channels)
157
+
158
+
159
+ class NaiveSyncBatchNorm(BatchNorm2d):
160
+ """
161
+ In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
162
+ when the batch size on each worker is different.
163
+ (e.g., when scale augmentation is used, or when it is applied to mask head).
164
+
165
+ This is a slower but correct alternative to `nn.SyncBatchNorm`.
166
+
167
+ Note:
168
+ There isn't a single definition of Sync BatchNorm.
169
+
170
+ When ``stats_mode==""``, this module computes overall statistics by using
171
+ statistics of each worker with equal weight. The result is true statistics
172
+ of all samples (as if they are all on one worker) only when all workers
173
+ have the same (N, H, W). This mode does not support inputs with zero batch size.
174
+
175
+ When ``stats_mode=="N"``, this module computes overall statistics by weighting
176
+ the statistics of each worker by their ``N``. The result is true statistics
177
+ of all samples (as if they are all on one worker) only when all workers
178
+ have the same (H, W). It is slower than ``stats_mode==""``.
179
+
180
+ Even though the result of this module may not be the true statistics of all samples,
181
+ it may still be reasonable because it might be preferrable to assign equal weights
182
+ to all workers, regardless of their (H, W) dimension, instead of putting larger weight
183
+ on larger images. From preliminary experiments, little difference is found between such
184
+ a simplified implementation and an accurate computation of overall mean & variance.
185
+ """
186
+
187
+ def __init__(self, *args, stats_mode="", **kwargs):
188
+ super().__init__(*args, **kwargs)
189
+ assert stats_mode in ["", "N"]
190
+ self._stats_mode = stats_mode
191
+
192
+ def forward(self, input):
193
+ if get_world_size() == 1 or not self.training:
194
+ return super().forward(input)
195
+
196
+ B, C = input.shape[0], input.shape[1]
197
+
198
+ mean = torch.mean(input, dim=[0, 2, 3])
199
+ meansqr = torch.mean(input * input, dim=[0, 2, 3])
200
+
201
+ if self._stats_mode == "":
202
+ assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
203
+ vec = torch.cat([mean, meansqr], dim=0)
204
+ vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
205
+ mean, meansqr = torch.split(vec, C)
206
+ momentum = self.momentum
207
+ else:
208
+ if B == 0:
209
+ vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
210
+ vec = vec + input.sum() # make sure there is gradient w.r.t input
211
+ else:
212
+ vec = torch.cat(
213
+ [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
214
+ )
215
+ vec = differentiable_all_reduce(vec * B)
216
+
217
+ total_batch = vec[-1].detach()
218
+ momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0
219
+ total_batch = torch.max(total_batch, torch.ones_like(total_batch)) # avoid div-by-zero
220
+ mean, meansqr, _ = torch.split(vec / total_batch, C)
221
+
222
+ var = meansqr - mean * mean
223
+ invstd = torch.rsqrt(var + self.eps)
224
+ scale = self.weight * invstd
225
+ bias = self.bias - mean * scale
226
+ scale = scale.reshape(1, -1, 1, 1)
227
+ bias = bias.reshape(1, -1, 1, 1)
228
+
229
+ self.running_mean += momentum * (mean.detach() - self.running_mean)
230
+ self.running_var += momentum * (var.detach() - self.running_var)
231
+ return input * scale + bias