python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,494 @@
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+ import os.path as osp
7
+ from .config import _C
8
+ import os
9
+ import logging
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ BN_MOMENTUM = 0.01
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def conv3x3(in_planes, out_planes, stride=1):
20
+ """3x3 convolution with padding"""
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=1, bias=False)
23
+
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+
28
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
29
+ super(BasicBlock, self).__init__()
30
+ self.conv1 = conv3x3(inplanes, planes, stride)
31
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
32
+ self.relu = nn.ReLU(inplace=True)
33
+ self.conv2 = conv3x3(planes, planes)
34
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
35
+ self.downsample = downsample
36
+ self.stride = stride
37
+
38
+ def forward(self, x):
39
+ residual = x
40
+
41
+ out = self.conv1(x)
42
+ out = self.bn1(out)
43
+ out = self.relu(out)
44
+
45
+ out = self.conv2(out)
46
+ out = self.bn2(out)
47
+
48
+ if self.downsample is not None:
49
+ residual = self.downsample(x)
50
+
51
+ out += residual
52
+ out = self.relu(out)
53
+
54
+ return out
55
+
56
+
57
+ class Bottleneck(nn.Module):
58
+ expansion = 4
59
+
60
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
61
+ super(Bottleneck, self).__init__()
62
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
64
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65
+ padding=1, bias=False)
66
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
67
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
68
+ bias=False)
69
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
70
+ momentum=BN_MOMENTUM)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HighResolutionModule(nn.Module):
99
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
100
+ num_channels, fuse_method, multi_scale_output=True):
101
+ super(HighResolutionModule, self).__init__()
102
+ self._check_branches(
103
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
104
+
105
+ self.num_inchannels = num_inchannels
106
+ self.fuse_method = fuse_method
107
+ self.num_branches = num_branches
108
+
109
+ self.multi_scale_output = multi_scale_output
110
+
111
+ self.branches = self._make_branches(
112
+ num_branches, blocks, num_blocks, num_channels)
113
+ self.fuse_layers = self._make_fuse_layers()
114
+ self.relu = nn.ReLU(True)
115
+
116
+ def _check_branches(self, num_branches, blocks, num_blocks,
117
+ num_inchannels, num_channels):
118
+ if num_branches != len(num_blocks):
119
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
120
+ num_branches, len(num_blocks))
121
+ logger.error(error_msg)
122
+ raise ValueError(error_msg)
123
+
124
+ if num_branches != len(num_channels):
125
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
126
+ num_branches, len(num_channels))
127
+ logger.error(error_msg)
128
+ raise ValueError(error_msg)
129
+
130
+ if num_branches != len(num_inchannels):
131
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
132
+ num_branches, len(num_inchannels))
133
+ logger.error(error_msg)
134
+ raise ValueError(error_msg)
135
+
136
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
137
+ stride=1):
138
+ downsample = None
139
+ if stride != 1 or \
140
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
141
+ downsample = nn.Sequential(
142
+ nn.Conv2d(
143
+ self.num_inchannels[branch_index],
144
+ num_channels[branch_index] * block.expansion,
145
+ kernel_size=1, stride=stride, bias=False
146
+ ),
147
+ nn.BatchNorm2d(
148
+ num_channels[branch_index] * block.expansion,
149
+ momentum=BN_MOMENTUM
150
+ ),
151
+ )
152
+
153
+ layers = []
154
+ layers.append(
155
+ block(
156
+ self.num_inchannels[branch_index],
157
+ num_channels[branch_index],
158
+ stride,
159
+ downsample
160
+ )
161
+ )
162
+ self.num_inchannels[branch_index] = \
163
+ num_channels[branch_index] * block.expansion
164
+ for i in range(1, num_blocks[branch_index]):
165
+ layers.append(
166
+ block(
167
+ self.num_inchannels[branch_index],
168
+ num_channels[branch_index]
169
+ )
170
+ )
171
+
172
+ return nn.Sequential(*layers)
173
+
174
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
175
+ branches = []
176
+
177
+ for i in range(num_branches):
178
+ branches.append(
179
+ self._make_one_branch(i, block, num_blocks, num_channels)
180
+ )
181
+
182
+ return nn.ModuleList(branches)
183
+
184
+ def _make_fuse_layers(self):
185
+ if self.num_branches == 1:
186
+ return None
187
+
188
+ num_branches = self.num_branches
189
+ num_inchannels = self.num_inchannels
190
+ fuse_layers = []
191
+ for i in range(num_branches if self.multi_scale_output else 1):
192
+ fuse_layer = []
193
+ for j in range(num_branches):
194
+ if j > i:
195
+ fuse_layer.append(
196
+ nn.Sequential(
197
+ nn.Conv2d(
198
+ num_inchannels[j],
199
+ num_inchannels[i],
200
+ 1, 1, 0, bias=False
201
+ ),
202
+ nn.BatchNorm2d(num_inchannels[i],momentum=BN_MOMENTUM),
203
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')
204
+ )
205
+ )
206
+ elif j == i:
207
+ fuse_layer.append(None)
208
+ else:
209
+ conv3x3s = []
210
+ for k in range(i-j):
211
+ if k == i - j - 1:
212
+ num_outchannels_conv3x3 = num_inchannels[i]
213
+ conv3x3s.append(
214
+ nn.Sequential(
215
+ nn.Conv2d(
216
+ num_inchannels[j],
217
+ num_outchannels_conv3x3,
218
+ 3, 2, 1, bias=False
219
+ ),
220
+ nn.BatchNorm2d(num_outchannels_conv3x3,momentum=BN_MOMENTUM)
221
+ )
222
+ )
223
+ else:
224
+ num_outchannels_conv3x3 = num_inchannels[j]
225
+ conv3x3s.append(
226
+ nn.Sequential(
227
+ nn.Conv2d(
228
+ num_inchannels[j],
229
+ num_outchannels_conv3x3,
230
+ 3, 2, 1, bias=False
231
+ ),
232
+ nn.BatchNorm2d(num_outchannels_conv3x3,momentum=BN_MOMENTUM),
233
+ nn.ReLU(True)
234
+ )
235
+ )
236
+ fuse_layer.append(nn.Sequential(*conv3x3s))
237
+ fuse_layers.append(nn.ModuleList(fuse_layer))
238
+
239
+ return nn.ModuleList(fuse_layers)
240
+
241
+ def get_num_inchannels(self):
242
+ return self.num_inchannels
243
+
244
+ def forward(self, x):
245
+ if self.num_branches == 1:
246
+ return [self.branches[0](x[0])]
247
+
248
+ for i in range(self.num_branches):
249
+ x[i] = self.branches[i](x[i])
250
+
251
+ x_fuse = []
252
+
253
+ for i in range(len(self.fuse_layers)):
254
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
255
+ for j in range(1, self.num_branches):
256
+ if i == j:
257
+ y = y + x[j]
258
+ else:
259
+ y = y + self.fuse_layers[i][j](x[j])
260
+ x_fuse.append(self.relu(y))
261
+
262
+ return x_fuse
263
+
264
+
265
+ blocks_dict = {
266
+ 'BASIC': BasicBlock,
267
+ 'BOTTLENECK': Bottleneck
268
+ }
269
+
270
+
271
+ class PoseHighResolutionNet(nn.Module):
272
+
273
+ def __init__(self, cfg, **kwargs):
274
+ self.inplanes = 64
275
+ self.cfg = cfg
276
+ extra = cfg['EXTRA']
277
+ super(PoseHighResolutionNet, self).__init__()
278
+
279
+ # stem net
280
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
281
+ bias=False)
282
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
283
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
284
+ bias=False)
285
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
286
+ self.relu = nn.ReLU(inplace=True)
287
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
288
+
289
+ self.stage2_cfg = extra['STAGE2']
290
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
291
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
292
+ num_channels = [
293
+ num_channels[i] * block.expansion for i in range(len(num_channels))
294
+ ]
295
+ self.transition1 = self._make_transition_layer([256], num_channels)
296
+ self.stage2, pre_stage_channels = self._make_stage(
297
+ self.stage2_cfg, num_channels)
298
+
299
+ self.stage3_cfg = extra['STAGE3']
300
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
301
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
302
+ num_channels = [
303
+ num_channels[i] * block.expansion for i in range(len(num_channels))
304
+ ]
305
+ self.transition2 = self._make_transition_layer(
306
+ pre_stage_channels, num_channels)
307
+ self.stage3, pre_stage_channels = self._make_stage(
308
+ self.stage3_cfg, num_channels)
309
+
310
+ self.stage4_cfg = extra['STAGE4']
311
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
312
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
313
+ num_channels = [
314
+ num_channels[i] * block.expansion for i in range(len(num_channels))
315
+ ]
316
+ self.transition3 = self._make_transition_layer(
317
+ pre_stage_channels, num_channels)
318
+ self.stage4, pre_stage_channels = self._make_stage(
319
+ self.stage4_cfg, num_channels, multi_scale_output=False)
320
+ self.last_channel = pre_stage_channels[0]
321
+ self.pretrained_layers = extra['PRETRAINED_LAYERS']
322
+
323
+ def _make_transition_layer(
324
+ self, num_channels_pre_layer, num_channels_cur_layer):
325
+ num_branches_cur = len(num_channels_cur_layer)
326
+ num_branches_pre = len(num_channels_pre_layer)
327
+
328
+ transition_layers = []
329
+ for i in range(num_branches_cur):
330
+ if i < num_branches_pre:
331
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
332
+ transition_layers.append(
333
+ nn.Sequential(
334
+ nn.Conv2d(
335
+ num_channels_pre_layer[i],
336
+ num_channels_cur_layer[i],
337
+ 3, 1, 1, bias=False
338
+ ),
339
+ nn.BatchNorm2d(num_channels_cur_layer[i],momentum=BN_MOMENTUM),
340
+ nn.ReLU(inplace=True)
341
+ )
342
+ )
343
+ else:
344
+ transition_layers.append(None)
345
+ else:
346
+ conv3x3s = []
347
+ for j in range(i+1-num_branches_pre):
348
+ inchannels = num_channels_pre_layer[-1]
349
+ outchannels = num_channels_cur_layer[i] \
350
+ if j == i-num_branches_pre else inchannels
351
+ conv3x3s.append(
352
+ nn.Sequential(
353
+ nn.Conv2d(
354
+ inchannels, outchannels, 3, 2, 1, bias=False
355
+ ),
356
+ nn.BatchNorm2d(outchannels,momentum=BN_MOMENTUM),
357
+ nn.ReLU(inplace=True)
358
+ )
359
+ )
360
+ transition_layers.append(nn.Sequential(*conv3x3s))
361
+
362
+ return nn.ModuleList(transition_layers)
363
+
364
+ def _make_layer(self, block, planes, blocks, stride=1):
365
+ downsample = None
366
+ if stride != 1 or self.inplanes != planes * block.expansion:
367
+ downsample = nn.Sequential(
368
+ nn.Conv2d(
369
+ self.inplanes, planes * block.expansion,
370
+ kernel_size=1, stride=stride, bias=False
371
+ ),
372
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
373
+ )
374
+
375
+ layers = []
376
+ layers.append(block(self.inplanes, planes, stride, downsample))
377
+ self.inplanes = planes * block.expansion
378
+ for i in range(1, blocks):
379
+ layers.append(block(self.inplanes, planes))
380
+
381
+ return nn.Sequential(*layers)
382
+
383
+ def _make_stage(self, layer_config, num_inchannels,
384
+ multi_scale_output=True):
385
+ num_modules = layer_config['NUM_MODULES']
386
+ num_branches = layer_config['NUM_BRANCHES']
387
+ num_blocks = layer_config['NUM_BLOCKS']
388
+ num_channels = layer_config['NUM_CHANNELS']
389
+ block = blocks_dict[layer_config['BLOCK']]
390
+ fuse_method = layer_config['FUSE_METHOD']
391
+
392
+ modules = []
393
+ for i in range(num_modules):
394
+ # multi_scale_output is only used last module
395
+ if not multi_scale_output and i == num_modules - 1:
396
+ reset_multi_scale_output = False
397
+ else:
398
+ reset_multi_scale_output = True
399
+
400
+ modules.append(
401
+ HighResolutionModule(
402
+ num_branches,
403
+ block,
404
+ num_blocks,
405
+ num_inchannels,
406
+ num_channels,
407
+ fuse_method,
408
+ reset_multi_scale_output
409
+ )
410
+ )
411
+ num_inchannels = modules[-1].get_num_inchannels()
412
+
413
+ return nn.Sequential(*modules), num_inchannels
414
+
415
+ def forward(self, x):
416
+ x = self.conv1(x)
417
+ x = self.bn1(x)
418
+ x = self.relu(x)
419
+ x = self.conv2(x)
420
+ x = self.bn2(x)
421
+ x = self.relu(x)
422
+ x = self.layer1(x)
423
+
424
+ x_list = []
425
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
426
+ if self.transition1[i] is not None:
427
+ x_list.append(self.transition1[i](x))
428
+ else:
429
+ x_list.append(x)
430
+ y_list = self.stage2(x_list)
431
+
432
+ x_list = []
433
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
434
+ if self.transition2[i] is not None:
435
+ x_list.append(self.transition2[i](y_list[-1]))
436
+ else:
437
+ x_list.append(y_list[i])
438
+ y_list = self.stage3(x_list)
439
+
440
+ x_list = []
441
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
442
+ if self.transition3[i] is not None:
443
+ x_list.append(self.transition3[i](y_list[-1]))
444
+ else:
445
+ x_list.append(y_list[i])
446
+ y_list = self.stage4(x_list)
447
+
448
+ #x = self.final_layer(y_list[0])
449
+ return y_list[0]
450
+
451
+ def init_weights(self, pretrained=''):
452
+ logger.info('=> init weights from normal distribution')
453
+ for m in self.modules():
454
+ if isinstance(m, nn.Conv2d):
455
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
456
+ nn.init.normal_(m.weight, std=0.001)
457
+ for name, _ in m.named_parameters():
458
+ if name in ['bias']:
459
+ nn.init.constant_(m.bias, 0)
460
+ elif isinstance(m, nn.BatchNorm2d):
461
+ nn.init.constant_(m.weight, 1)
462
+ nn.init.constant_(m.bias, 0)
463
+ elif isinstance(m, nn.ConvTranspose2d):
464
+ nn.init.normal_(m.weight, std=0.001)
465
+ for name, _ in m.named_parameters():
466
+ if name in ['bias']:
467
+ nn.init.constant_(m.bias, 0)
468
+
469
+ if os.path.isfile(pretrained):
470
+ pretrained_state_dict = torch.load(pretrained)
471
+ print('=> loading pretrained model {}'.format(pretrained))
472
+
473
+ need_init_state_dict = {}
474
+ for name, m in pretrained_state_dict.items():
475
+ if name.split('.')[0] in self.pretrained_layers \
476
+ or self.pretrained_layers[0] is '*':
477
+ need_init_state_dict[name] = m
478
+ self.load_state_dict(need_init_state_dict, strict=False)
479
+ elif pretrained:
480
+ print('=> please download pre-trained models first!')
481
+ raise ValueError('{} is not exist!'.format(pretrained))
482
+
483
+
484
+ def get_net(type="w32",is_train=True, init_weights=True,pretrained=''):
485
+ cur_dir = osp.dirname(__file__)
486
+ config_path = osp.join(cur_dir,type+".yaml")
487
+ cfg = _C
488
+ cfg.merge_from_file(config_path)
489
+ model = PoseHighResolutionNet(cfg)
490
+
491
+ if is_train and init_weights:
492
+ model.init_weights(pretrained)
493
+
494
+ return model