python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ from .transformer_blocks import TransformerBlock,Attention
2
+ from .conv_module import ConvModule
3
+ from .conv_ws import ConvAWS2d,conv_ws_2d
4
+ from .fc_module import FCModule
5
+ from .summary import *
6
+ from .nn import CHW2HWC,HWC2CHW,LayerNorm,ParallelModule,SumModule,AttentionPool2d
7
+ from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
8
+ from .functional import soft_one_hot
wml/wtorch/bboxes.py ADDED
@@ -0,0 +1,104 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def cxywh2xy(bboxes):
5
+ cxy = bboxes[...,:2]
6
+ hwh = bboxes[...,2:]/2
7
+ minxy = cxy-hwh
8
+ maxxy = cxy+hwh
9
+ return torch.cat([minxy,maxxy],dim=-1)
10
+
11
+ def xy2cxywh(bboxes):
12
+ wh = bboxes[...,2:]-bboxes[...,:2]
13
+ cxy = (bboxes[...,2:]+bboxes[...,:2])/2
14
+ return torch.cat([cxy,wh],dim=-1)
15
+
16
+ def distored_boxes(bboxes:torch.Tensor,scale=[0.8,1.2],offset=0.2):
17
+ bboxes = xy2cxywh(bboxes)
18
+ cxy,wh = torch.split(bboxes,2,dim=-1)
19
+ wh_scales = torch.rand(list(wh.shape),dtype=bboxes.dtype)*(scale[1]-scale[0])+scale[0]
20
+ wh_scales = wh_scales.to(wh.device)
21
+ wh = wh*wh_scales
22
+ cxy_offset = torch.rand(list(cxy.shape),dtype=cxy.dtype)*offset
23
+ cxy_offset = cxy_offset.to(cxy.device)
24
+ cxy = cxy+cxy_offset
25
+ bboxes = torch.cat([cxy,wh],axis=-1)
26
+ bboxes = cxywh2xy(bboxes)
27
+ bboxes = torch.nn.functional.relu(bboxes)
28
+ return bboxes
29
+
30
+ def flip(bboxes,size):
31
+ '''
32
+ bboxes:[N,4][xmin,ymin,xmax,ymax]
33
+ size:[H,W]
34
+ '''
35
+ _bboxes = torch.clone(bboxes)
36
+ _bboxes[...,0] = size[1]-bboxes[...,2]
37
+ _bboxes[...,2] = size[1]-bboxes[...,0]
38
+ return _bboxes
39
+
40
+ def bboxes_ious(bboxesa, bboxesb):
41
+ '''
42
+ bboxesa: [N,4] or [1,4] (xmin,ymin,xmax,ymax)
43
+ bboxesb: [N,4] or [1,4] (xmin,ymin,xmax,ymax)
44
+ return:
45
+ [N]
46
+ '''
47
+
48
+ bboxesa = torch.unbind(bboxesa,-1)
49
+ bboxesb = torch.unbind(bboxesb,-1)
50
+ int_xmin = torch.maximum(bboxesa[0], bboxesb[0])
51
+ int_ymin = torch.maximum(bboxesa[1], bboxesb[1])
52
+ int_xmax = torch.minimum(bboxesa[2], bboxesb[2])
53
+ int_ymax = torch.minimum(bboxesa[3], bboxesb[3])
54
+ h = F.relu(int_ymax - int_ymin)
55
+ w = F.relu(int_xmax - int_xmin)
56
+ inter_vol = h * w
57
+ union_vol = -inter_vol \
58
+ + (bboxesa[2] - bboxesa[0]) * (bboxesa[3] - bboxesa[1]) \
59
+ + (bboxesb[2] - bboxesb[0]) * (bboxesb[3] - bboxesb[1])
60
+ union_vol.clamp_(min=1e-8)
61
+ jaccard = inter_vol/union_vol
62
+ return jaccard
63
+
64
+
65
+ def bboxes_ious_matrix(bboxes0,bboxes1):
66
+ bboxes0 = torch.unsqueeze(bboxes0,dim=1)
67
+ bboxes1 = torch.unsqueeze(bboxes1,dim=0)
68
+
69
+ x_int_min = torch.maximum(bboxes0[...,0],bboxes1[...,0])
70
+ x_int_max = torch.minimum(bboxes0[...,2],bboxes1[...,2])
71
+ y_int_min = torch.maximum(bboxes0[...,1],bboxes1[...,1])
72
+ y_int_max = torch.minimum(bboxes0[...,3],bboxes1[...,3])
73
+
74
+ int_w = x_int_max-x_int_min
75
+ int_h = y_int_max-y_int_min
76
+ int_w.clamp_(min=0.0)
77
+ int_h.clamp_(min=0.0)
78
+ inter_vol = int_w*int_h
79
+ areas0 = torch.prod(bboxes0[...,2:]-bboxes0[...,:2],dim=-1)
80
+ areas1 = torch.prod(bboxes1[...,2:]-bboxes1[...,:2],dim=-1)
81
+ union_vol = areas0+areas1-inter_vol
82
+
83
+ union_vol.clamp_(min=1e-8)
84
+
85
+ return inter_vol/union_vol
86
+
87
+ iou_matrix = bboxes_ious_matrix
88
+
89
+ def correct_bbox(bboxes,w,h):
90
+ '''
91
+ bboxes: [N,4](x0,y0,x1,y1)
92
+ '''
93
+ bboxes[:,0:4:2] = torch.clamp(bboxes[:,0:4:2],min=0,max=w)
94
+ bboxes[:,1:4:2] = torch.clamp(bboxes[:,1:4:2],min=0,max=h)
95
+
96
+ return bboxes
97
+
98
+ def area(bboxes):
99
+ ws = bboxes[:,2]-bboxes[:,0]
100
+ hs = bboxes[:,3]-bboxes[:,1]
101
+ area = ws*hs
102
+ return area
103
+
104
+
@@ -0,0 +1,24 @@
1
+ import numpy as np
2
+
3
+ '''
4
+ 根据scores得分,从输入中仅取一个类别
5
+ '''
6
+ def classes_suppression(bboxes,labels,scores,test_nr=1):
7
+ if len(labels)==0:
8
+ return bboxes,labels,scores
9
+ max_scores = -1
10
+ max_l = -1
11
+ u_labels = set(labels.tolist())
12
+ for l in u_labels:
13
+ m = labels==l
14
+ idx = np.argsort(-scores[m])[:test_nr]
15
+ t_scores = np.mean(scores[m][idx])
16
+ if t_scores>max_scores:
17
+ max_scores = t_scores
18
+ max_l = l
19
+
20
+ m = labels==max_l
21
+
22
+ return bboxes[m],labels[m],scores[m]
23
+
24
+
@@ -0,0 +1,181 @@
1
+ import warnings
2
+ import copy
3
+ import torch.nn as nn
4
+ import wml.wtorch.nn as wnn
5
+ from torch.nn.modules.batchnorm import _BatchNorm
6
+ from torch.nn.modules.instancenorm import _InstanceNorm
7
+ from .nn import get_conv_type
8
+
9
+
10
+ class ConvModule(nn.Module):
11
+ """A conv block that bundles conv/norm/activation layers.
12
+
13
+ This block simplifies the usage of convolution layers, which are commonly
14
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
15
+ It is based upon three build methods: `build_conv_layer()`,
16
+ `build_norm_layer()` and `build_activation_layer()`.
17
+
18
+ Besides, we add some additional features in this module.
19
+ 1. Automatically set `bias` of the conv layer.
20
+ 2. Spectral norm is supported.
21
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
22
+ supports zero and circular padding, and we add "reflect" padding mode.
23
+
24
+ Args:
25
+ in_channels (int): Number of channels in the input feature map.
26
+ Same as that in ``nn._ConvNd``.
27
+ out_channels (int): Number of channels produced by the convolution.
28
+ Same as that in ``nn._ConvNd``.
29
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
30
+ Same as that in ``nn._ConvNd``.
31
+ stride (int | tuple[int]): Stride of the convolution.
32
+ Same as that in ``nn._ConvNd``.
33
+ padding (int | tuple[int]): Zero-padding added to both sides of
34
+ the input. Same as that in ``nn._ConvNd``.
35
+ dilation (int | tuple[int]): Spacing between kernel elements.
36
+ Same as that in ``nn._ConvNd``.
37
+ groups (int): Number of blocked connections from input channels to
38
+ output channels. Same as that in ``nn._ConvNd``.
39
+ bias (bool | str): If specified as `auto`, it will be decided by the
40
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
41
+ False. Default: "auto".
42
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
43
+ which means using conv2d.
44
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
45
+ act_cfg (dict): Config dict for activation layer.
46
+ Default: dict(type='ReLU').
47
+ inplace (bool): Whether to use inplace mode for activation.
48
+ Default: True.
49
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
50
+ Default: False.
51
+ padding_mode (str): If the `padding_mode` has not been supported by
52
+ current `Conv2d` in PyTorch, we will use our own padding layer
53
+ instead. Currently, we support ['zeros', 'circular'] with official
54
+ implementation and ['reflect'] with our own implementation.
55
+ Default: 'zeros'.
56
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
57
+ sequence of "conv", "norm" and "act". Common examples are
58
+ ("conv", "norm", "act") and ("act", "conv", "norm").
59
+ Default: ('conv', 'norm', 'act').
60
+ """
61
+
62
+ _abbr_ = 'conv_block'
63
+
64
+ def __init__(self,
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size,
68
+ stride=1,
69
+ padding=0,
70
+ dilation=1,
71
+ groups=1,
72
+ bias='auto',
73
+ conv_cfg=None,
74
+ norm_cfg=None,
75
+ act_cfg=dict(type='ReLU'),
76
+ inplace=True,
77
+ with_spectral_norm=False,
78
+ padding_mode='zeros',
79
+ order=('conv', 'norm', 'act')):
80
+ super(ConvModule, self).__init__()
81
+ norm_cfg = copy.deepcopy(norm_cfg)
82
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
83
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
84
+ assert act_cfg is None or isinstance(act_cfg, dict)
85
+ official_padding_mode = ['zeros', 'circular']
86
+ self.conv_cfg = conv_cfg
87
+ self.norm_cfg = norm_cfg
88
+ self.act_cfg = act_cfg
89
+ self.inplace = inplace
90
+ self.with_spectral_norm = with_spectral_norm
91
+ self.with_explicit_padding = padding_mode not in official_padding_mode
92
+ self.order = order
93
+ assert isinstance(self.order, tuple) and len(self.order) == 3
94
+ assert set(order) == set(['conv', 'norm', 'act'])
95
+
96
+ self.with_norm = norm_cfg is not None
97
+ self.with_activation = act_cfg is not None
98
+ # if the conv layer is before a norm layer, bias is unnecessary.
99
+ if bias == 'auto':
100
+ bias = not self.with_norm
101
+ self.with_bias = bias
102
+
103
+ if self.with_explicit_padding:
104
+ pad_cfg = dict(type=padding_mode)
105
+ raise RuntimeError("Unsupport with_explicit_padding")
106
+
107
+ # reset padding to 0 for conv module
108
+ conv_padding = 0 if self.with_explicit_padding else padding
109
+ # build convolution layer
110
+ self.conv = get_conv_type(conv_cfg)(
111
+ in_channels,
112
+ out_channels,
113
+ kernel_size,
114
+ stride=stride,
115
+ padding=conv_padding,
116
+ dilation=dilation,
117
+ groups=groups,
118
+ bias=bias)
119
+ # export the attributes of self.conv to a higher level for convenience
120
+ self.in_channels = self.conv.in_channels
121
+ self.out_channels = self.conv.out_channels
122
+ self.kernel_size = self.conv.kernel_size
123
+ self.stride = self.conv.stride
124
+ self.padding = padding
125
+ self.dilation = self.conv.dilation
126
+ self.transposed = self.conv.transposed
127
+ self.output_padding = self.conv.output_padding
128
+ self.groups = self.conv.groups
129
+
130
+ if self.with_spectral_norm:
131
+ self.conv = nn.utils.spectral_norm(self.conv)
132
+
133
+ # build normalization layers
134
+ if self.with_norm:
135
+ # norm layer is after conv layer
136
+ if order.index('norm') > order.index('conv'):
137
+ norm_channels = out_channels
138
+ else:
139
+ norm_channels = in_channels
140
+ norm_type = norm_cfg.pop('type')
141
+ norm = wnn.get_norm(norm_type,norm_channels,norm_args=norm_cfg)
142
+ self.norm_name = norm_type
143
+ self.add_module(self.norm_name, norm)
144
+ if self.with_bias:
145
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
146
+ warnings.warn(
147
+ 'Unnecessary conv bias before batch/instance norm')
148
+ else:
149
+ self.norm_name = None
150
+
151
+ # build activation layer
152
+ if self.with_activation:
153
+ act_cfg_ = act_cfg.copy()
154
+ # nn.Tanh has no 'inplace' argument
155
+ if act_cfg_['type'] not in [
156
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
157
+ ]:
158
+ act_cfg_.setdefault('inplace', inplace)
159
+ act_type = act_cfg_.pop('type')
160
+ inplace = act_cfg_.get('inplace',True)
161
+ self.activate = wnn.get_activation(act_type,inplace)
162
+
163
+
164
+ @property
165
+ def norm(self):
166
+ if self.norm_name:
167
+ return getattr(self, self.norm_name)
168
+ else:
169
+ return None
170
+
171
+ def forward(self, x, activate=True, norm=True):
172
+ for layer in self.order:
173
+ if layer == 'conv':
174
+ if self.with_explicit_padding:
175
+ x = self.padding_layer(x)
176
+ x = self.conv(x)
177
+ elif layer == 'norm' and norm and self.with_norm:
178
+ x = self.norm(x)
179
+ elif layer == 'act' and activate and self.with_activation:
180
+ x = self.activate(x)
181
+ return x
wml/wtorch/conv_ws.py ADDED
@@ -0,0 +1,144 @@
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def conv_ws_2d(input,
8
+ weight,
9
+ bias=None,
10
+ stride=1,
11
+ padding=0,
12
+ dilation=1,
13
+ groups=1,
14
+ eps=1e-5):
15
+ c_in = weight.size(0)
16
+ weight_flat = weight.view(c_in, -1)
17
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
18
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
19
+ weight = (weight - mean) / (std + eps)
20
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
21
+
22
+
23
+ class ConvWS2d(nn.Conv2d):
24
+
25
+ def __init__(self,
26
+ in_channels,
27
+ out_channels,
28
+ kernel_size,
29
+ stride=1,
30
+ padding=0,
31
+ dilation=1,
32
+ groups=1,
33
+ bias=True,
34
+ eps=1e-5):
35
+ super(ConvWS2d, self).__init__(
36
+ in_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ stride=stride,
40
+ padding=padding,
41
+ dilation=dilation,
42
+ groups=groups,
43
+ bias=bias)
44
+ self.eps = eps
45
+
46
+ def forward(self, x):
47
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
48
+ self.dilation, self.groups, self.eps)
49
+
50
+
51
+ class ConvAWS2d(nn.Conv2d):
52
+ """AWS (Adaptive Weight Standardization)
53
+
54
+ This is a variant of Weight Standardization
55
+ (https://arxiv.org/pdf/1903.10520.pdf)
56
+ It is used in DetectoRS to avoid NaN
57
+ (https://arxiv.org/pdf/2006.02334.pdf)
58
+
59
+ Args:
60
+ in_channels (int): Number of channels in the input image
61
+ out_channels (int): Number of channels produced by the convolution
62
+ kernel_size (int or tuple): Size of the conv kernel
63
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
64
+ padding (int or tuple, optional): Zero-padding added to both sides of
65
+ the input. Default: 0
66
+ dilation (int or tuple, optional): Spacing between kernel elements.
67
+ Default: 1
68
+ groups (int, optional): Number of blocked connections from input
69
+ channels to output channels. Default: 1
70
+ bias (bool, optional): If set True, adds a learnable bias to the
71
+ output. Default: True
72
+ """
73
+
74
+ def __init__(self,
75
+ in_channels,
76
+ out_channels,
77
+ kernel_size,
78
+ stride=1,
79
+ padding=0,
80
+ dilation=1,
81
+ groups=1,
82
+ bias=True):
83
+ super().__init__(
84
+ in_channels,
85
+ out_channels,
86
+ kernel_size,
87
+ stride=stride,
88
+ padding=padding,
89
+ dilation=dilation,
90
+ groups=groups,
91
+ bias=bias)
92
+ self.register_buffer('weight_gamma',
93
+ torch.ones(self.out_channels, 1, 1, 1))
94
+ self.register_buffer('weight_beta',
95
+ torch.zeros(self.out_channels, 1, 1, 1))
96
+
97
+ def _get_weight(self, weight):
98
+ weight_flat = weight.view(weight.size(0), -1)
99
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
100
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
101
+ weight = (weight - mean) / std
102
+ weight = self.weight_gamma * weight + self.weight_beta
103
+ return weight
104
+
105
+ def forward(self, x):
106
+ weight = self._get_weight(self.weight)
107
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
108
+ self.dilation, self.groups)
109
+
110
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
111
+ missing_keys, unexpected_keys, error_msgs):
112
+ """Override default load function.
113
+
114
+ AWS overrides the function _load_from_state_dict to recover
115
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
116
+ weight_beta are found in the checkpoint, this function will return
117
+ after super()._load_from_state_dict. Otherwise, it will compute the
118
+ mean and std of the pretrained weights and store them in weight_beta
119
+ and weight_gamma.
120
+ """
121
+
122
+ self.weight_gamma.data.fill_(-1)
123
+ local_missing_keys = []
124
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
125
+ strict, local_missing_keys,
126
+ unexpected_keys, error_msgs)
127
+ if self.weight_gamma.data.mean() > 0:
128
+ for k in local_missing_keys:
129
+ missing_keys.append(k)
130
+ return
131
+ weight = self.weight.data
132
+ weight_flat = weight.view(weight.size(0), -1)
133
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
134
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
135
+ self.weight_beta.data.copy_(mean)
136
+ self.weight_gamma.data.copy_(std)
137
+ missing_gamma_beta = [
138
+ k for k in local_missing_keys
139
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
140
+ ]
141
+ for k in missing_gamma_beta:
142
+ local_missing_keys.remove(k)
143
+ for k in local_missing_keys:
144
+ missing_keys.append(k)
@@ -0,0 +1,16 @@
1
+ from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler,InfiniteSequentialSampler
2
+ from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, BufferedShuffleDataset,
3
+ Subset, random_split)
4
+ from .base_data_loader_iter import _BaseDataLoaderIter, _DatasetKind
5
+ from .single_process_data_loader_iter import _SingleProcessDataLoaderIter
6
+ from .multi_processing_data_loader_iter import _MultiProcessingDataLoaderIter
7
+ from .dataset import IterableDataset as IterDataPipe
8
+ from .distributed import DistributedSampler
9
+ from .dataloader import DataLoader, _DatasetKind, get_worker_info
10
+
11
+ __all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
12
+ 'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
13
+ 'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
14
+ 'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
15
+ 'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
16
+ 'IterDataPipe']
@@ -0,0 +1,45 @@
1
+ r"""Utility classes & functions for data loading. Code in this folder is mostly
2
+ used by ../dataloder.py.
3
+
4
+ A lot of multiprocessing is used in data loading, which only supports running
5
+ functions defined in global environment (py2 can't serialize static methods).
6
+ Therefore, for code tidiness we put these functions into different files in this
7
+ folder.
8
+ """
9
+
10
+ import sys
11
+ import atexit
12
+
13
+ # old private location of the ExceptionWrapper that some users rely on:
14
+ from torch._utils import ExceptionWrapper
15
+
16
+
17
+ IS_WINDOWS = sys.platform == "win32"
18
+
19
+
20
+ MP_STATUS_CHECK_INTERVAL = 5.0
21
+ r"""Interval (in seconds) to check status of processes to avoid hanging in
22
+ multiprocessing data loading. This is mainly used in getting data from
23
+ another process, in which case we need to periodically check whether the
24
+ sender is alive to prevent hanging."""
25
+
26
+
27
+ python_exit_status = False
28
+ r"""Whether Python is shutting down. This flag is guaranteed to be set before
29
+ the Python core library resources are freed, but Python may already be exiting
30
+ for some time when this is set.
31
+
32
+ Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
33
+ hook in Python 3.7 multiprocessing library:
34
+ https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
35
+ """
36
+
37
+
38
+ def _set_python_exit_flag():
39
+ global python_exit_status
40
+ python_exit_status = True
41
+
42
+ atexit.register(_set_python_exit_flag)
43
+
44
+
45
+ from . import worker, signal_handling, pin_memory, collate, fetch