python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,603 @@
1
+ import os
2
+ import torch
3
+ import math
4
+ from functools import partial
5
+ import torch.nn as nn
6
+ import time
7
+ import inspect
8
+ import sys
9
+ from .wlr_scheduler import *
10
+ from collections import OrderedDict
11
+ from .nn import LayerNorm,LayerNorm2d,EvoNormS0,EvoNormS01D,FrozenBatchNorm2d
12
+ import traceback
13
+ from typing import Union, Iterable
14
+ import re
15
+
16
+
17
+ _NORMS = (
18
+ nn.BatchNorm1d,
19
+ nn.BatchNorm2d,
20
+ nn.BatchNorm3d,
21
+ nn.InstanceNorm1d,
22
+ nn.InstanceNorm2d,
23
+ nn.InstanceNorm3d,
24
+ nn.SyncBatchNorm,
25
+ nn.GroupNorm,
26
+ LayerNorm,
27
+ LayerNorm2d,
28
+ EvoNormS0,
29
+ EvoNormS01D,
30
+ FrozenBatchNorm2d,
31
+ )
32
+
33
+ def is_norm(model):
34
+ return isinstance(model,_NORMS)
35
+
36
+ def __is_name_of(name, names):
37
+ for x in names:
38
+ if name.startswith(x) or name.startswith("module."+x):
39
+ return True
40
+ return False
41
+
42
+ def is_in_scope(name, scopes):
43
+ for x in scopes:
44
+ if name.startswith(x) or name.startswith("module."+x):
45
+ return True
46
+ return False
47
+
48
+ def _get_tensor_or_tensors_shape(x):
49
+ if isinstance(x,(list,tuple)):
50
+ res = []
51
+ for v in x:
52
+ if v is not None:
53
+ res.append(v.shape)
54
+ return res
55
+ if x is not None:
56
+ return x.shape
57
+ else:
58
+ return None
59
+
60
+ def grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
61
+ if isinstance(parameters, torch.Tensor):
62
+ parameters = [parameters]
63
+ parameters = [p for p in parameters if p.grad is not None]
64
+ norm_type = float(norm_type)
65
+ if len(parameters) == 0:
66
+ return torch.tensor(0.)
67
+ device = parameters[0].grad.device
68
+ if norm_type == math.inf:
69
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
70
+ else:
71
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
72
+ return total_norm
73
+ def _add_to_dict(v,dicts):
74
+ for i,c in enumerate(dicts):
75
+ if v in c:
76
+ print(f"ERROR: {v} already in dict {i}")
77
+ dicts[0].add(v)
78
+
79
+ def simple_split_parameters(model,filter=None,return_unused=False,silent=False):
80
+ '''
81
+ Example:
82
+ bn_weights,weights,biases = simple_split_parameters(model)
83
+ optimizer = optim.AdamW(weights, lr=lr,weight_decay=1e-4)
84
+ optimizer.add_param_group(
85
+ {"params": bias, "weight_decay": 0.0}
86
+ ) # add pg1 with weight_decay
87
+ optimizer.add_param_group({"params": bn_weights,"weight_decay":0.0})
88
+ '''
89
+ bn_weights, weights, biases = [], [], []
90
+ unbn_weights, unweights, unbiases = [], [], []
91
+ parameters_set = set()
92
+ unused_parameters_set = set()
93
+ print(f"Split model parameters")
94
+ print(f"------------------------------------------")
95
+ total_skip = 0
96
+ for k, v in model.named_modules():
97
+ if len(k)==0:
98
+ continue
99
+ if filter is not None and not(filter(k,v)):
100
+ continue
101
+ if hasattr(v, "bias") and isinstance(v.bias, (torch.Tensor,nn.Parameter)):
102
+ if v.bias.requires_grad is False:
103
+ print(f"{k}.bias requires grad == False, skip.")
104
+ unbiases.append(v.bias)
105
+ _add_to_dict(k+".bias",[unused_parameters_set,parameters_set])
106
+ total_skip += 1
107
+ else:
108
+ biases.append(v.bias) # biases
109
+ parameters_set.add(k+".bias")
110
+ if (isinstance(v, _NORMS) or "bn" in k) and hasattr(v,'weight'):
111
+ if v.weight is None:
112
+ continue
113
+ elif v.weight.requires_grad is False:
114
+ print(f"{k}.weight requires grad == False, skip.")
115
+ unbn_weights.append(v.weight)
116
+ _add_to_dict(k+".weight",[unused_parameters_set,parameters_set])
117
+ total_skip += 1
118
+ else:
119
+ bn_weights.append(v.weight) # no decay
120
+ parameters_set.add(k+".weight")
121
+ elif hasattr(v, "weight") and isinstance(v.weight, (torch.Tensor,nn.Parameter)):
122
+ if v.weight.requires_grad is False:
123
+ print(f"{k}.weight requires grad == False, skip.")
124
+ unweights.append(v.weight)
125
+ _add_to_dict(k+".weight",[unused_parameters_set,parameters_set])
126
+ total_skip += 1
127
+ else:
128
+ weights.append(v.weight) # apply decay
129
+ parameters_set.add(k+".weight")
130
+ for k1,p in v.named_parameters(recurse=False):
131
+ if k1 in ["weight","bias"]:
132
+ continue
133
+ if p.requires_grad == False:
134
+ print(f"{k}.{k1} requires grad == False, skip.")
135
+ total_skip += 1
136
+ if "weight" in k:
137
+ unweights.append(p)
138
+ _add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
139
+ elif "bias" in k:
140
+ unbiases.append(p)
141
+ _add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
142
+ else:
143
+ if p.ndim>1:
144
+ unweights.append(p)
145
+ _add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
146
+ else:
147
+ unbiases.append(p)
148
+ _add_to_dict(k+f".{k1}",[unused_parameters_set,parameters_set])
149
+ continue
150
+ if "weight" in k:
151
+ weights.append(p)
152
+ parameters_set.add(k+f".{k1}")
153
+ elif "bias" in k:
154
+ biases.append(p)
155
+ parameters_set.add(k+f".{k1}")
156
+ else:
157
+ if p.ndim>1:
158
+ weights.append(p)
159
+ else:
160
+ biases.append(p)
161
+ parameters_set.add(k+f".{k1}")
162
+
163
+ print(f"------------------------------------------")
164
+ if not silent:
165
+ for k,p in model.named_parameters():
166
+ if p.requires_grad == False:
167
+ continue
168
+ if k not in parameters_set:
169
+ print(f"ERROR: {k} not in any parameters set.")
170
+ #batch norm weight, weight, bias
171
+ print(f"Total have {len(list(model.named_parameters()))} parameters.")
172
+ print(f"Finaly find {len(bn_weights)} bn weights, {len(weights)} weights, {len(biases)} biases, total {len(bn_weights)+len(weights)+len(biases)}, total skip {total_skip}.")
173
+ if not return_unused:
174
+ return bn_weights,weights,biases
175
+ else:
176
+ return bn_weights,weights,biases,unbn_weights,unweights,unbiases
177
+
178
+ def freeze_model(model,freeze_bn=True):
179
+ if freeze_bn:
180
+ model.eval()
181
+ for name, param in model.named_parameters():
182
+ print(name, param.size(), "freeze")
183
+ param.requires_grad = False
184
+
185
+ def defrost_model(model,defrost_bn=True,silent=False):
186
+ if defrost_bn:
187
+ model.train()
188
+ for name, param in model.named_parameters():
189
+ if not silent:
190
+ print(name, param.size(), "defrost")
191
+ param.requires_grad = True
192
+
193
+ def defrost_scope(model,scope,defrost_bn=True,silent=False):
194
+ if defrost_bn:
195
+ defrost_bn(model,scope)
196
+ for name, param in model.named_parameters():
197
+ if not is_in_scope(name,scope):
198
+ continue
199
+ if not silent:
200
+ print(name, param.size(), "defrost")
201
+ param.requires_grad = True
202
+
203
+ def __set_bn_momentum(m,momentum=0.1):
204
+ classname = m.__class__.__name__
205
+ if classname.find('BatchNorm') != -1:
206
+ m.momentum = momentum
207
+
208
+ def __set_bn_eps(m,eps=1e-3):
209
+ classname = m.__class__.__name__
210
+ if classname.find('BatchNorm') != -1:
211
+ m.eps = eps
212
+
213
+ def __fix_bn(m):
214
+ classname = m.__class__.__name__
215
+ if classname.find('BatchNorm') != -1:
216
+ m.eval()
217
+
218
+ def defrost_bn(model:torch.nn.Module,scopes=None):
219
+
220
+ _nr = 0
221
+ _nr_skip = 0
222
+ for name, ms in model.named_modules():
223
+ if not isinstance(ms, nn.BatchNorm2d):
224
+ continue
225
+ if __is_name_of(name, scopes):
226
+ ms.train()
227
+ print(f"defrost bn {name}")
228
+ _nr += 1
229
+ else:
230
+ _nr_skip += 1
231
+ continue
232
+ print(f"Total defrost {_nr} bn, total {_nr_skip} bn not defrost.")
233
+ sys.stdout.flush()
234
+ return model
235
+
236
+ def __freeze_bn(model:torch.nn.Module,names2freeze=None):
237
+
238
+ _nr = 0
239
+ _nr_skip = 0
240
+ for name, ms in model.named_modules():
241
+ if not isinstance(ms, nn.BatchNorm2d):
242
+ continue
243
+ if __is_name_of(name, names2freeze):
244
+ ms.apply(__fix_bn)
245
+ print(f"Freeze bn {name}")
246
+ _nr += 1
247
+ else:
248
+ _nr_skip += 1
249
+ continue
250
+ print(f"Total freeze {_nr} bn, total {_nr_skip} bn not freeze.")
251
+ sys.stdout.flush()
252
+ return model
253
+
254
+ def __freeze_bn2(model,names2freeze=None):
255
+ '''
256
+ names2freeze: str/list[str] names to freeze
257
+ '''
258
+ for name in names2freeze:
259
+ child = getattr(model,name)
260
+ FrozenBatchNorm2d.convert_frozen_batchnorm(child)
261
+
262
+ def freeze_bn(model,names2freeze=None):
263
+ '''
264
+ names2freeze: str/list[str] names to freeze
265
+ '''
266
+ if names2freeze is None:
267
+ model.apply(__fix_bn)
268
+ else:
269
+ if isinstance(names2freeze,(str,bytes)):
270
+ names2freeze = [names2freeze]
271
+ model = __freeze_bn(model,names2freeze)
272
+
273
+ return model
274
+
275
+ def freeze_bn2(model,names2freeze=None):
276
+ '''
277
+ names2freeze: str/list[str] names to freeze
278
+ '''
279
+ if names2freeze is None:
280
+ #model.apply(__fix_bn)
281
+ model = FrozenBatchNorm2d.convert_frozen_batchnorm(model)
282
+ else:
283
+ if isinstance(names2freeze,(str,bytes)):
284
+ names2freeze = [names2freeze]
285
+ model = __freeze_bn2(model,names2freeze)
286
+
287
+ return model
288
+
289
+ def set_bn_momentum(model,momentum):
290
+ fn = partial(__set_bn_momentum,momentum=momentum)
291
+ model.apply(fn)
292
+
293
+ def set_bn_eps(model,eps):
294
+ fn = partial(__set_bn_eps,eps=eps)
295
+ model.apply(fn)
296
+
297
+ def get_gpus_str(gpus):
298
+ gpus_str = ""
299
+ for g in gpus:
300
+ gpus_str += str(g) + ","
301
+ gpus_str = gpus_str[:-1]
302
+
303
+ return gpus_str
304
+
305
+ def show_model_parameters_info(net):
306
+ print("Training parameters.")
307
+ total_train_parameters = 0
308
+ freeze_parameters = []
309
+ unfreeze_parameters = []
310
+ for name, param in net.named_parameters():
311
+ if param.requires_grad:
312
+ print(name, list(param.size()), param.device,'unfreeze')
313
+ total_train_parameters += param.numel()
314
+ unfreeze_parameters.append(name)
315
+ print(f"Total train parameters {total_train_parameters:,}")
316
+ print("Not training parameters.")
317
+ total_not_train_parameters = 0
318
+ for name, param in net.named_parameters():
319
+ if not param.requires_grad:
320
+ print(name, list(param.size()), param.device,'freeze')
321
+ total_not_train_parameters += param.numel()
322
+ freeze_parameters.append(name)
323
+ print(f"Total not train parameters {total_not_train_parameters:,}")
324
+
325
+ _nr = 0
326
+ not_freeze_nr =0
327
+ for name, ms in net.named_modules():
328
+ if not isinstance(ms, (nn.BatchNorm2d,FrozenBatchNorm2d)):
329
+ continue
330
+ if not ms.training or isinstance(ms,FrozenBatchNorm2d):
331
+ _nr += 1
332
+ else:
333
+ not_freeze_nr += 1
334
+ print(f"Total freeze {_nr} batch normal layers, {not_freeze_nr} batch normal layer not freeze.")
335
+
336
+ return freeze_parameters,unfreeze_parameters
337
+
338
+ def show_async_norm_states(module):
339
+ for name, child in module.named_modules():
340
+ if isinstance(child, _NORMS):
341
+ info = ""
342
+ for k,v in child.named_parameters():
343
+ if hasattr(v,"requires_grad"):
344
+ info += f"{k}:{v.requires_grad}, "
345
+ print(f"{name}: {type(child)}: training: {child.training}, requires_grad: {info}")
346
+
347
+ def get_total_and_free_memory_in_Mb(cuda_device):
348
+ devices_info_str = os.popen(
349
+ "nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader"
350
+ )
351
+ devices_info = devices_info_str.read().strip().split("\n")
352
+ total, used = devices_info[int(cuda_device)].split(",")
353
+ return int(total), int(used)
354
+
355
+
356
+ def occupy_mem(cuda_device, mem_ratio=0.9):
357
+ """
358
+ pre-allocate gpu memory for training to avoid memory Fragmentation.
359
+ """
360
+ total, used = get_total_and_free_memory_in_Mb(cuda_device)
361
+ max_mem = int(total * mem_ratio)
362
+ block_mem = max_mem - used
363
+ x = torch.cuda.FloatTensor(256, 1024, block_mem)
364
+ del x
365
+ time.sleep(5)
366
+
367
+ def isfinite_hook(module,fea_in,fea_out):
368
+ '''
369
+ register_forward_hook(net,isfinite_hook)
370
+ '''
371
+ if isinstance(fea_in,(tuple,list)):
372
+ if len(fea_in)==1:
373
+ fea_in = fea_in[0]
374
+ elif len(fea_in)==0:
375
+ return None
376
+ #if not torch.all(torch.isfinite(fea_in)):
377
+ #return None
378
+ if not torch.all(torch.isfinite(fea_out)):
379
+ print("Find NaN or infininite")
380
+ #print(f"{inspect.stack()}")
381
+ traceback.print_exc(file=sys.stdout)
382
+ print(f"Input : {torch.min(fea_in).item(),torch.max(fea_in).item(),torch.mean(fea_in).item()}")
383
+ print(f"Output: {torch.min(fea_out).item(),torch.max(fea_out).item(),torch.mean(fea_out).item()}")
384
+ for name, param in module.named_parameters():
385
+ print(f"{name}: {torch.min(param).item(),torch.max(param).item(),torch.mean(param).item()}")
386
+
387
+ def islarge_hook(module,fea_in,fea_out,max_v=60000):
388
+ '''
389
+ register_forward_hook(net,isfinite_hook)
390
+ '''
391
+ if isinstance(fea_in,(tuple,list)):
392
+ if len(fea_in)==1:
393
+ fea_in = fea_in[0]
394
+ elif len(fea_in)==0:
395
+ return None
396
+ #if not torch.all(torch.isfinite(fea_in)):
397
+ #return None
398
+ if islarge(fea_out,max_v=max_v):
399
+ print("Find Large value")
400
+ #print(f"{inspect.stack()}")
401
+ traceback.print_exc(file=sys.stdout)
402
+ print(f"Input : {torch.min(fea_in).item(),torch.max(fea_in).item(),torch.mean(fea_in).item()}")
403
+ print(f"Output: {torch.min(fea_out).item(),torch.max(fea_out).item(),torch.mean(fea_out).item()}")
404
+ for name, param in module.named_parameters():
405
+ print(f"{name}: {torch.min(param).item(),torch.max(param).item(),torch.mean(param).item()}")
406
+
407
+
408
+ def islarge(x,max_v=65535):
409
+ if x is None:
410
+ return False
411
+ if isinstance(x,(tuple,list)):
412
+ for v in x :
413
+ if islarge(v,max_v=max_v):
414
+ return True
415
+ return False
416
+ return torch.any(torch.abs(x)>max_v)
417
+
418
+ def isfinite(x):
419
+ if x is None:
420
+ return True
421
+ if isinstance(x,(tuple,list)):
422
+ for v in x :
423
+ if not isfinite(v):
424
+ return False
425
+ return True
426
+ return torch.all(torch.isfinite(x))
427
+
428
+ def register_forward_hook(net,hook):
429
+ nr = 0
430
+ for module in net.children():
431
+ register_forward_hook(module,hook)
432
+ nr += 1
433
+ if nr == 0:
434
+ net.register_forward_hook(hook=hook)
435
+
436
+ def register_backward_hook(net,hook):
437
+ nr = 0
438
+ for module in net.children():
439
+ register_backward_hook(module,hook)
440
+ nr += 1
441
+ if True:
442
+ #if nr == 0:
443
+ #net.register_full_backward_hook(hook=hook)
444
+ net.register_backward_hook(hook=hook)
445
+
446
+ def tensor_fix_grad(grad):
447
+ '''
448
+ tensor.register_hook(net,isfinite_hook)
449
+ '''
450
+ max_v = 16000.0
451
+ if not torch.all(torch.isfinite(grad)):
452
+ #print(f"infinite grad:",grad.shape,grad)
453
+ #raise RuntimeError(f"infinite grad")
454
+ return torch.zeros_like(grad)
455
+ elif islarge(grad,max_v):
456
+ #print(f"large grad:",grad.shape,torch.min(grad),torch.max(grad))
457
+ return torch.clamp(grad,min=-max_v,max=max_v)
458
+ return grad
459
+
460
+
461
+ def tensor_isfinite_hook(grad):
462
+ '''
463
+ tensor.register_hook(net,isfinite_hook)
464
+ '''
465
+ if not torch.all(torch.isfinite(grad)):
466
+ print(f"Find NaN or infininite grad, {grad.shape}")
467
+ #print(f"{inspect.stack()}")
468
+ traceback.print_exc(file=sys.stdout)
469
+ print(f"grad: {torch.min(grad).item(),torch.max(grad).item(),torch.mean(grad).item()}")
470
+ #print("value:",grad)
471
+
472
+ def tensor_islarge_hook(grad,max_v=60000):
473
+ '''
474
+ tensor.register_hook(net,isfinite_hook)
475
+ '''
476
+ if islarge(grad,max_v=max_v):
477
+ print("Find Large value grad")
478
+ #print(f"{inspect.stack()}")
479
+ traceback.print_exc(file=sys.stdout)
480
+ print(f"Output: {torch.min(grad).item(),torch.max(grad).item(),torch.mean(grad).item()}")
481
+
482
+ def register_tensor_hook(model,hook):
483
+ '''
484
+ register_tensor_hook(model,tensor_isfinite_hook)
485
+ '''
486
+ for param in model.parameters():
487
+ if param.requires_grad:
488
+ param.register_hook(hook)
489
+
490
+ def is_any_grad_infinite(model):
491
+ '''
492
+ register_tensor_hook(model,tensor_isfinite_hook)
493
+ '''
494
+ res = False
495
+ for name,param in model.named_parameters():
496
+ if param.requires_grad and param.grad is not None and \
497
+ (not torch.all(torch.isfinite(param.grad)) or islarge(param.grad,max_v=32768.0)):
498
+ print(f"ERROR: {name}: unnormal grad")
499
+ res = True
500
+
501
+ return res
502
+
503
+ def backward_grad_normal_hook(module,grad_input,grad_output):
504
+ '''
505
+ tensor.register_hook(net,isfinite_hook)
506
+ '''
507
+ if not isfinite(grad_output) or islarge(grad_output,max_v=32768.0):
508
+ print("Find NaN or infininite grad")
509
+ #print(f"{inspect.stack()}")
510
+ print(module,_get_tensor_or_tensors_shape(grad_input),_get_tensor_or_tensors_shape(grad_output),grad_input,grad_output)
511
+ #traceback.print_exc(file=sys.stdout)
512
+ #print(f"grad_output: {torch.min(grad_output).item(),torch.max(grad_output).item(),torch.mean(grad_output).item()}")
513
+
514
+ def finetune_model(model,names_not2train=None,names2train=None):
515
+ if names_not2train is not None:
516
+ finetune_model_nottrain(model,names_not2train)
517
+ if names2train is not None:
518
+ finetune_model_train(model,names2train)
519
+ return
520
+
521
+ def is_name_of(name, names):
522
+ for x in names:
523
+ if name.startswith(x) or name.startswith("module."+x):
524
+ return True
525
+ return False
526
+
527
+ for name, param in model.named_parameters():
528
+ if is_name_of(name, names2train):
529
+ continue
530
+ param.requires_grad = False
531
+
532
+ param_to_update = []
533
+ for name, param in model.named_parameters():
534
+ if param.requires_grad:
535
+ param_to_update.append(param)
536
+
537
+ _nr = 0
538
+ for name, ms in model.named_modules():
539
+ if not isinstance(ms, nn.BatchNorm2d):
540
+ continue
541
+ if is_name_of(name, names2train):
542
+ continue
543
+ else:
544
+ ms.eval()
545
+ _nr += 1
546
+
547
+ def finetune_model_train(model,names2train=None):
548
+
549
+ def is_name_of(name, names):
550
+ for x in names:
551
+ if name.startswith(x) or name.startswith("module."+x):
552
+ return True
553
+ return False
554
+
555
+ for name, param in model.named_parameters():
556
+ if is_name_of(name, names2train):
557
+ param.requires_grad = True
558
+
559
+ _nr = 0
560
+ for name, ms in model.named_modules():
561
+ if not isinstance(ms, nn.BatchNorm2d):
562
+ continue
563
+ if is_name_of(name, names2train):
564
+ ms.train()
565
+ _nr += 1
566
+
567
+ def finetune_model_nottrain(model:torch.nn.Module,names_not2train):
568
+
569
+ if not isinstance(names_not2train,(list,tuple)):
570
+ names_not2train = [names_not2train]
571
+
572
+ patterns = [re.compile(x) for x in names_not2train]
573
+
574
+ def is_name_of(name, names):
575
+ for x in names:
576
+ if name.startswith(x) or name.startswith("module."+x):
577
+ return True
578
+ for x in patterns:
579
+ if x.match(name) is not None:
580
+ return True
581
+ return False
582
+
583
+ for name, param in model.named_parameters():
584
+ if is_name_of(name, names_not2train):
585
+ param.requires_grad = False
586
+
587
+
588
+ param_to_update = []
589
+ for name, param in model.named_parameters():
590
+ if param.requires_grad:
591
+ param_to_update.append(param)
592
+
593
+ _nr = 0
594
+ for name, ms in model.named_modules():
595
+ if not isinstance(ms, nn.BatchNorm2d):
596
+ continue
597
+ if is_name_of(name, names_not2train):
598
+ ms.eval()
599
+ _nr += 1
600
+ else:
601
+ continue
602
+ sys.stdout.flush()
603
+