python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .transformer_blocks import TransformerBlock,Attention
|
|
2
|
+
from .conv_module import ConvModule
|
|
3
|
+
from .conv_ws import ConvAWS2d,conv_ws_2d
|
|
4
|
+
from .fc_module import FCModule
|
|
5
|
+
from .summary import *
|
|
6
|
+
from .nn import CHW2HWC,HWC2CHW,LayerNorm,ParallelModule,SumModule,AttentionPool2d
|
|
7
|
+
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
|
|
8
|
+
from .functional import soft_one_hot
|
wml/wtorch/bboxes.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
def cxywh2xy(bboxes):
|
|
5
|
+
cxy = bboxes[...,:2]
|
|
6
|
+
hwh = bboxes[...,2:]/2
|
|
7
|
+
minxy = cxy-hwh
|
|
8
|
+
maxxy = cxy+hwh
|
|
9
|
+
return torch.cat([minxy,maxxy],dim=-1)
|
|
10
|
+
|
|
11
|
+
def xy2cxywh(bboxes):
|
|
12
|
+
wh = bboxes[...,2:]-bboxes[...,:2]
|
|
13
|
+
cxy = (bboxes[...,2:]+bboxes[...,:2])/2
|
|
14
|
+
return torch.cat([cxy,wh],dim=-1)
|
|
15
|
+
|
|
16
|
+
def distored_boxes(bboxes:torch.Tensor,scale=[0.8,1.2],offset=0.2):
|
|
17
|
+
bboxes = xy2cxywh(bboxes)
|
|
18
|
+
cxy,wh = torch.split(bboxes,2,dim=-1)
|
|
19
|
+
wh_scales = torch.rand(list(wh.shape),dtype=bboxes.dtype)*(scale[1]-scale[0])+scale[0]
|
|
20
|
+
wh_scales = wh_scales.to(wh.device)
|
|
21
|
+
wh = wh*wh_scales
|
|
22
|
+
cxy_offset = torch.rand(list(cxy.shape),dtype=cxy.dtype)*offset
|
|
23
|
+
cxy_offset = cxy_offset.to(cxy.device)
|
|
24
|
+
cxy = cxy+cxy_offset
|
|
25
|
+
bboxes = torch.cat([cxy,wh],axis=-1)
|
|
26
|
+
bboxes = cxywh2xy(bboxes)
|
|
27
|
+
bboxes = torch.nn.functional.relu(bboxes)
|
|
28
|
+
return bboxes
|
|
29
|
+
|
|
30
|
+
def flip(bboxes,size):
|
|
31
|
+
'''
|
|
32
|
+
bboxes:[N,4][xmin,ymin,xmax,ymax]
|
|
33
|
+
size:[H,W]
|
|
34
|
+
'''
|
|
35
|
+
_bboxes = torch.clone(bboxes)
|
|
36
|
+
_bboxes[...,0] = size[1]-bboxes[...,2]
|
|
37
|
+
_bboxes[...,2] = size[1]-bboxes[...,0]
|
|
38
|
+
return _bboxes
|
|
39
|
+
|
|
40
|
+
def bboxes_ious(bboxesa, bboxesb):
|
|
41
|
+
'''
|
|
42
|
+
bboxesa: [N,4] or [1,4] (xmin,ymin,xmax,ymax)
|
|
43
|
+
bboxesb: [N,4] or [1,4] (xmin,ymin,xmax,ymax)
|
|
44
|
+
return:
|
|
45
|
+
[N]
|
|
46
|
+
'''
|
|
47
|
+
|
|
48
|
+
bboxesa = torch.unbind(bboxesa,-1)
|
|
49
|
+
bboxesb = torch.unbind(bboxesb,-1)
|
|
50
|
+
int_xmin = torch.maximum(bboxesa[0], bboxesb[0])
|
|
51
|
+
int_ymin = torch.maximum(bboxesa[1], bboxesb[1])
|
|
52
|
+
int_xmax = torch.minimum(bboxesa[2], bboxesb[2])
|
|
53
|
+
int_ymax = torch.minimum(bboxesa[3], bboxesb[3])
|
|
54
|
+
h = F.relu(int_ymax - int_ymin)
|
|
55
|
+
w = F.relu(int_xmax - int_xmin)
|
|
56
|
+
inter_vol = h * w
|
|
57
|
+
union_vol = -inter_vol \
|
|
58
|
+
+ (bboxesa[2] - bboxesa[0]) * (bboxesa[3] - bboxesa[1]) \
|
|
59
|
+
+ (bboxesb[2] - bboxesb[0]) * (bboxesb[3] - bboxesb[1])
|
|
60
|
+
union_vol.clamp_(min=1e-8)
|
|
61
|
+
jaccard = inter_vol/union_vol
|
|
62
|
+
return jaccard
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def bboxes_ious_matrix(bboxes0,bboxes1):
|
|
66
|
+
bboxes0 = torch.unsqueeze(bboxes0,dim=1)
|
|
67
|
+
bboxes1 = torch.unsqueeze(bboxes1,dim=0)
|
|
68
|
+
|
|
69
|
+
x_int_min = torch.maximum(bboxes0[...,0],bboxes1[...,0])
|
|
70
|
+
x_int_max = torch.minimum(bboxes0[...,2],bboxes1[...,2])
|
|
71
|
+
y_int_min = torch.maximum(bboxes0[...,1],bboxes1[...,1])
|
|
72
|
+
y_int_max = torch.minimum(bboxes0[...,3],bboxes1[...,3])
|
|
73
|
+
|
|
74
|
+
int_w = x_int_max-x_int_min
|
|
75
|
+
int_h = y_int_max-y_int_min
|
|
76
|
+
int_w.clamp_(min=0.0)
|
|
77
|
+
int_h.clamp_(min=0.0)
|
|
78
|
+
inter_vol = int_w*int_h
|
|
79
|
+
areas0 = torch.prod(bboxes0[...,2:]-bboxes0[...,:2],dim=-1)
|
|
80
|
+
areas1 = torch.prod(bboxes1[...,2:]-bboxes1[...,:2],dim=-1)
|
|
81
|
+
union_vol = areas0+areas1-inter_vol
|
|
82
|
+
|
|
83
|
+
union_vol.clamp_(min=1e-8)
|
|
84
|
+
|
|
85
|
+
return inter_vol/union_vol
|
|
86
|
+
|
|
87
|
+
iou_matrix = bboxes_ious_matrix
|
|
88
|
+
|
|
89
|
+
def correct_bbox(bboxes,w,h):
|
|
90
|
+
'''
|
|
91
|
+
bboxes: [N,4](x0,y0,x1,y1)
|
|
92
|
+
'''
|
|
93
|
+
bboxes[:,0:4:2] = torch.clamp(bboxes[:,0:4:2],min=0,max=w)
|
|
94
|
+
bboxes[:,1:4:2] = torch.clamp(bboxes[:,1:4:2],min=0,max=h)
|
|
95
|
+
|
|
96
|
+
return bboxes
|
|
97
|
+
|
|
98
|
+
def area(bboxes):
|
|
99
|
+
ws = bboxes[:,2]-bboxes[:,0]
|
|
100
|
+
hs = bboxes[:,3]-bboxes[:,1]
|
|
101
|
+
area = ws*hs
|
|
102
|
+
return area
|
|
103
|
+
|
|
104
|
+
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
'''
|
|
4
|
+
根据scores得分,从输入中仅取一个类别
|
|
5
|
+
'''
|
|
6
|
+
def classes_suppression(bboxes,labels,scores,test_nr=1):
|
|
7
|
+
if len(labels)==0:
|
|
8
|
+
return bboxes,labels,scores
|
|
9
|
+
max_scores = -1
|
|
10
|
+
max_l = -1
|
|
11
|
+
u_labels = set(labels.tolist())
|
|
12
|
+
for l in u_labels:
|
|
13
|
+
m = labels==l
|
|
14
|
+
idx = np.argsort(-scores[m])[:test_nr]
|
|
15
|
+
t_scores = np.mean(scores[m][idx])
|
|
16
|
+
if t_scores>max_scores:
|
|
17
|
+
max_scores = t_scores
|
|
18
|
+
max_l = l
|
|
19
|
+
|
|
20
|
+
m = labels==max_l
|
|
21
|
+
|
|
22
|
+
return bboxes[m],labels[m],scores[m]
|
|
23
|
+
|
|
24
|
+
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import copy
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import wml.wtorch.nn as wnn
|
|
5
|
+
from torch.nn.modules.batchnorm import _BatchNorm
|
|
6
|
+
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
7
|
+
from .nn import get_conv_type
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConvModule(nn.Module):
|
|
11
|
+
"""A conv block that bundles conv/norm/activation layers.
|
|
12
|
+
|
|
13
|
+
This block simplifies the usage of convolution layers, which are commonly
|
|
14
|
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
|
15
|
+
It is based upon three build methods: `build_conv_layer()`,
|
|
16
|
+
`build_norm_layer()` and `build_activation_layer()`.
|
|
17
|
+
|
|
18
|
+
Besides, we add some additional features in this module.
|
|
19
|
+
1. Automatically set `bias` of the conv layer.
|
|
20
|
+
2. Spectral norm is supported.
|
|
21
|
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
|
22
|
+
supports zero and circular padding, and we add "reflect" padding mode.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
in_channels (int): Number of channels in the input feature map.
|
|
26
|
+
Same as that in ``nn._ConvNd``.
|
|
27
|
+
out_channels (int): Number of channels produced by the convolution.
|
|
28
|
+
Same as that in ``nn._ConvNd``.
|
|
29
|
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
|
30
|
+
Same as that in ``nn._ConvNd``.
|
|
31
|
+
stride (int | tuple[int]): Stride of the convolution.
|
|
32
|
+
Same as that in ``nn._ConvNd``.
|
|
33
|
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
|
34
|
+
the input. Same as that in ``nn._ConvNd``.
|
|
35
|
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
|
36
|
+
Same as that in ``nn._ConvNd``.
|
|
37
|
+
groups (int): Number of blocked connections from input channels to
|
|
38
|
+
output channels. Same as that in ``nn._ConvNd``.
|
|
39
|
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
|
40
|
+
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
|
41
|
+
False. Default: "auto".
|
|
42
|
+
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
|
43
|
+
which means using conv2d.
|
|
44
|
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
|
45
|
+
act_cfg (dict): Config dict for activation layer.
|
|
46
|
+
Default: dict(type='ReLU').
|
|
47
|
+
inplace (bool): Whether to use inplace mode for activation.
|
|
48
|
+
Default: True.
|
|
49
|
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
|
50
|
+
Default: False.
|
|
51
|
+
padding_mode (str): If the `padding_mode` has not been supported by
|
|
52
|
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
|
53
|
+
instead. Currently, we support ['zeros', 'circular'] with official
|
|
54
|
+
implementation and ['reflect'] with our own implementation.
|
|
55
|
+
Default: 'zeros'.
|
|
56
|
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
|
57
|
+
sequence of "conv", "norm" and "act". Common examples are
|
|
58
|
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
|
59
|
+
Default: ('conv', 'norm', 'act').
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
_abbr_ = 'conv_block'
|
|
63
|
+
|
|
64
|
+
def __init__(self,
|
|
65
|
+
in_channels,
|
|
66
|
+
out_channels,
|
|
67
|
+
kernel_size,
|
|
68
|
+
stride=1,
|
|
69
|
+
padding=0,
|
|
70
|
+
dilation=1,
|
|
71
|
+
groups=1,
|
|
72
|
+
bias='auto',
|
|
73
|
+
conv_cfg=None,
|
|
74
|
+
norm_cfg=None,
|
|
75
|
+
act_cfg=dict(type='ReLU'),
|
|
76
|
+
inplace=True,
|
|
77
|
+
with_spectral_norm=False,
|
|
78
|
+
padding_mode='zeros',
|
|
79
|
+
order=('conv', 'norm', 'act')):
|
|
80
|
+
super(ConvModule, self).__init__()
|
|
81
|
+
norm_cfg = copy.deepcopy(norm_cfg)
|
|
82
|
+
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
|
83
|
+
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
|
84
|
+
assert act_cfg is None or isinstance(act_cfg, dict)
|
|
85
|
+
official_padding_mode = ['zeros', 'circular']
|
|
86
|
+
self.conv_cfg = conv_cfg
|
|
87
|
+
self.norm_cfg = norm_cfg
|
|
88
|
+
self.act_cfg = act_cfg
|
|
89
|
+
self.inplace = inplace
|
|
90
|
+
self.with_spectral_norm = with_spectral_norm
|
|
91
|
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
|
92
|
+
self.order = order
|
|
93
|
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
|
94
|
+
assert set(order) == set(['conv', 'norm', 'act'])
|
|
95
|
+
|
|
96
|
+
self.with_norm = norm_cfg is not None
|
|
97
|
+
self.with_activation = act_cfg is not None
|
|
98
|
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
|
99
|
+
if bias == 'auto':
|
|
100
|
+
bias = not self.with_norm
|
|
101
|
+
self.with_bias = bias
|
|
102
|
+
|
|
103
|
+
if self.with_explicit_padding:
|
|
104
|
+
pad_cfg = dict(type=padding_mode)
|
|
105
|
+
raise RuntimeError("Unsupport with_explicit_padding")
|
|
106
|
+
|
|
107
|
+
# reset padding to 0 for conv module
|
|
108
|
+
conv_padding = 0 if self.with_explicit_padding else padding
|
|
109
|
+
# build convolution layer
|
|
110
|
+
self.conv = get_conv_type(conv_cfg)(
|
|
111
|
+
in_channels,
|
|
112
|
+
out_channels,
|
|
113
|
+
kernel_size,
|
|
114
|
+
stride=stride,
|
|
115
|
+
padding=conv_padding,
|
|
116
|
+
dilation=dilation,
|
|
117
|
+
groups=groups,
|
|
118
|
+
bias=bias)
|
|
119
|
+
# export the attributes of self.conv to a higher level for convenience
|
|
120
|
+
self.in_channels = self.conv.in_channels
|
|
121
|
+
self.out_channels = self.conv.out_channels
|
|
122
|
+
self.kernel_size = self.conv.kernel_size
|
|
123
|
+
self.stride = self.conv.stride
|
|
124
|
+
self.padding = padding
|
|
125
|
+
self.dilation = self.conv.dilation
|
|
126
|
+
self.transposed = self.conv.transposed
|
|
127
|
+
self.output_padding = self.conv.output_padding
|
|
128
|
+
self.groups = self.conv.groups
|
|
129
|
+
|
|
130
|
+
if self.with_spectral_norm:
|
|
131
|
+
self.conv = nn.utils.spectral_norm(self.conv)
|
|
132
|
+
|
|
133
|
+
# build normalization layers
|
|
134
|
+
if self.with_norm:
|
|
135
|
+
# norm layer is after conv layer
|
|
136
|
+
if order.index('norm') > order.index('conv'):
|
|
137
|
+
norm_channels = out_channels
|
|
138
|
+
else:
|
|
139
|
+
norm_channels = in_channels
|
|
140
|
+
norm_type = norm_cfg.pop('type')
|
|
141
|
+
norm = wnn.get_norm(norm_type,norm_channels,norm_args=norm_cfg)
|
|
142
|
+
self.norm_name = norm_type
|
|
143
|
+
self.add_module(self.norm_name, norm)
|
|
144
|
+
if self.with_bias:
|
|
145
|
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
|
146
|
+
warnings.warn(
|
|
147
|
+
'Unnecessary conv bias before batch/instance norm')
|
|
148
|
+
else:
|
|
149
|
+
self.norm_name = None
|
|
150
|
+
|
|
151
|
+
# build activation layer
|
|
152
|
+
if self.with_activation:
|
|
153
|
+
act_cfg_ = act_cfg.copy()
|
|
154
|
+
# nn.Tanh has no 'inplace' argument
|
|
155
|
+
if act_cfg_['type'] not in [
|
|
156
|
+
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
|
|
157
|
+
]:
|
|
158
|
+
act_cfg_.setdefault('inplace', inplace)
|
|
159
|
+
act_type = act_cfg_.pop('type')
|
|
160
|
+
inplace = act_cfg_.get('inplace',True)
|
|
161
|
+
self.activate = wnn.get_activation(act_type,inplace)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def norm(self):
|
|
166
|
+
if self.norm_name:
|
|
167
|
+
return getattr(self, self.norm_name)
|
|
168
|
+
else:
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
def forward(self, x, activate=True, norm=True):
|
|
172
|
+
for layer in self.order:
|
|
173
|
+
if layer == 'conv':
|
|
174
|
+
if self.with_explicit_padding:
|
|
175
|
+
x = self.padding_layer(x)
|
|
176
|
+
x = self.conv(x)
|
|
177
|
+
elif layer == 'norm' and norm and self.with_norm:
|
|
178
|
+
x = self.norm(x)
|
|
179
|
+
elif layer == 'act' and activate and self.with_activation:
|
|
180
|
+
x = self.activate(x)
|
|
181
|
+
return x
|
wml/wtorch/conv_ws.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright (c) OpenMMLab. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def conv_ws_2d(input,
|
|
8
|
+
weight,
|
|
9
|
+
bias=None,
|
|
10
|
+
stride=1,
|
|
11
|
+
padding=0,
|
|
12
|
+
dilation=1,
|
|
13
|
+
groups=1,
|
|
14
|
+
eps=1e-5):
|
|
15
|
+
c_in = weight.size(0)
|
|
16
|
+
weight_flat = weight.view(c_in, -1)
|
|
17
|
+
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
|
18
|
+
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
|
19
|
+
weight = (weight - mean) / (std + eps)
|
|
20
|
+
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ConvWS2d(nn.Conv2d):
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
in_channels,
|
|
27
|
+
out_channels,
|
|
28
|
+
kernel_size,
|
|
29
|
+
stride=1,
|
|
30
|
+
padding=0,
|
|
31
|
+
dilation=1,
|
|
32
|
+
groups=1,
|
|
33
|
+
bias=True,
|
|
34
|
+
eps=1e-5):
|
|
35
|
+
super(ConvWS2d, self).__init__(
|
|
36
|
+
in_channels,
|
|
37
|
+
out_channels,
|
|
38
|
+
kernel_size,
|
|
39
|
+
stride=stride,
|
|
40
|
+
padding=padding,
|
|
41
|
+
dilation=dilation,
|
|
42
|
+
groups=groups,
|
|
43
|
+
bias=bias)
|
|
44
|
+
self.eps = eps
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
|
48
|
+
self.dilation, self.groups, self.eps)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ConvAWS2d(nn.Conv2d):
|
|
52
|
+
"""AWS (Adaptive Weight Standardization)
|
|
53
|
+
|
|
54
|
+
This is a variant of Weight Standardization
|
|
55
|
+
(https://arxiv.org/pdf/1903.10520.pdf)
|
|
56
|
+
It is used in DetectoRS to avoid NaN
|
|
57
|
+
(https://arxiv.org/pdf/2006.02334.pdf)
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
in_channels (int): Number of channels in the input image
|
|
61
|
+
out_channels (int): Number of channels produced by the convolution
|
|
62
|
+
kernel_size (int or tuple): Size of the conv kernel
|
|
63
|
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
64
|
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
|
65
|
+
the input. Default: 0
|
|
66
|
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
|
67
|
+
Default: 1
|
|
68
|
+
groups (int, optional): Number of blocked connections from input
|
|
69
|
+
channels to output channels. Default: 1
|
|
70
|
+
bias (bool, optional): If set True, adds a learnable bias to the
|
|
71
|
+
output. Default: True
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self,
|
|
75
|
+
in_channels,
|
|
76
|
+
out_channels,
|
|
77
|
+
kernel_size,
|
|
78
|
+
stride=1,
|
|
79
|
+
padding=0,
|
|
80
|
+
dilation=1,
|
|
81
|
+
groups=1,
|
|
82
|
+
bias=True):
|
|
83
|
+
super().__init__(
|
|
84
|
+
in_channels,
|
|
85
|
+
out_channels,
|
|
86
|
+
kernel_size,
|
|
87
|
+
stride=stride,
|
|
88
|
+
padding=padding,
|
|
89
|
+
dilation=dilation,
|
|
90
|
+
groups=groups,
|
|
91
|
+
bias=bias)
|
|
92
|
+
self.register_buffer('weight_gamma',
|
|
93
|
+
torch.ones(self.out_channels, 1, 1, 1))
|
|
94
|
+
self.register_buffer('weight_beta',
|
|
95
|
+
torch.zeros(self.out_channels, 1, 1, 1))
|
|
96
|
+
|
|
97
|
+
def _get_weight(self, weight):
|
|
98
|
+
weight_flat = weight.view(weight.size(0), -1)
|
|
99
|
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
|
100
|
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
|
101
|
+
weight = (weight - mean) / std
|
|
102
|
+
weight = self.weight_gamma * weight + self.weight_beta
|
|
103
|
+
return weight
|
|
104
|
+
|
|
105
|
+
def forward(self, x):
|
|
106
|
+
weight = self._get_weight(self.weight)
|
|
107
|
+
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
|
108
|
+
self.dilation, self.groups)
|
|
109
|
+
|
|
110
|
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
111
|
+
missing_keys, unexpected_keys, error_msgs):
|
|
112
|
+
"""Override default load function.
|
|
113
|
+
|
|
114
|
+
AWS overrides the function _load_from_state_dict to recover
|
|
115
|
+
weight_gamma and weight_beta if they are missing. If weight_gamma and
|
|
116
|
+
weight_beta are found in the checkpoint, this function will return
|
|
117
|
+
after super()._load_from_state_dict. Otherwise, it will compute the
|
|
118
|
+
mean and std of the pretrained weights and store them in weight_beta
|
|
119
|
+
and weight_gamma.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
self.weight_gamma.data.fill_(-1)
|
|
123
|
+
local_missing_keys = []
|
|
124
|
+
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
|
125
|
+
strict, local_missing_keys,
|
|
126
|
+
unexpected_keys, error_msgs)
|
|
127
|
+
if self.weight_gamma.data.mean() > 0:
|
|
128
|
+
for k in local_missing_keys:
|
|
129
|
+
missing_keys.append(k)
|
|
130
|
+
return
|
|
131
|
+
weight = self.weight.data
|
|
132
|
+
weight_flat = weight.view(weight.size(0), -1)
|
|
133
|
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
|
134
|
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
|
135
|
+
self.weight_beta.data.copy_(mean)
|
|
136
|
+
self.weight_gamma.data.copy_(std)
|
|
137
|
+
missing_gamma_beta = [
|
|
138
|
+
k for k in local_missing_keys
|
|
139
|
+
if k.endswith('weight_gamma') or k.endswith('weight_beta')
|
|
140
|
+
]
|
|
141
|
+
for k in missing_gamma_beta:
|
|
142
|
+
local_missing_keys.remove(k)
|
|
143
|
+
for k in local_missing_keys:
|
|
144
|
+
missing_keys.append(k)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler,InfiniteSequentialSampler
|
|
2
|
+
from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, BufferedShuffleDataset,
|
|
3
|
+
Subset, random_split)
|
|
4
|
+
from .base_data_loader_iter import _BaseDataLoaderIter, _DatasetKind
|
|
5
|
+
from .single_process_data_loader_iter import _SingleProcessDataLoaderIter
|
|
6
|
+
from .multi_processing_data_loader_iter import _MultiProcessingDataLoaderIter
|
|
7
|
+
from .dataset import IterableDataset as IterDataPipe
|
|
8
|
+
from .distributed import DistributedSampler
|
|
9
|
+
from .dataloader import DataLoader, _DatasetKind, get_worker_info
|
|
10
|
+
|
|
11
|
+
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
|
|
12
|
+
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
|
|
13
|
+
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
|
|
14
|
+
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
|
|
15
|
+
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
|
|
16
|
+
'IterDataPipe']
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
r"""Utility classes & functions for data loading. Code in this folder is mostly
|
|
2
|
+
used by ../dataloder.py.
|
|
3
|
+
|
|
4
|
+
A lot of multiprocessing is used in data loading, which only supports running
|
|
5
|
+
functions defined in global environment (py2 can't serialize static methods).
|
|
6
|
+
Therefore, for code tidiness we put these functions into different files in this
|
|
7
|
+
folder.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
import atexit
|
|
12
|
+
|
|
13
|
+
# old private location of the ExceptionWrapper that some users rely on:
|
|
14
|
+
from torch._utils import ExceptionWrapper
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
IS_WINDOWS = sys.platform == "win32"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
MP_STATUS_CHECK_INTERVAL = 5.0
|
|
21
|
+
r"""Interval (in seconds) to check status of processes to avoid hanging in
|
|
22
|
+
multiprocessing data loading. This is mainly used in getting data from
|
|
23
|
+
another process, in which case we need to periodically check whether the
|
|
24
|
+
sender is alive to prevent hanging."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
python_exit_status = False
|
|
28
|
+
r"""Whether Python is shutting down. This flag is guaranteed to be set before
|
|
29
|
+
the Python core library resources are freed, but Python may already be exiting
|
|
30
|
+
for some time when this is set.
|
|
31
|
+
|
|
32
|
+
Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
|
|
33
|
+
hook in Python 3.7 multiprocessing library:
|
|
34
|
+
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _set_python_exit_flag():
|
|
39
|
+
global python_exit_status
|
|
40
|
+
python_exit_status = True
|
|
41
|
+
|
|
42
|
+
atexit.register(_set_python_exit_flag)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
from . import worker, signal_handling, pin_memory, collate, fetch
|