python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/nn.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from .conv_ws import ConvWS2d
|
|
5
|
+
from collections import Iterable
|
|
6
|
+
from torch.nn import Parameter
|
|
7
|
+
import math
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
#from einops import rearrange
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _clone_tensors(x):
|
|
14
|
+
if isinstance(x,(list,tuple)):
|
|
15
|
+
return [v.clone() for v in x]
|
|
16
|
+
return x.clone()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LayerNorm(nn.Module):
|
|
20
|
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
|
21
|
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
|
22
|
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
|
23
|
+
with shape (batch_size, channels, height, width).
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
28
|
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
|
29
|
+
self.eps = eps
|
|
30
|
+
self.data_format = data_format
|
|
31
|
+
if self.data_format not in ["channels_last", "channels_first"]:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
self.normalized_shape = (normalized_shape, )
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
if self.data_format == "channels_last":
|
|
37
|
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
38
|
+
elif self.data_format == "channels_first":
|
|
39
|
+
u = x.mean(1, keepdim=True)
|
|
40
|
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
41
|
+
x = (x - u) / torch.sqrt(s + self.eps)
|
|
42
|
+
if len(x.shape)==4:
|
|
43
|
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
|
44
|
+
else:
|
|
45
|
+
x = self.weight[:, None] * x + self.bias[:, None]
|
|
46
|
+
return x
|
|
47
|
+
|
|
48
|
+
class Identity(nn.Module):
|
|
49
|
+
def __init__(self,name="Identity"):
|
|
50
|
+
self.name = name
|
|
51
|
+
self.cache = None
|
|
52
|
+
self.grad_input = None
|
|
53
|
+
self.grad_output = None
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.register_backward_hook(self.backward_hook)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def backward_hook(self,model,grad_input,grad_output):
|
|
59
|
+
self.grad_input = _clone_tensors(grad_input)
|
|
60
|
+
self.grad_output = _clone_tensors(grad_output)
|
|
61
|
+
|
|
62
|
+
def forward(self,x):
|
|
63
|
+
self.cache = x
|
|
64
|
+
return x.clone()
|
|
65
|
+
|
|
66
|
+
def __repr__(self):
|
|
67
|
+
return self.name
|
|
68
|
+
|
|
69
|
+
class LayerNorm2d(nn.LayerNorm):
|
|
70
|
+
"""LayerNorm on channels for 2d images.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
num_channels (int): The number of channels of the input tensor.
|
|
74
|
+
eps (float): a value added to the denominator for numerical stability.
|
|
75
|
+
Defaults to 1e-5.
|
|
76
|
+
elementwise_affine (bool): a boolean value that when set to ``True``,
|
|
77
|
+
this module has learnable per-element affine parameters initialized
|
|
78
|
+
to ones (for weights) and zeros (for biases). Defaults to True.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, num_features: int, **kwargs) -> None:
|
|
82
|
+
super().__init__(num_features, **kwargs)
|
|
83
|
+
self.num_channels = self.normalized_shape[0]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def forward(self, x, data_format='channel_first'):
|
|
87
|
+
"""Forward method.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
x (torch.Tensor): The input tensor.
|
|
91
|
+
data_format (str): The format of the input tensor. If
|
|
92
|
+
``"channel_first"``, the shape of the input tensor should be
|
|
93
|
+
(B, C, H, W). If ``"channel_last"``, the shape of the input
|
|
94
|
+
tensor should be (B, H, W, C). Defaults to "channel_first".
|
|
95
|
+
"""
|
|
96
|
+
assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \
|
|
97
|
+
f'(N, C, H, W), but got tensor with shape {x.shape}'
|
|
98
|
+
if data_format == 'channel_first':
|
|
99
|
+
x = x.permute(0, 2, 3, 1)
|
|
100
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
|
|
101
|
+
self.eps)
|
|
102
|
+
# If the output is discontiguous, it may cause some unexpected
|
|
103
|
+
# problem in the downstream tasks
|
|
104
|
+
x = x.permute(0, 3, 1, 2).contiguous()
|
|
105
|
+
elif data_format == 'channel_last':
|
|
106
|
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
|
|
107
|
+
self.eps)
|
|
108
|
+
return x
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class EvoNormS0(nn.Module):
|
|
112
|
+
def __init__(self, num_groups,num_features, eps=1e-6, scale=True):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.num_groups = num_groups
|
|
115
|
+
self.num_features = num_features
|
|
116
|
+
if scale:
|
|
117
|
+
self.gamma = nn.Parameter(torch.ones([1,num_groups,num_features//num_groups,1,1]))
|
|
118
|
+
self.beta = nn.Parameter(torch.zeros([1,num_groups,num_features//num_groups,1,1]))
|
|
119
|
+
self.v1 = nn.Parameter(torch.ones([1,num_groups,num_features//num_groups,1,1]))
|
|
120
|
+
self.eps = eps
|
|
121
|
+
self.scale = scale
|
|
122
|
+
|
|
123
|
+
def forward(self, x):
|
|
124
|
+
N,C,H,W = x.shape
|
|
125
|
+
G = self.num_groups
|
|
126
|
+
x = x.view([N,G,C//G,H,W])
|
|
127
|
+
var = x.std(dim=(2,3,4),keepdim=True)
|
|
128
|
+
gain = torch.rsqrt(var+self.eps)
|
|
129
|
+
if self.scale:
|
|
130
|
+
gain = gain*self.gamma
|
|
131
|
+
|
|
132
|
+
x = x*torch.sigmoid(x*self.v1)*gain+self.beta
|
|
133
|
+
|
|
134
|
+
return x.view([N,C,H,W]).contiguous()
|
|
135
|
+
|
|
136
|
+
def __repr__(self):
|
|
137
|
+
return f"EvoNormS0 (num_features={self.num_features}, num_groups={self.num_groups}, eps={self.eps})"
|
|
138
|
+
|
|
139
|
+
class GroupNorm(nn.GroupNorm): #torch.nn.GroupNorm在导出onnx时只支持ndim>=3的tensor
|
|
140
|
+
def forward(self,x):
|
|
141
|
+
if x.ndim == 2:
|
|
142
|
+
x = torch.unsqueeze(x,dim=-1)
|
|
143
|
+
x = super().forward(x)
|
|
144
|
+
x = torch.squeeze(x,dim=-1)
|
|
145
|
+
return x
|
|
146
|
+
else:
|
|
147
|
+
return super().forward(x)
|
|
148
|
+
|
|
149
|
+
class EvoNormS01D(nn.Module):
|
|
150
|
+
def __init__(self, num_groups,num_features, eps=1e-6, scale=True):
|
|
151
|
+
super().__init__()
|
|
152
|
+
self.num_groups = num_groups
|
|
153
|
+
self.num_features = num_features
|
|
154
|
+
if scale:
|
|
155
|
+
self.gamma = nn.Parameter(torch.ones([1,num_groups,num_features//num_groups]))
|
|
156
|
+
self.beta = nn.Parameter(torch.zeros([1,num_groups,num_features//num_groups]))
|
|
157
|
+
self.v1 = nn.Parameter(torch.ones([1,num_groups,num_features//num_groups]))
|
|
158
|
+
self.eps = eps
|
|
159
|
+
self.scale = scale
|
|
160
|
+
|
|
161
|
+
def forward(self, x):
|
|
162
|
+
N,C = x.shape
|
|
163
|
+
G = self.num_groups
|
|
164
|
+
x = x.view([N,G,C//G])
|
|
165
|
+
var = x.std(dim=(2),keepdim=True)
|
|
166
|
+
gain = torch.rsqrt(var+self.eps)
|
|
167
|
+
if self.scale:
|
|
168
|
+
gain = gain*self.gamma
|
|
169
|
+
|
|
170
|
+
x = x*torch.sigmoid(x*self.v1)*gain+self.beta
|
|
171
|
+
|
|
172
|
+
return x.view([N,C]).contiguous()
|
|
173
|
+
|
|
174
|
+
def __repr__(self):
|
|
175
|
+
return f"EvoNormS01D (num_features={self.num_features}, num_groups={self.num_groups}, eps={self.eps})"
|
|
176
|
+
|
|
177
|
+
class SEBlock(nn.Module):
|
|
178
|
+
def __init__(self,channels,r=16):
|
|
179
|
+
super().__init__()
|
|
180
|
+
self.channels = channels
|
|
181
|
+
self.r = r
|
|
182
|
+
self.fc0 = nn.Linear(self.channels,self.channels//r)
|
|
183
|
+
self.fc1 = nn.Linear(self.channels//r,self.channels)
|
|
184
|
+
|
|
185
|
+
def forward(self,net):
|
|
186
|
+
org_net = net
|
|
187
|
+
net = net.mean(dim=(2,3),keepdim=False)
|
|
188
|
+
net = self.fc0(net)
|
|
189
|
+
net = F.relu(net,inplace=True)
|
|
190
|
+
net = self.fc1(net)
|
|
191
|
+
net = F.sigmoid(net)
|
|
192
|
+
net = torch.unsqueeze(net,dim=-1)
|
|
193
|
+
net = torch.unsqueeze(net,dim=-1)
|
|
194
|
+
return net*org_net
|
|
195
|
+
|
|
196
|
+
class PositionEmbeddingLearned(nn.Module):
|
|
197
|
+
"""
|
|
198
|
+
Absolute pos embedding, learned.
|
|
199
|
+
"""
|
|
200
|
+
def __init__(self, num_pos_feats=256,max_rows=50,max_cols=50):
|
|
201
|
+
super().__init__()
|
|
202
|
+
self.row_embed = nn.Embedding(max_rows, num_pos_feats)
|
|
203
|
+
self.col_embed = nn.Embedding(max_cols, num_pos_feats)
|
|
204
|
+
self.reset_parameters()
|
|
205
|
+
|
|
206
|
+
def reset_parameters(self):
|
|
207
|
+
nn.init.uniform_(self.row_embed.weight)
|
|
208
|
+
nn.init.uniform_(self.col_embed.weight)
|
|
209
|
+
|
|
210
|
+
def forward(self, x):
|
|
211
|
+
'''
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
x: [...,C,H,W]
|
|
215
|
+
Returns:
|
|
216
|
+
|
|
217
|
+
'''
|
|
218
|
+
h, w = x.shape[-2:]
|
|
219
|
+
i = torch.arange(w, device=x.device)
|
|
220
|
+
j = torch.arange(h, device=x.device)
|
|
221
|
+
x_emb = self.col_embed(i)
|
|
222
|
+
y_emb = self.row_embed(j)
|
|
223
|
+
pos = torch.cat([
|
|
224
|
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
|
225
|
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
|
226
|
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
|
227
|
+
|
|
228
|
+
return pos
|
|
229
|
+
|
|
230
|
+
class FrozenBatchNorm2d(nn.Module):
|
|
231
|
+
"""
|
|
232
|
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
233
|
+
|
|
234
|
+
It contains non-trainable buffers called
|
|
235
|
+
"weight" and "bias", "running_mean", "running_var",
|
|
236
|
+
initialized to perform identity transformation.
|
|
237
|
+
|
|
238
|
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
|
239
|
+
which are computed from the original four parameters of BN.
|
|
240
|
+
The affine transform `x * weight + bias` will perform the equivalent
|
|
241
|
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
|
242
|
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
|
243
|
+
will be left unchanged as identity transformation.
|
|
244
|
+
|
|
245
|
+
Other pre-trained backbone models may contain all 4 parameters.
|
|
246
|
+
|
|
247
|
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
_version = 3
|
|
251
|
+
|
|
252
|
+
def __init__(self, num_features, eps=1e-5):
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.num_features = num_features
|
|
255
|
+
self.eps = eps
|
|
256
|
+
#self.register_buffer("weight", torch.ones(num_features))
|
|
257
|
+
#self.register_buffer("bias", torch.zeros(num_features))
|
|
258
|
+
self.weight = Parameter(torch.ones(num_features))
|
|
259
|
+
self.bias = Parameter(torch.zeros(num_features))
|
|
260
|
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
|
261
|
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def forward(self, x):
|
|
265
|
+
if x.requires_grad:
|
|
266
|
+
# When gradients are needed, F.batch_norm will use extra memory
|
|
267
|
+
# because its backward op computes gradients for weight/bias as well.
|
|
268
|
+
scale = self.weight.float() * (self.running_var.float() + self.eps).rsqrt()
|
|
269
|
+
bias = self.bias.float() - self.running_mean.float() * scale
|
|
270
|
+
scale = scale.reshape(1, -1, 1, 1)
|
|
271
|
+
bias = bias.reshape(1, -1, 1, 1)
|
|
272
|
+
out_dtype = x.dtype # may be half
|
|
273
|
+
return x * scale.to(out_dtype) + bias.to(out_dtype)
|
|
274
|
+
else:
|
|
275
|
+
# When gradients are not needed, F.batch_norm is a single fused op
|
|
276
|
+
# and provide more optimization opportunities.
|
|
277
|
+
return F.batch_norm(
|
|
278
|
+
x,
|
|
279
|
+
self.running_mean,
|
|
280
|
+
self.running_var,
|
|
281
|
+
self.weight,
|
|
282
|
+
self.bias,
|
|
283
|
+
training=False,
|
|
284
|
+
eps=self.eps,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def _load_from_state_dict(
|
|
288
|
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
289
|
+
):
|
|
290
|
+
version = local_metadata.get("version", None)
|
|
291
|
+
|
|
292
|
+
if version is None or version < 2:
|
|
293
|
+
# No running_mean/var in early versions
|
|
294
|
+
# This will silent the warnings
|
|
295
|
+
if prefix + "running_mean" not in state_dict:
|
|
296
|
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
|
297
|
+
if prefix + "running_var" not in state_dict:
|
|
298
|
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
|
299
|
+
|
|
300
|
+
# NOTE: if a checkpoint is trained with BatchNorm and loaded (together with
|
|
301
|
+
# version number) to FrozenBatchNorm, running_var will be wrong. One solution
|
|
302
|
+
# is to remove the version number from the checkpoint.
|
|
303
|
+
if version is not None and version < 3:
|
|
304
|
+
print("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
|
|
305
|
+
# In version < 3, running_var are used without +eps.
|
|
306
|
+
state_dict[prefix + "running_var"] -= self.eps
|
|
307
|
+
|
|
308
|
+
super()._load_from_state_dict(
|
|
309
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def __repr__(self):
|
|
313
|
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
|
314
|
+
|
|
315
|
+
@classmethod
|
|
316
|
+
def convert_frozen_batchnorm(cls, module):
|
|
317
|
+
"""
|
|
318
|
+
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
module (torch.nn.Module):
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
|
325
|
+
Otherwise, in-place convert module and return it.
|
|
326
|
+
|
|
327
|
+
Similar to convert_sync_batchnorm in
|
|
328
|
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
|
329
|
+
"""
|
|
330
|
+
bn_module = nn.modules.batchnorm
|
|
331
|
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
|
332
|
+
res = module
|
|
333
|
+
if isinstance(module, bn_module):
|
|
334
|
+
res = cls(module.num_features)
|
|
335
|
+
if module.affine:
|
|
336
|
+
res.weight.data = module.weight.data.clone().detach()
|
|
337
|
+
res.bias.data = module.bias.data.clone().detach()
|
|
338
|
+
res.running_mean.data = module.running_mean.data
|
|
339
|
+
res.running_var.data = module.running_var.data
|
|
340
|
+
res.eps = module.eps
|
|
341
|
+
else:
|
|
342
|
+
for name, child in module.named_children():
|
|
343
|
+
new_child = cls.convert_frozen_batchnorm(child)
|
|
344
|
+
if new_child is not child:
|
|
345
|
+
res.add_module(name, new_child)
|
|
346
|
+
return res
|
|
347
|
+
|
|
348
|
+
class WSConv2d(nn.Conv2d):
|
|
349
|
+
|
|
350
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
351
|
+
padding=0, dilation=1, groups=1, bias=True):
|
|
352
|
+
super().__init__(in_channels, out_channels, kernel_size, stride,
|
|
353
|
+
padding, dilation, groups, bias)
|
|
354
|
+
|
|
355
|
+
def forward(self, x):
|
|
356
|
+
weight = self.weight
|
|
357
|
+
weight_mean = weight.view(weight.size(0),-1).mean(dim=1).view(-1,1,1,1)
|
|
358
|
+
weight = weight - weight_mean
|
|
359
|
+
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
|
360
|
+
weight = weight / std
|
|
361
|
+
return F.conv2d(x, weight, self.bias, self.stride,
|
|
362
|
+
self.padding, self.dilation, self.groups)
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def convert_wsconv(cls, module,exclude=None,parent=""):
|
|
366
|
+
conv_module = nn.Conv2d
|
|
367
|
+
res = module
|
|
368
|
+
if isinstance(module, conv_module):
|
|
369
|
+
res = cls(module.in_channels,module.out_channels,module.kernel_size,module.stride,
|
|
370
|
+
module.padding,module.dilation,module.groups,module.bias is not None)
|
|
371
|
+
'''res.weight.data = module.weight.data
|
|
372
|
+
if res.bias is not None:
|
|
373
|
+
res.bias.data = module.bias.data'''
|
|
374
|
+
else:
|
|
375
|
+
for name, child in module.named_children():
|
|
376
|
+
r_name = parent+"."+name if len(parent)>0 else name
|
|
377
|
+
if exclude is not None:
|
|
378
|
+
if name in exclude or r_name in exclude:
|
|
379
|
+
print(f"Skip: {r_name}")
|
|
380
|
+
continue
|
|
381
|
+
new_child = cls.convert_wsconv(child,exclude=exclude,parent=r_name)
|
|
382
|
+
if new_child is not child:
|
|
383
|
+
res.add_module(name, new_child)
|
|
384
|
+
return res
|
|
385
|
+
|
|
386
|
+
class BCNorm(nn.Module):
|
|
387
|
+
def __init__(self,num_features, num_groups=32,eps=1e-5, momentum=0.1, affine=True,
|
|
388
|
+
track_running_stats=True):
|
|
389
|
+
super().__init__()
|
|
390
|
+
self.bn = nn.BatchNorm2d(num_features=num_features,
|
|
391
|
+
eps=eps,
|
|
392
|
+
momentum=momentum,
|
|
393
|
+
affine=affine,
|
|
394
|
+
track_running_stats=track_running_stats)
|
|
395
|
+
self.gn = nn.GroupNorm(num_groups=num_groups,num_channels=num_features,eps=eps,affine=affine)
|
|
396
|
+
|
|
397
|
+
def forward(self,x):
|
|
398
|
+
x = self.bn(x)
|
|
399
|
+
return self.gn(x)
|
|
400
|
+
|
|
401
|
+
def get_norm(norm, out_channels,norm_args={}):
|
|
402
|
+
"""
|
|
403
|
+
Args:
|
|
404
|
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
|
405
|
+
or a callable that takes a channel number and returns
|
|
406
|
+
the normalization layer as a nn.Module.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
nn.Module or None: the normalization layer
|
|
410
|
+
"""
|
|
411
|
+
if isinstance(norm,dict):
|
|
412
|
+
norm = dict(norm)
|
|
413
|
+
_norm = norm.pop('type')
|
|
414
|
+
norm_args = norm
|
|
415
|
+
norm = _norm
|
|
416
|
+
if norm is None:
|
|
417
|
+
return None
|
|
418
|
+
if norm in ["GN","EvoNormS0"] and len(norm_args)==0:
|
|
419
|
+
norm_args = {"num_groups":32}
|
|
420
|
+
|
|
421
|
+
if norm == 'GN':
|
|
422
|
+
#return nn.GroupNorm(num_channels=out_channels,**norm_args)
|
|
423
|
+
return GroupNorm(num_channels=out_channels,**norm_args)
|
|
424
|
+
|
|
425
|
+
if isinstance(norm, str):
|
|
426
|
+
if len(norm) == 0:
|
|
427
|
+
return None
|
|
428
|
+
norm = {
|
|
429
|
+
"BN": torch.nn.BatchNorm2d,
|
|
430
|
+
# Fixed in https://github.com/pytorch/pytorch/pull/36382
|
|
431
|
+
"SyncBN": nn.SyncBatchNorm,
|
|
432
|
+
"FrozenBN": FrozenBatchNorm2d,
|
|
433
|
+
# for debugging:
|
|
434
|
+
"SyncBatchNorm": nn.SyncBatchNorm,
|
|
435
|
+
"LayerNorm2d": LayerNorm2d,
|
|
436
|
+
"LayerNorm":LayerNorm2d,
|
|
437
|
+
"EvoNormS0": EvoNormS0,
|
|
438
|
+
"InstanceNorm":nn.InstanceNorm2d,
|
|
439
|
+
}[norm]
|
|
440
|
+
return norm(num_features=out_channels,**norm_args)
|
|
441
|
+
|
|
442
|
+
def get_norm1d(norm, out_channels,norm_args={}):
|
|
443
|
+
"""
|
|
444
|
+
Args:
|
|
445
|
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
|
446
|
+
or a callable that takes a channel number and returns
|
|
447
|
+
the normalization layer as a nn.Module.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
nn.Module or None: the normalization layer
|
|
451
|
+
"""
|
|
452
|
+
if norm is None:
|
|
453
|
+
return None
|
|
454
|
+
if norm in ["GN","EvoNormS0"] and len(norm_args)==0:
|
|
455
|
+
norm_args = {"num_groups":32}
|
|
456
|
+
|
|
457
|
+
if norm == 'GN':
|
|
458
|
+
#return nn.GroupNorm(num_channels=out_channels,**norm_args)
|
|
459
|
+
return GroupNorm(num_channels=out_channels,**norm_args)
|
|
460
|
+
|
|
461
|
+
if isinstance(norm, str):
|
|
462
|
+
if len(norm) == 0:
|
|
463
|
+
return None
|
|
464
|
+
norm = {
|
|
465
|
+
"BN": torch.nn.BatchNorm1d,
|
|
466
|
+
"LayerNorm":nn.LayerNorm,
|
|
467
|
+
"EvoNormS0": EvoNormS01D,
|
|
468
|
+
}[norm]
|
|
469
|
+
return norm(num_features=out_channels,**norm_args)
|
|
470
|
+
|
|
471
|
+
def get_conv_type(conv_cfg):
|
|
472
|
+
if conv_cfg is None:
|
|
473
|
+
return nn.Conv2d
|
|
474
|
+
elif conv_cfg['type'] == "ConvWS":
|
|
475
|
+
return ConvWS2d
|
|
476
|
+
|
|
477
|
+
class SiLU(nn.Module):
|
|
478
|
+
"""export-friendly version of nn.SiLU()"""
|
|
479
|
+
|
|
480
|
+
@staticmethod
|
|
481
|
+
def forward(x):
|
|
482
|
+
return x * torch.sigmoid(x)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def get_activation(name="SiLU", inplace=True):
|
|
486
|
+
if isinstance(name,dict):
|
|
487
|
+
cfg = dict(name)
|
|
488
|
+
name = cfg.pop('type')
|
|
489
|
+
inplace = cfg.pop('inplace',True)
|
|
490
|
+
assert len(cfg)==0,f"ERROR: activation cfg {cfg}"
|
|
491
|
+
if name == "SiLU" or name == "Swish":
|
|
492
|
+
module = nn.SiLU(inplace=inplace)
|
|
493
|
+
elif name == "ReLU":
|
|
494
|
+
module = nn.ReLU(inplace=inplace)
|
|
495
|
+
elif name == "LeakyReLU":
|
|
496
|
+
module = nn.LeakyReLU(0.1, inplace=inplace)
|
|
497
|
+
elif name == "Hardswish":
|
|
498
|
+
module = nn.Hardswish(inplace=inplace)
|
|
499
|
+
elif name == "GELU":
|
|
500
|
+
module = nn.GELU()
|
|
501
|
+
elif name == "HSigmoid":
|
|
502
|
+
module = nn.Hardsigmoid(inplace=inplace)
|
|
503
|
+
else:
|
|
504
|
+
raise AttributeError("Unsupported act type: {}".format(name))
|
|
505
|
+
return module
|
|
506
|
+
|
|
507
|
+
class Conv2d(torch.nn.Conv2d):
|
|
508
|
+
"""
|
|
509
|
+
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
def __init__(self, *args, **kwargs):
|
|
513
|
+
"""
|
|
514
|
+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
norm (nn.Module, optional): a normalization layer
|
|
518
|
+
activation (callable(Tensor) -> Tensor): a callable activation function
|
|
519
|
+
|
|
520
|
+
It assumes that norm layer is used before activation.
|
|
521
|
+
"""
|
|
522
|
+
norm = kwargs.pop("norm", None)
|
|
523
|
+
activation = kwargs.pop("activation", None)
|
|
524
|
+
super().__init__(*args, **kwargs)
|
|
525
|
+
|
|
526
|
+
self.norm = norm
|
|
527
|
+
self.activation = activation
|
|
528
|
+
|
|
529
|
+
def forward(self, x):
|
|
530
|
+
# torchscript does not support SyncBatchNorm yet
|
|
531
|
+
# https://github.com/pytorch/pytorch/issues/40507
|
|
532
|
+
# and we skip these codes in torchscript since:
|
|
533
|
+
# 1. currently we only support torchscript in evaluation mode
|
|
534
|
+
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
|
|
535
|
+
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
|
|
536
|
+
if not torch.jit.is_scripting():
|
|
537
|
+
if x.numel() == 0 and self.training:
|
|
538
|
+
# https://github.com/pytorch/pytorch/issues/12013
|
|
539
|
+
assert not isinstance(
|
|
540
|
+
self.norm, torch.nn.SyncBatchNorm
|
|
541
|
+
), "SyncBatchNorm does not support empty inputs!"
|
|
542
|
+
|
|
543
|
+
x = F.conv2d(
|
|
544
|
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
|
545
|
+
)
|
|
546
|
+
if self.norm is not None:
|
|
547
|
+
x = self.norm(x)
|
|
548
|
+
if self.activation is not None:
|
|
549
|
+
x = self.activation(x)
|
|
550
|
+
return x
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
class AttentionPool2d(nn.Module):
|
|
554
|
+
def __init__(self, spacial_dim: int, in_channels: int = 2048, num_heads: int=8, out_channels: int = 1024):
|
|
555
|
+
'''
|
|
556
|
+
in_channels: input_channels
|
|
557
|
+
out_channels: output_channels
|
|
558
|
+
spacial_dim: int/list, int: w=h=spacial_dim, list[int]: spacial_dim (h,w)
|
|
559
|
+
'''
|
|
560
|
+
super().__init__()
|
|
561
|
+
if isinstance(spacial_dim,Iterable):
|
|
562
|
+
spacial_size = spacial_dim[0]*spacial_dim[1]
|
|
563
|
+
else:
|
|
564
|
+
spacial_size = spacial_dim ** 2
|
|
565
|
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_size + 1, in_channels) / in_channels ** 0.5)
|
|
566
|
+
self.k_proj = nn.Linear(in_channels, in_channels)
|
|
567
|
+
self.q_proj = nn.Linear(in_channels, in_channels)
|
|
568
|
+
self.v_proj = nn.Linear(in_channels, in_channels)
|
|
569
|
+
self.c_proj = nn.Linear(in_channels, out_channels or in_channels)
|
|
570
|
+
self.num_heads = num_heads
|
|
571
|
+
|
|
572
|
+
def forward(self, x,query=None):
|
|
573
|
+
'''
|
|
574
|
+
query: [B,C]
|
|
575
|
+
return:
|
|
576
|
+
[B,C]
|
|
577
|
+
'''
|
|
578
|
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
|
579
|
+
if query is None:
|
|
580
|
+
query = x.mean(dim=0, keepdim=True)
|
|
581
|
+
else:
|
|
582
|
+
query = torch.unsqueeze(query,dim=0)
|
|
583
|
+
x = torch.cat([query, x], dim=0) # (HW+1)NC
|
|
584
|
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
|
585
|
+
x, _ = F.multi_head_attention_forward(
|
|
586
|
+
query=x[:1], key=x, value=x,
|
|
587
|
+
embed_dim_to_check=x.shape[-1],
|
|
588
|
+
num_heads=self.num_heads,
|
|
589
|
+
q_proj_weight=self.q_proj.weight,
|
|
590
|
+
k_proj_weight=self.k_proj.weight,
|
|
591
|
+
v_proj_weight=self.v_proj.weight,
|
|
592
|
+
in_proj_weight=None,
|
|
593
|
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
594
|
+
bias_k=None,
|
|
595
|
+
bias_v=None,
|
|
596
|
+
add_zero_attn=False,
|
|
597
|
+
dropout_p=0,
|
|
598
|
+
out_proj_weight=self.c_proj.weight,
|
|
599
|
+
out_proj_bias=self.c_proj.bias,
|
|
600
|
+
use_separate_proj_weight=True,
|
|
601
|
+
training=self.training,
|
|
602
|
+
need_weights=False
|
|
603
|
+
)
|
|
604
|
+
return x.squeeze(0)
|
|
605
|
+
|
|
606
|
+
class NormalizedLinear(nn.Module):
|
|
607
|
+
def __init__(self,in_channels,out_channels,eps=1e-5):
|
|
608
|
+
super().__init__()
|
|
609
|
+
self.weight = Parameter(torch.FloatTensor(out_channels,in_channels))
|
|
610
|
+
nn.init.xavier_uniform_(self.weight)
|
|
611
|
+
self.eps = eps
|
|
612
|
+
|
|
613
|
+
@torch.cuda.amp.autocast(False)
|
|
614
|
+
def forward(self,x):
|
|
615
|
+
x = x.float()
|
|
616
|
+
x = F.normalize(x,dim=-1)
|
|
617
|
+
'''with torch.no_grad():
|
|
618
|
+
temp_norm = torch.norm(
|
|
619
|
+
self.weight, p=2,
|
|
620
|
+
dim=1).unsqueeze(1).expand_as(self.weight)
|
|
621
|
+
self.weight.div_(temp_norm + self.eps)
|
|
622
|
+
return F.linear(x,self.weight)'''
|
|
623
|
+
weight = F.normalize(self.weight,dim=1)
|
|
624
|
+
return F.linear(x,weight)
|
|
625
|
+
|
|
626
|
+
def normalize_weight(self):
|
|
627
|
+
weight = F.normalize(self.weight,dim=1)
|
|
628
|
+
return weight
|
|
629
|
+
|
|
630
|
+
def loss(self):
|
|
631
|
+
s = torch.eye(self.weight.shape[0])
|
|
632
|
+
s = torch.ones_like(s)-s
|
|
633
|
+
w = F.normalize(self.weight,dim=1)
|
|
634
|
+
r = w@w.T
|
|
635
|
+
l = r.clamp(min=0)
|
|
636
|
+
l = l*s.to(l)
|
|
637
|
+
return torch.mean(l)
|
|
638
|
+
|
|
639
|
+
class ArcMarginProductIF(nn.Module):
|
|
640
|
+
'''
|
|
641
|
+
insightface中ArcFace的实现
|
|
642
|
+
'''
|
|
643
|
+
def __init__(self, s=64.0, m=0.5):
|
|
644
|
+
super().__init__()
|
|
645
|
+
self.s = s
|
|
646
|
+
self.m = m
|
|
647
|
+
|
|
648
|
+
def forward(self, cosine: torch.Tensor, label):
|
|
649
|
+
'''
|
|
650
|
+
cosing: [M,...,C],[-1,1]
|
|
651
|
+
label: [M,...],int,[0,C)
|
|
652
|
+
'''
|
|
653
|
+
index = torch.where(label != -1)[0]
|
|
654
|
+
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
|
655
|
+
m_hot.scatter_(1, label[index, None], self.m)
|
|
656
|
+
cosine.acos_()
|
|
657
|
+
cosine[index] += m_hot
|
|
658
|
+
cosine.cos_().mul_(self.s)
|
|
659
|
+
return cosine
|
|
660
|
+
|
|
661
|
+
class ArcMarginProduct(nn.Module):
|
|
662
|
+
r"""Implement of large margin arc distance: :
|
|
663
|
+
Args:
|
|
664
|
+
s: norm of input feature
|
|
665
|
+
m: margin
|
|
666
|
+
|
|
667
|
+
cos(theta + m)
|
|
668
|
+
"""
|
|
669
|
+
def __init__(self, s=30.0, m=0.50, easy_margin=False):
|
|
670
|
+
super(ArcMarginProduct, self).__init__()
|
|
671
|
+
self.s = s
|
|
672
|
+
self.m = m
|
|
673
|
+
|
|
674
|
+
self.easy_margin = easy_margin
|
|
675
|
+
self.cos_m = math.cos(m)
|
|
676
|
+
self.sin_m = math.sin(m)
|
|
677
|
+
self.th = math.cos(math.pi - m)
|
|
678
|
+
self.mm = math.sin(math.pi - m) * m
|
|
679
|
+
|
|
680
|
+
@torch.cuda.amp.autocast(False)
|
|
681
|
+
def forward(self, *,cosine, label):
|
|
682
|
+
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
|
683
|
+
b = 1.0001 #防止cosine==1或者-1时梯度变为无穷大,无穷小
|
|
684
|
+
max_v = 1.00004
|
|
685
|
+
cosine = cosine.float()
|
|
686
|
+
cosine = cosine.clamp(-max_v,max_v)
|
|
687
|
+
sine = torch.sqrt((b - torch.pow(cosine, 2)).clamp(0, 1)).to(cosine.dtype)
|
|
688
|
+
phi = cosine * self.cos_m - sine * self.sin_m #cos(theta+m)
|
|
689
|
+
if self.easy_margin:
|
|
690
|
+
phi = torch.where(cosine > 0, phi, cosine)
|
|
691
|
+
else:
|
|
692
|
+
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
|
693
|
+
# --------------------------- convert label to one-hot ---------------------------
|
|
694
|
+
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
|
|
695
|
+
one_hot = torch.zeros(cosine.size(), device=cosine.device)
|
|
696
|
+
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
|
697
|
+
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
|
698
|
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
|
699
|
+
output *= self.s
|
|
700
|
+
|
|
701
|
+
return output
|
|
702
|
+
|
|
703
|
+
class ArcMarginProduct_(nn.Module):
|
|
704
|
+
r"""Implement of large margin arc distance: :
|
|
705
|
+
Args:
|
|
706
|
+
s: norm of input feature
|
|
707
|
+
m: margin
|
|
708
|
+
|
|
709
|
+
cos(theta + m)
|
|
710
|
+
"""
|
|
711
|
+
def __init__(self, s=30.0, m=0.50, easy_margin=False):
|
|
712
|
+
super(ArcMarginProduct, self).__init__()
|
|
713
|
+
self.s = s
|
|
714
|
+
self.m = m
|
|
715
|
+
|
|
716
|
+
self.easy_margin = easy_margin
|
|
717
|
+
self.cos_m = math.cos(m)
|
|
718
|
+
self.sin_m = math.sin(m)
|
|
719
|
+
self.th = math.cos(math.pi - m)
|
|
720
|
+
self.mm = math.sin(math.pi - m) * m
|
|
721
|
+
self.id0 = Identity("id0")
|
|
722
|
+
self.id1 = Identity("id1")
|
|
723
|
+
self.id2 = Identity("id2")
|
|
724
|
+
self.id3 = Identity("id3")
|
|
725
|
+
self.id4 = Identity("id4")
|
|
726
|
+
self.id5 = Identity("id5")
|
|
727
|
+
|
|
728
|
+
@torch.cuda.amp.autocast(False)
|
|
729
|
+
def forward(self, *,cosine, label):
|
|
730
|
+
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
|
731
|
+
b = 1.00001 #防止cosine==1或者-1时梯度变为无穷大,无穷小
|
|
732
|
+
cosine = cosine.float()
|
|
733
|
+
cosine = self.id0(cosine)
|
|
734
|
+
sine = torch.sqrt((b - torch.pow(cosine, 2)).clamp(0, 1)).to(cosine.dtype)
|
|
735
|
+
sine = self.id1(sine)
|
|
736
|
+
phi = cosine * self.cos_m - sine * self.sin_m #cos(theta+m)
|
|
737
|
+
phi = self.id2(phi)
|
|
738
|
+
if self.easy_margin:
|
|
739
|
+
phi = torch.where(cosine > 0, phi, cosine)
|
|
740
|
+
else:
|
|
741
|
+
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
|
742
|
+
phi = self.id3(phi)
|
|
743
|
+
# --------------------------- convert label to one-hot ---------------------------
|
|
744
|
+
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
|
|
745
|
+
one_hot = torch.zeros(cosine.size(), device=cosine.device)
|
|
746
|
+
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
|
747
|
+
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
|
748
|
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
|
749
|
+
output = self.id4(output)
|
|
750
|
+
output *= self.s
|
|
751
|
+
output = self.id5(output)
|
|
752
|
+
|
|
753
|
+
return output
|
|
754
|
+
|
|
755
|
+
# Define the softmax_one function with added one in the denominator , which helps to reduce
|
|
756
|
+
#the negative impact impact of tiny values in the softmax function and improves numerical stability
|
|
757
|
+
def softmax_one(x, dim=None, _stacklevel=3, dtype=None):
|
|
758
|
+
#subtract the max for stability
|
|
759
|
+
x = x - x.max(dim=dim, keepdim=True).values
|
|
760
|
+
#compute exponentials
|
|
761
|
+
exp_x = torch.exp(x)
|
|
762
|
+
#compute softmax values and add on in the denominator
|
|
763
|
+
return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True))
|
|
764
|
+
|
|
765
|
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
|
766
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
767
|
+
|
|
768
|
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
|
769
|
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
|
770
|
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
|
771
|
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
|
772
|
+
'survival rate' as the argument.
|
|
773
|
+
|
|
774
|
+
"""
|
|
775
|
+
if drop_prob == 0. or not training:
|
|
776
|
+
return x
|
|
777
|
+
keep_prob = 1 - drop_prob
|
|
778
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
779
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
780
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
781
|
+
random_tensor.div_(keep_prob)
|
|
782
|
+
return x * random_tensor
|
|
783
|
+
|
|
784
|
+
class DropPath(nn.Module):
|
|
785
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
786
|
+
"""
|
|
787
|
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
|
788
|
+
super(DropPath, self).__init__()
|
|
789
|
+
self.drop_prob = drop_prob
|
|
790
|
+
self.scale_by_keep = scale_by_keep
|
|
791
|
+
|
|
792
|
+
def forward(self, x):
|
|
793
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
794
|
+
|
|
795
|
+
def extra_repr(self):
|
|
796
|
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
class CHW2HWC(nn.Module):
|
|
800
|
+
def __init__(self):
|
|
801
|
+
super().__init__()
|
|
802
|
+
|
|
803
|
+
def forward(self,x):
|
|
804
|
+
return x.permute(0,2,3,1)
|
|
805
|
+
|
|
806
|
+
class HWC2CHW(nn.Module):
|
|
807
|
+
def __init__(self):
|
|
808
|
+
super().__init__()
|
|
809
|
+
|
|
810
|
+
def forward(self,x):
|
|
811
|
+
return x.permute(0,3,1,2)
|
|
812
|
+
|
|
813
|
+
class ParallelModule(nn.Module):
|
|
814
|
+
def __init__(self, *args):
|
|
815
|
+
super().__init__()
|
|
816
|
+
if len(args) == 1 and isinstance(args[0], (OrderedDict,dict)):
|
|
817
|
+
for key, module in args[0].items():
|
|
818
|
+
self.add_module(key, module)
|
|
819
|
+
else:
|
|
820
|
+
for idx, module in enumerate(args):
|
|
821
|
+
self.add_module(str(idx), module)
|
|
822
|
+
|
|
823
|
+
def forward(self,x):
|
|
824
|
+
res = []
|
|
825
|
+
for k,m in self._modules.items():
|
|
826
|
+
res.append(m(x))
|
|
827
|
+
return res
|
|
828
|
+
|
|
829
|
+
class SumModule(nn.Module):
|
|
830
|
+
def __init__(self):
|
|
831
|
+
super().__init__()
|
|
832
|
+
|
|
833
|
+
def forward(self,xs):
|
|
834
|
+
res = xs[0]
|
|
835
|
+
for i in range(1,len(xs)):
|
|
836
|
+
res += xs[i]
|
|
837
|
+
return res
|
|
838
|
+
|
|
839
|
+
def hard_sigmoid(x):
|
|
840
|
+
x = x/6+0.5
|
|
841
|
+
x = torch.clamp(x,min=0,max=1)
|
|
842
|
+
return x
|
|
843
|
+
|
|
844
|
+
class ChannelAttention(nn.Module):
|
|
845
|
+
"""Channel attention Module.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
channels (int): The input (and output) channels of the attention layer.
|
|
849
|
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
850
|
+
Defaults to None
|
|
851
|
+
"""
|
|
852
|
+
|
|
853
|
+
def __init__(self, channels: int, init_cfg = None) -> None:
|
|
854
|
+
super().__init__()
|
|
855
|
+
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
|
856
|
+
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
|
|
857
|
+
self.act = nn.Hardsigmoid(inplace=True)
|
|
858
|
+
|
|
859
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
860
|
+
"""Forward function for ChannelAttention."""
|
|
861
|
+
with torch.cuda.amp.autocast(enabled=False):
|
|
862
|
+
out = self.global_avgpool(x)
|
|
863
|
+
out = self.fc(out)
|
|
864
|
+
out = self.act(out)
|
|
865
|
+
'''if torch.jit.is_tracing:
|
|
866
|
+
out = hard_sigmoid(out)
|
|
867
|
+
else:
|
|
868
|
+
out = self.act(out)'''
|
|
869
|
+
return x * out
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
class MParent:
|
|
873
|
+
def __init__(self,model):
|
|
874
|
+
self.model = model
|
|
875
|
+
|
|
876
|
+
def __getattr__(self, name):
|
|
877
|
+
return self.model.__getattr__(name)
|
|
878
|
+
|
|
879
|
+
def __getitem__(self, name):
|
|
880
|
+
return self.model.__getitem__(name)
|
|
881
|
+
|
|
882
|
+
class Unsqueeze(nn.Module):
|
|
883
|
+
def __init__(self,dim) -> None:
|
|
884
|
+
super().__init__()
|
|
885
|
+
self.dim = dim
|
|
886
|
+
|
|
887
|
+
def forward(self,x):
|
|
888
|
+
return torch.unsqueeze(x,dim=self.dim)
|
|
889
|
+
|
|
890
|
+
class Squeeze(nn.Module):
|
|
891
|
+
def __init__(self,dim) -> None:
|
|
892
|
+
super().__init__()
|
|
893
|
+
self.dim = dim
|
|
894
|
+
|
|
895
|
+
def forward(self,x):
|
|
896
|
+
return torch.squeeze(x,dim=self.dim)
|