python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
from functools import partial
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import time
|
|
7
|
+
import inspect
|
|
8
|
+
import sys
|
|
9
|
+
from .wlr_scheduler import *
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
from .nn import LayerNorm,LayerNorm2d,EvoNormS0,EvoNormS01D,FrozenBatchNorm2d
|
|
12
|
+
import traceback
|
|
13
|
+
from typing import Union, Iterable
|
|
14
|
+
import re
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_NORMS = (
|
|
18
|
+
nn.BatchNorm1d,
|
|
19
|
+
nn.BatchNorm2d,
|
|
20
|
+
nn.BatchNorm3d,
|
|
21
|
+
nn.InstanceNorm1d,
|
|
22
|
+
nn.InstanceNorm2d,
|
|
23
|
+
nn.InstanceNorm3d,
|
|
24
|
+
nn.SyncBatchNorm,
|
|
25
|
+
nn.GroupNorm,
|
|
26
|
+
LayerNorm,
|
|
27
|
+
LayerNorm2d,
|
|
28
|
+
EvoNormS0,
|
|
29
|
+
EvoNormS01D,
|
|
30
|
+
FrozenBatchNorm2d,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def is_norm(model):
|
|
34
|
+
return isinstance(model,_NORMS)
|
|
35
|
+
|
|
36
|
+
def __is_name_of(name, names):
|
|
37
|
+
for x in names:
|
|
38
|
+
if name.startswith(x) or name.startswith("module."+x):
|
|
39
|
+
return True
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
def is_in_scope(name, scopes):
|
|
43
|
+
for x in scopes:
|
|
44
|
+
if name.startswith(x) or name.startswith("module."+x):
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
def _get_tensor_or_tensors_shape(x):
|
|
49
|
+
if isinstance(x,(list,tuple)):
|
|
50
|
+
res = []
|
|
51
|
+
for v in x:
|
|
52
|
+
if v is not None:
|
|
53
|
+
res.append(v.shape)
|
|
54
|
+
return res
|
|
55
|
+
if x is not None:
|
|
56
|
+
return x.shape
|
|
57
|
+
else:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
def grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
|
61
|
+
if isinstance(parameters, torch.Tensor):
|
|
62
|
+
parameters = [parameters]
|
|
63
|
+
parameters = [p for p in parameters if p.grad is not None]
|
|
64
|
+
norm_type = float(norm_type)
|
|
65
|
+
if len(parameters) == 0:
|
|
66
|
+
return torch.tensor(0.)
|
|
67
|
+
device = parameters[0].grad.device
|
|
68
|
+
if norm_type == math.inf:
|
|
69
|
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
|
70
|
+
else:
|
|
71
|
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
|
72
|
+
return total_norm
|
|
73
|
+
def _add_to_dict(v,dicts):
|
|
74
|
+
for i,c in enumerate(dicts):
|
|
75
|
+
if v in c:
|
|
76
|
+
print(f"ERROR: {v} already in dict {i}")
|
|
77
|
+
dicts[0].add(v)
|
|
78
|
+
|
|
79
|
+
def simple_split_parameters(model,filter=None,return_unused=False,silent=False):
|
|
80
|
+
'''
|
|
81
|
+
Example:
|
|
82
|
+
bn_weights,weights,biases = simple_split_parameters(model)
|
|
83
|
+
optimizer = optim.AdamW(weights, lr=lr,weight_decay=1e-4)
|
|
84
|
+
optimizer.add_param_group(
|
|
85
|
+
{"params": bias, "weight_decay": 0.0}
|
|
86
|
+
) # add pg1 with weight_decay
|
|
87
|
+
optimizer.add_param_group({"params": bn_weights,"weight_decay":0.0})
|
|
88
|
+
'''
|
|
89
|
+
bn_weights, weights, biases = [], [], []
|
|
90
|
+
unbn_weights, unweights, unbiases = [], [], []
|
|
91
|
+
parameters_set = set()
|
|
92
|
+
unused_parameters_set = set()
|
|
93
|
+
print(f"Split model parameters")
|
|
94
|
+
print(f"------------------------------------------")
|
|
95
|
+
total_skip = 0
|
|
96
|
+
for k, v in model.named_modules():
|
|
97
|
+
if len(k)==0:
|
|
98
|
+
continue
|
|
99
|
+
if filter is not None and not(filter(k,v)):
|
|
100
|
+
continue
|
|
101
|
+
if hasattr(v, "bias") and isinstance(v.bias, (torch.Tensor,nn.Parameter)):
|
|
102
|
+
if v.bias.requires_grad is False:
|
|
103
|
+
print(f"{k}.bias requires grad == False, skip.")
|
|
104
|
+
unbiases.append(v.bias)
|
|
105
|
+
_add_to_dict(k+".bias",[unused_parameters_set,parameters_set])
|
|
106
|
+
total_skip += 1
|
|
107
|
+
else:
|
|
108
|
+
biases.append(v.bias) # biases
|
|
109
|
+
parameters_set.add(k+".bias")
|
|
110
|
+
if (isinstance(v, _NORMS) or "bn" in k) and hasattr(v,'weight'):
|
|
111
|
+
if v.weight is None:
|
|
112
|
+
continue
|
|
113
|
+
elif v.weight.requires_grad is False:
|
|
114
|
+
print(f"{k}.weight requires grad == False, skip.")
|
|
115
|
+
unbn_weights.append(v.weight)
|
|
116
|
+
_add_to_dict(k+".weight",[unused_parameters_set,parameters_set])
|
|
117
|
+
total_skip += 1
|
|
118
|
+
else:
|
|
119
|
+
bn_weights.append(v.weight) # no decay
|
|
120
|
+
parameters_set.add(k+".weight")
|
|
121
|
+
elif hasattr(v, "weight") and isinstance(v.weight, (torch.Tensor,nn.Parameter)):
|
|
122
|
+
if v.weight.requires_grad is False:
|
|
123
|
+
print(f"{k}.weight requires grad == False, skip.")
|
|
124
|
+
unweights.append(v.weight)
|
|
125
|
+
_add_to_dict(k+".weight",[unused_parameters_set,parameters_set])
|
|
126
|
+
total_skip += 1
|
|
127
|
+
else:
|
|
128
|
+
weights.append(v.weight) # apply decay
|
|
129
|
+
parameters_set.add(k+".weight")
|
|
130
|
+
for k1,p in v.named_parameters(recurse=False):
|
|
131
|
+
if k1 in ["weight","bias"]:
|
|
132
|
+
continue
|
|
133
|
+
if p.requires_grad == False:
|
|
134
|
+
print(f"{k}.{k1} requires grad == False, skip.")
|
|
135
|
+
total_skip += 1
|
|
136
|
+
if "weight" in k:
|
|
137
|
+
unweights.append(p)
|
|
138
|
+
_add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
|
|
139
|
+
elif "bias" in k:
|
|
140
|
+
unbiases.append(p)
|
|
141
|
+
_add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
|
|
142
|
+
else:
|
|
143
|
+
if p.ndim>1:
|
|
144
|
+
unweights.append(p)
|
|
145
|
+
_add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
|
|
146
|
+
else:
|
|
147
|
+
unbiases.append(p)
|
|
148
|
+
_add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
|
|
149
|
+
continue
|
|
150
|
+
if "weight" in k:
|
|
151
|
+
weights.append(p)
|
|
152
|
+
parameters_set.add(k+f".{k1}")
|
|
153
|
+
elif "bias" in k:
|
|
154
|
+
biases.append(p)
|
|
155
|
+
parameters_set.add(k+f".{k1}")
|
|
156
|
+
else:
|
|
157
|
+
if p.ndim>1:
|
|
158
|
+
weights.append(p)
|
|
159
|
+
else:
|
|
160
|
+
biases.append(p)
|
|
161
|
+
parameters_set.add(k+f".{k1}")
|
|
162
|
+
|
|
163
|
+
print(f"------------------------------------------")
|
|
164
|
+
if not silent:
|
|
165
|
+
for k,p in model.named_parameters():
|
|
166
|
+
if p.requires_grad == False:
|
|
167
|
+
continue
|
|
168
|
+
if k not in parameters_set:
|
|
169
|
+
print(f"ERROR: {k} not in any parameters set.")
|
|
170
|
+
#batch norm weight, weight, bias
|
|
171
|
+
print(f"Total have {len(list(model.named_parameters()))} parameters.")
|
|
172
|
+
print(f"Finaly find {len(bn_weights)} bn weights, {len(weights)} weights, {len(biases)} biases, total {len(bn_weights)+len(weights)+len(biases)}, total skip {total_skip}.")
|
|
173
|
+
if not return_unused:
|
|
174
|
+
return bn_weights,weights,biases
|
|
175
|
+
else:
|
|
176
|
+
return bn_weights,weights,biases,unbn_weights,unweights,unbiases
|
|
177
|
+
|
|
178
|
+
def freeze_model(model,freeze_bn=True):
|
|
179
|
+
if freeze_bn:
|
|
180
|
+
model.eval()
|
|
181
|
+
for name, param in model.named_parameters():
|
|
182
|
+
print(name, param.size(), "freeze")
|
|
183
|
+
param.requires_grad = False
|
|
184
|
+
|
|
185
|
+
def defrost_model(model,defrost_bn=True,silent=False):
|
|
186
|
+
if defrost_bn:
|
|
187
|
+
model.train()
|
|
188
|
+
for name, param in model.named_parameters():
|
|
189
|
+
if not silent:
|
|
190
|
+
print(name, param.size(), "defrost")
|
|
191
|
+
param.requires_grad = True
|
|
192
|
+
|
|
193
|
+
def defrost_scope(model,scope,defrost_bn=True,silent=False):
|
|
194
|
+
if defrost_bn:
|
|
195
|
+
defrost_bn(model,scope)
|
|
196
|
+
for name, param in model.named_parameters():
|
|
197
|
+
if not is_in_scope(name,scope):
|
|
198
|
+
continue
|
|
199
|
+
if not silent:
|
|
200
|
+
print(name, param.size(), "defrost")
|
|
201
|
+
param.requires_grad = True
|
|
202
|
+
|
|
203
|
+
def __set_bn_momentum(m,momentum=0.1):
|
|
204
|
+
classname = m.__class__.__name__
|
|
205
|
+
if classname.find('BatchNorm') != -1:
|
|
206
|
+
m.momentum = momentum
|
|
207
|
+
|
|
208
|
+
def __set_bn_eps(m,eps=1e-3):
|
|
209
|
+
classname = m.__class__.__name__
|
|
210
|
+
if classname.find('BatchNorm') != -1:
|
|
211
|
+
m.eps = eps
|
|
212
|
+
|
|
213
|
+
def __fix_bn(m):
|
|
214
|
+
classname = m.__class__.__name__
|
|
215
|
+
if classname.find('BatchNorm') != -1:
|
|
216
|
+
m.eval()
|
|
217
|
+
|
|
218
|
+
def defrost_bn(model:torch.nn.Module,scopes=None):
|
|
219
|
+
|
|
220
|
+
_nr = 0
|
|
221
|
+
_nr_skip = 0
|
|
222
|
+
for name, ms in model.named_modules():
|
|
223
|
+
if not isinstance(ms, nn.BatchNorm2d):
|
|
224
|
+
continue
|
|
225
|
+
if __is_name_of(name, scopes):
|
|
226
|
+
ms.train()
|
|
227
|
+
print(f"defrost bn {name}")
|
|
228
|
+
_nr += 1
|
|
229
|
+
else:
|
|
230
|
+
_nr_skip += 1
|
|
231
|
+
continue
|
|
232
|
+
print(f"Total defrost {_nr} bn, total {_nr_skip} bn not defrost.")
|
|
233
|
+
sys.stdout.flush()
|
|
234
|
+
return model
|
|
235
|
+
|
|
236
|
+
def __freeze_bn(model:torch.nn.Module,names2freeze=None):
|
|
237
|
+
|
|
238
|
+
_nr = 0
|
|
239
|
+
_nr_skip = 0
|
|
240
|
+
for name, ms in model.named_modules():
|
|
241
|
+
if not isinstance(ms, nn.BatchNorm2d):
|
|
242
|
+
continue
|
|
243
|
+
if __is_name_of(name, names2freeze):
|
|
244
|
+
ms.apply(__fix_bn)
|
|
245
|
+
print(f"Freeze bn {name}")
|
|
246
|
+
_nr += 1
|
|
247
|
+
else:
|
|
248
|
+
_nr_skip += 1
|
|
249
|
+
continue
|
|
250
|
+
print(f"Total freeze {_nr} bn, total {_nr_skip} bn not freeze.")
|
|
251
|
+
sys.stdout.flush()
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
def __freeze_bn2(model,names2freeze=None):
|
|
255
|
+
'''
|
|
256
|
+
names2freeze: str/list[str] names to freeze
|
|
257
|
+
'''
|
|
258
|
+
for name in names2freeze:
|
|
259
|
+
child = getattr(model,name)
|
|
260
|
+
FrozenBatchNorm2d.convert_frozen_batchnorm(child)
|
|
261
|
+
|
|
262
|
+
def freeze_bn(model,names2freeze=None):
|
|
263
|
+
'''
|
|
264
|
+
names2freeze: str/list[str] names to freeze
|
|
265
|
+
'''
|
|
266
|
+
if names2freeze is None:
|
|
267
|
+
model.apply(__fix_bn)
|
|
268
|
+
else:
|
|
269
|
+
if isinstance(names2freeze,(str,bytes)):
|
|
270
|
+
names2freeze = [names2freeze]
|
|
271
|
+
model = __freeze_bn(model,names2freeze)
|
|
272
|
+
|
|
273
|
+
return model
|
|
274
|
+
|
|
275
|
+
def freeze_bn2(model,names2freeze=None):
|
|
276
|
+
'''
|
|
277
|
+
names2freeze: str/list[str] names to freeze
|
|
278
|
+
'''
|
|
279
|
+
if names2freeze is None:
|
|
280
|
+
#model.apply(__fix_bn)
|
|
281
|
+
model = FrozenBatchNorm2d.convert_frozen_batchnorm(model)
|
|
282
|
+
else:
|
|
283
|
+
if isinstance(names2freeze,(str,bytes)):
|
|
284
|
+
names2freeze = [names2freeze]
|
|
285
|
+
model = __freeze_bn2(model,names2freeze)
|
|
286
|
+
|
|
287
|
+
return model
|
|
288
|
+
|
|
289
|
+
def set_bn_momentum(model,momentum):
|
|
290
|
+
fn = partial(__set_bn_momentum,momentum=momentum)
|
|
291
|
+
model.apply(fn)
|
|
292
|
+
|
|
293
|
+
def set_bn_eps(model,eps):
|
|
294
|
+
fn = partial(__set_bn_eps,eps=eps)
|
|
295
|
+
model.apply(fn)
|
|
296
|
+
|
|
297
|
+
def get_gpus_str(gpus):
|
|
298
|
+
gpus_str = ""
|
|
299
|
+
for g in gpus:
|
|
300
|
+
gpus_str += str(g) + ","
|
|
301
|
+
gpus_str = gpus_str[:-1]
|
|
302
|
+
|
|
303
|
+
return gpus_str
|
|
304
|
+
|
|
305
|
+
def show_model_parameters_info(net):
|
|
306
|
+
print("Training parameters.")
|
|
307
|
+
total_train_parameters = 0
|
|
308
|
+
freeze_parameters = []
|
|
309
|
+
unfreeze_parameters = []
|
|
310
|
+
for name, param in net.named_parameters():
|
|
311
|
+
if param.requires_grad:
|
|
312
|
+
print(name, list(param.size()), param.device,'unfreeze')
|
|
313
|
+
total_train_parameters += param.numel()
|
|
314
|
+
unfreeze_parameters.append(name)
|
|
315
|
+
print(f"Total train parameters {total_train_parameters:,}")
|
|
316
|
+
print("Not training parameters.")
|
|
317
|
+
total_not_train_parameters = 0
|
|
318
|
+
for name, param in net.named_parameters():
|
|
319
|
+
if not param.requires_grad:
|
|
320
|
+
print(name, list(param.size()), param.device,'freeze')
|
|
321
|
+
total_not_train_parameters += param.numel()
|
|
322
|
+
freeze_parameters.append(name)
|
|
323
|
+
print(f"Total not train parameters {total_not_train_parameters:,}")
|
|
324
|
+
|
|
325
|
+
_nr = 0
|
|
326
|
+
not_freeze_nr =0
|
|
327
|
+
for name, ms in net.named_modules():
|
|
328
|
+
if not isinstance(ms, (nn.BatchNorm2d,FrozenBatchNorm2d)):
|
|
329
|
+
continue
|
|
330
|
+
if not ms.training or isinstance(ms,FrozenBatchNorm2d):
|
|
331
|
+
_nr += 1
|
|
332
|
+
else:
|
|
333
|
+
not_freeze_nr += 1
|
|
334
|
+
print(f"Total freeze {_nr} batch normal layers, {not_freeze_nr} batch normal layer not freeze.")
|
|
335
|
+
|
|
336
|
+
return freeze_parameters,unfreeze_parameters
|
|
337
|
+
|
|
338
|
+
def show_async_norm_states(module):
|
|
339
|
+
for name, child in module.named_modules():
|
|
340
|
+
if isinstance(child, _NORMS):
|
|
341
|
+
info = ""
|
|
342
|
+
for k,v in child.named_parameters():
|
|
343
|
+
if hasattr(v,"requires_grad"):
|
|
344
|
+
info += f"{k}:{v.requires_grad}, "
|
|
345
|
+
print(f"{name}: {type(child)}: training: {child.training}, requires_grad: {info}")
|
|
346
|
+
|
|
347
|
+
def get_total_and_free_memory_in_Mb(cuda_device):
|
|
348
|
+
devices_info_str = os.popen(
|
|
349
|
+
"nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader"
|
|
350
|
+
)
|
|
351
|
+
devices_info = devices_info_str.read().strip().split("\n")
|
|
352
|
+
total, used = devices_info[int(cuda_device)].split(",")
|
|
353
|
+
return int(total), int(used)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def occupy_mem(cuda_device, mem_ratio=0.9):
|
|
357
|
+
"""
|
|
358
|
+
pre-allocate gpu memory for training to avoid memory Fragmentation.
|
|
359
|
+
"""
|
|
360
|
+
total, used = get_total_and_free_memory_in_Mb(cuda_device)
|
|
361
|
+
max_mem = int(total * mem_ratio)
|
|
362
|
+
block_mem = max_mem - used
|
|
363
|
+
x = torch.cuda.FloatTensor(256, 1024, block_mem)
|
|
364
|
+
del x
|
|
365
|
+
time.sleep(5)
|
|
366
|
+
|
|
367
|
+
def isfinite_hook(module,fea_in,fea_out):
|
|
368
|
+
'''
|
|
369
|
+
register_forward_hook(net,isfinite_hook)
|
|
370
|
+
'''
|
|
371
|
+
if isinstance(fea_in,(tuple,list)):
|
|
372
|
+
if len(fea_in)==1:
|
|
373
|
+
fea_in = fea_in[0]
|
|
374
|
+
elif len(fea_in)==0:
|
|
375
|
+
return None
|
|
376
|
+
#if not torch.all(torch.isfinite(fea_in)):
|
|
377
|
+
#return None
|
|
378
|
+
if not torch.all(torch.isfinite(fea_out)):
|
|
379
|
+
print("Find NaN or infininite")
|
|
380
|
+
#print(f"{inspect.stack()}")
|
|
381
|
+
traceback.print_exc(file=sys.stdout)
|
|
382
|
+
print(f"Input : {torch.min(fea_in).item(),torch.max(fea_in).item(),torch.mean(fea_in).item()}")
|
|
383
|
+
print(f"Output: {torch.min(fea_out).item(),torch.max(fea_out).item(),torch.mean(fea_out).item()}")
|
|
384
|
+
for name, param in module.named_parameters():
|
|
385
|
+
print(f"{name}: {torch.min(param).item(),torch.max(param).item(),torch.mean(param).item()}")
|
|
386
|
+
|
|
387
|
+
def islarge_hook(module,fea_in,fea_out,max_v=60000):
|
|
388
|
+
'''
|
|
389
|
+
register_forward_hook(net,isfinite_hook)
|
|
390
|
+
'''
|
|
391
|
+
if isinstance(fea_in,(tuple,list)):
|
|
392
|
+
if len(fea_in)==1:
|
|
393
|
+
fea_in = fea_in[0]
|
|
394
|
+
elif len(fea_in)==0:
|
|
395
|
+
return None
|
|
396
|
+
#if not torch.all(torch.isfinite(fea_in)):
|
|
397
|
+
#return None
|
|
398
|
+
if islarge(fea_out,max_v=max_v):
|
|
399
|
+
print("Find Large value")
|
|
400
|
+
#print(f"{inspect.stack()}")
|
|
401
|
+
traceback.print_exc(file=sys.stdout)
|
|
402
|
+
print(f"Input : {torch.min(fea_in).item(),torch.max(fea_in).item(),torch.mean(fea_in).item()}")
|
|
403
|
+
print(f"Output: {torch.min(fea_out).item(),torch.max(fea_out).item(),torch.mean(fea_out).item()}")
|
|
404
|
+
for name, param in module.named_parameters():
|
|
405
|
+
print(f"{name}: {torch.min(param).item(),torch.max(param).item(),torch.mean(param).item()}")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def islarge(x,max_v=65535):
|
|
409
|
+
if x is None:
|
|
410
|
+
return False
|
|
411
|
+
if isinstance(x,(tuple,list)):
|
|
412
|
+
for v in x :
|
|
413
|
+
if islarge(v,max_v=max_v):
|
|
414
|
+
return True
|
|
415
|
+
return False
|
|
416
|
+
return torch.any(torch.abs(x)>max_v)
|
|
417
|
+
|
|
418
|
+
def isfinite(x):
|
|
419
|
+
if x is None:
|
|
420
|
+
return True
|
|
421
|
+
if isinstance(x,(tuple,list)):
|
|
422
|
+
for v in x :
|
|
423
|
+
if not isfinite(v):
|
|
424
|
+
return False
|
|
425
|
+
return True
|
|
426
|
+
return torch.all(torch.isfinite(x))
|
|
427
|
+
|
|
428
|
+
def register_forward_hook(net,hook):
|
|
429
|
+
nr = 0
|
|
430
|
+
for module in net.children():
|
|
431
|
+
register_forward_hook(module,hook)
|
|
432
|
+
nr += 1
|
|
433
|
+
if nr == 0:
|
|
434
|
+
net.register_forward_hook(hook=hook)
|
|
435
|
+
|
|
436
|
+
def register_backward_hook(net,hook):
|
|
437
|
+
nr = 0
|
|
438
|
+
for module in net.children():
|
|
439
|
+
register_backward_hook(module,hook)
|
|
440
|
+
nr += 1
|
|
441
|
+
if True:
|
|
442
|
+
#if nr == 0:
|
|
443
|
+
#net.register_full_backward_hook(hook=hook)
|
|
444
|
+
net.register_backward_hook(hook=hook)
|
|
445
|
+
|
|
446
|
+
def tensor_fix_grad(grad):
|
|
447
|
+
'''
|
|
448
|
+
tensor.register_hook(net,isfinite_hook)
|
|
449
|
+
'''
|
|
450
|
+
max_v = 16000.0
|
|
451
|
+
if not torch.all(torch.isfinite(grad)):
|
|
452
|
+
#print(f"infinite grad:",grad.shape,grad)
|
|
453
|
+
#raise RuntimeError(f"infinite grad")
|
|
454
|
+
return torch.zeros_like(grad)
|
|
455
|
+
elif islarge(grad,max_v):
|
|
456
|
+
#print(f"large grad:",grad.shape,torch.min(grad),torch.max(grad))
|
|
457
|
+
return torch.clamp(grad,min=-max_v,max=max_v)
|
|
458
|
+
return grad
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def tensor_isfinite_hook(grad):
|
|
462
|
+
'''
|
|
463
|
+
tensor.register_hook(net,isfinite_hook)
|
|
464
|
+
'''
|
|
465
|
+
if not torch.all(torch.isfinite(grad)):
|
|
466
|
+
print(f"Find NaN or infininite grad, {grad.shape}")
|
|
467
|
+
#print(f"{inspect.stack()}")
|
|
468
|
+
traceback.print_exc(file=sys.stdout)
|
|
469
|
+
print(f"grad: {torch.min(grad).item(),torch.max(grad).item(),torch.mean(grad).item()}")
|
|
470
|
+
#print("value:",grad)
|
|
471
|
+
|
|
472
|
+
def tensor_islarge_hook(grad,max_v=60000):
|
|
473
|
+
'''
|
|
474
|
+
tensor.register_hook(net,isfinite_hook)
|
|
475
|
+
'''
|
|
476
|
+
if islarge(grad,max_v=max_v):
|
|
477
|
+
print("Find Large value grad")
|
|
478
|
+
#print(f"{inspect.stack()}")
|
|
479
|
+
traceback.print_exc(file=sys.stdout)
|
|
480
|
+
print(f"Output: {torch.min(grad).item(),torch.max(grad).item(),torch.mean(grad).item()}")
|
|
481
|
+
|
|
482
|
+
def register_tensor_hook(model,hook):
|
|
483
|
+
'''
|
|
484
|
+
register_tensor_hook(model,tensor_isfinite_hook)
|
|
485
|
+
'''
|
|
486
|
+
for param in model.parameters():
|
|
487
|
+
if param.requires_grad:
|
|
488
|
+
param.register_hook(hook)
|
|
489
|
+
|
|
490
|
+
def is_any_grad_infinite(model):
|
|
491
|
+
'''
|
|
492
|
+
register_tensor_hook(model,tensor_isfinite_hook)
|
|
493
|
+
'''
|
|
494
|
+
res = False
|
|
495
|
+
for name,param in model.named_parameters():
|
|
496
|
+
if param.requires_grad and param.grad is not None and \
|
|
497
|
+
(not torch.all(torch.isfinite(param.grad)) or islarge(param.grad,max_v=32768.0)):
|
|
498
|
+
print(f"ERROR: {name}: unnormal grad")
|
|
499
|
+
res = True
|
|
500
|
+
|
|
501
|
+
return res
|
|
502
|
+
|
|
503
|
+
def backward_grad_normal_hook(module,grad_input,grad_output):
|
|
504
|
+
'''
|
|
505
|
+
tensor.register_hook(net,isfinite_hook)
|
|
506
|
+
'''
|
|
507
|
+
if not isfinite(grad_output) or islarge(grad_output,max_v=32768.0):
|
|
508
|
+
print("Find NaN or infininite grad")
|
|
509
|
+
#print(f"{inspect.stack()}")
|
|
510
|
+
print(module,_get_tensor_or_tensors_shape(grad_input),_get_tensor_or_tensors_shape(grad_output),grad_input,grad_output)
|
|
511
|
+
#traceback.print_exc(file=sys.stdout)
|
|
512
|
+
#print(f"grad_output: {torch.min(grad_output).item(),torch.max(grad_output).item(),torch.mean(grad_output).item()}")
|
|
513
|
+
|
|
514
|
+
def finetune_model(model,names_not2train=None,names2train=None):
|
|
515
|
+
if names_not2train is not None:
|
|
516
|
+
finetune_model_nottrain(model,names_not2train)
|
|
517
|
+
if names2train is not None:
|
|
518
|
+
finetune_model_train(model,names2train)
|
|
519
|
+
return
|
|
520
|
+
|
|
521
|
+
def is_name_of(name, names):
|
|
522
|
+
for x in names:
|
|
523
|
+
if name.startswith(x) or name.startswith("module."+x):
|
|
524
|
+
return True
|
|
525
|
+
return False
|
|
526
|
+
|
|
527
|
+
for name, param in model.named_parameters():
|
|
528
|
+
if is_name_of(name, names2train):
|
|
529
|
+
continue
|
|
530
|
+
param.requires_grad = False
|
|
531
|
+
|
|
532
|
+
param_to_update = []
|
|
533
|
+
for name, param in model.named_parameters():
|
|
534
|
+
if param.requires_grad:
|
|
535
|
+
param_to_update.append(param)
|
|
536
|
+
|
|
537
|
+
_nr = 0
|
|
538
|
+
for name, ms in model.named_modules():
|
|
539
|
+
if not isinstance(ms, nn.BatchNorm2d):
|
|
540
|
+
continue
|
|
541
|
+
if is_name_of(name, names2train):
|
|
542
|
+
continue
|
|
543
|
+
else:
|
|
544
|
+
ms.eval()
|
|
545
|
+
_nr += 1
|
|
546
|
+
|
|
547
|
+
def finetune_model_train(model,names2train=None):
|
|
548
|
+
|
|
549
|
+
def is_name_of(name, names):
|
|
550
|
+
for x in names:
|
|
551
|
+
if name.startswith(x) or name.startswith("module."+x):
|
|
552
|
+
return True
|
|
553
|
+
return False
|
|
554
|
+
|
|
555
|
+
for name, param in model.named_parameters():
|
|
556
|
+
if is_name_of(name, names2train):
|
|
557
|
+
param.requires_grad = True
|
|
558
|
+
|
|
559
|
+
_nr = 0
|
|
560
|
+
for name, ms in model.named_modules():
|
|
561
|
+
if not isinstance(ms, nn.BatchNorm2d):
|
|
562
|
+
continue
|
|
563
|
+
if is_name_of(name, names2train):
|
|
564
|
+
ms.train()
|
|
565
|
+
_nr += 1
|
|
566
|
+
|
|
567
|
+
def finetune_model_nottrain(model:torch.nn.Module,names_not2train):
|
|
568
|
+
|
|
569
|
+
if not isinstance(names_not2train,(list,tuple)):
|
|
570
|
+
names_not2train = [names_not2train]
|
|
571
|
+
|
|
572
|
+
patterns = [re.compile(x) for x in names_not2train]
|
|
573
|
+
|
|
574
|
+
def is_name_of(name, names):
|
|
575
|
+
for x in names:
|
|
576
|
+
if name.startswith(x) or name.startswith("module."+x):
|
|
577
|
+
return True
|
|
578
|
+
for x in patterns:
|
|
579
|
+
if x.match(name) is not None:
|
|
580
|
+
return True
|
|
581
|
+
return False
|
|
582
|
+
|
|
583
|
+
for name, param in model.named_parameters():
|
|
584
|
+
if is_name_of(name, names_not2train):
|
|
585
|
+
param.requires_grad = False
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
param_to_update = []
|
|
589
|
+
for name, param in model.named_parameters():
|
|
590
|
+
if param.requires_grad:
|
|
591
|
+
param_to_update.append(param)
|
|
592
|
+
|
|
593
|
+
_nr = 0
|
|
594
|
+
for name, ms in model.named_modules():
|
|
595
|
+
if not isinstance(ms, nn.BatchNorm2d):
|
|
596
|
+
continue
|
|
597
|
+
if is_name_of(name, names_not2train):
|
|
598
|
+
ms.eval()
|
|
599
|
+
_nr += 1
|
|
600
|
+
else:
|
|
601
|
+
continue
|
|
602
|
+
sys.stdout.flush()
|
|
603
|
+
|