python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/utils.py ADDED
@@ -0,0 +1,719 @@
1
+ import torch
2
+ import numpy as np
3
+ from collections import Iterable
4
+ import torch.nn.functional as F
5
+ import random
6
+ import sys
7
+ from functools import wraps
8
+ from collections.abc import Mapping, Sequence
9
+ import wml.wml_utils as wmlu
10
+ import wml.img_utils as wmli
11
+ import cv2
12
+ from wml.thirdparty.config import CfgNode
13
+ from wml.wstructures import WPolygonMasks,WBitmapMasks, WMCKeypoints
14
+ from wml.semantic.basic_toolkit import *
15
+ from itertools import repeat
16
+ import collections.abc
17
+ import math
18
+ import onnx
19
+ import pickle
20
+ import types
21
+
22
+ try:
23
+ from mmcv.parallel import DataContainer as DC
24
+ from mmcv.utils.config import ConfigDict
25
+ except:
26
+ DC = None
27
+ ConfigDict = wmlu.AlwaysNullObj
28
+
29
+ def _ntuple(n):
30
+ def parse(x):
31
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
32
+ return tuple(x)
33
+ return tuple(repeat(x, n))
34
+ return parse
35
+
36
+
37
+ to_1tuple = _ntuple(1)
38
+ to_2tuple = _ntuple(2)
39
+ to_3tuple = _ntuple(3)
40
+ to_4tuple = _ntuple(4)
41
+ to_ntuple = _ntuple
42
+
43
+ def unnormalize(x:torch.Tensor,mean=[0.0,0.0,0.0],std=[1.0,1.0,1.0]):
44
+ if len(x.size())==4:
45
+ C = x.shape[1]
46
+ scale = np.reshape(np.array(std,dtype=np.float32),[1,C,1,1])
47
+ offset = np.reshape(np.array(mean,dtype=np.float32),[1,C,1,1])
48
+ elif len(x.size())==5:
49
+ C = x.shape[2]
50
+ scale = np.reshape(np.array(std, dtype=np.float32), [1, 1,C, 1, 1])
51
+ offset = np.reshape(np.array(mean, dtype=np.float32), [1,1, C, 1, 1])
52
+ elif len(x.size())==3:
53
+ C = x.shape[0]
54
+ scale = np.reshape(np.array(std, dtype=np.float32), [C, 1, 1])
55
+ offset = np.reshape(np.array(mean, dtype=np.float32), [C, 1, 1])
56
+
57
+ offset = torch.from_numpy(offset).to(x.device)
58
+ scale = torch.from_numpy(scale).to(x.device)
59
+ x = x*scale+offset
60
+ return x
61
+
62
+ def normalize(x:torch.Tensor,mean=[0.0,0.0,0.0],std=[1.0,1.0,1.0]):
63
+ channel = len(mean)
64
+ if len(x.size())==4:
65
+ scale = np.reshape(np.array(std,dtype=np.float32),[1,channel,1,1])
66
+ offset = np.reshape(np.array(mean,dtype=np.float32),[1,channel,1,1])
67
+ elif len(x.size())==5:
68
+ scale = np.reshape(np.array(std, dtype=np.float32), [1, 1,channel, 1, 1])
69
+ offset = np.reshape(np.array(mean, dtype=np.float32), [1,1, channel, 1, 1])
70
+ elif len(x.size())==3:
71
+ scale = np.reshape(np.array(std, dtype=np.float32), [channel, 1, 1])
72
+ offset = np.reshape(np.array(mean, dtype=np.float32), [channel, 1, 1])
73
+
74
+ offset = torch.from_numpy(offset).to(x.device)
75
+ scale = torch.from_numpy(scale).to(x.device)
76
+ x = (x-offset)/scale
77
+ return x
78
+
79
+ def npnormalize(x:np.ndarray,mean=[0.0,0.0,0.0],std=[1.0,1.0,1.0]):
80
+ if len(x.shape)==4:
81
+ scale = np.reshape(np.array(std,dtype=np.float32),[1,3,1,1])
82
+ offset = np.reshape(np.array(mean,dtype=np.float32),[1,3,1,1])
83
+ elif len(x.shape)==5:
84
+ scale = np.reshape(np.array(std, dtype=np.float32), [1, 1,3, 1, 1])
85
+ offset = np.reshape(np.array(mean, dtype=np.float32), [1,1, 3, 1, 1])
86
+ elif len(x.shape)==3:
87
+ scale = np.reshape(np.array(std, dtype=np.float32), [3, 1, 1])
88
+ offset = np.reshape(np.array(mean, dtype=np.float32), [3, 1, 1])
89
+
90
+ x = (x.astype(np.float32)-offset)/scale
91
+
92
+ return x
93
+
94
+ def rgb2gray(img):
95
+ '''
96
+ img: [B,3,H,W]/[3,H,W] (R,G,B) order
97
+ '''
98
+ if len(img.shape)==3:
99
+ s = np.reshape(np.array([0.299, 0.587, 0.114], dtype=np.float32),[3,1,1])
100
+ s = img.new_tensor(s)
101
+ img = img*s
102
+ img = torch.sum(img,dim=0,keepdim=True)
103
+ else:
104
+ s = np.reshape(np.array([0.299, 0.587, 0.114], dtype=np.float32),[1,3,1,1])
105
+ s = img.new_tensor(s)
106
+ img = img*s
107
+ img = torch.sum(img,dim=1,keepdim=True)
108
+
109
+ return img
110
+
111
+ def remove_prefix_from_state_dict(state_dict,prefix="module."):
112
+ res = {}
113
+ for k,v in state_dict.items():
114
+ if k.startswith(prefix):
115
+ k = k[len(prefix):]
116
+ res[k] = v
117
+ return res
118
+
119
+ def forgiving_state_restore(net, loaded_dict,verbose=False):
120
+ """
121
+ Handle partial loading when some tensors don't match up in size.
122
+ Because we want to use models that were trained off a different
123
+ number of classes.
124
+ """
125
+ ignore_key = ['num_batches_tracked']
126
+ def _is_ignore_key(k):
127
+ for v in ignore_key:
128
+ if v in k:
129
+ return True
130
+ return False
131
+
132
+ if 'state_dict' in loaded_dict:
133
+ loaded_dict = loaded_dict['state_dict']
134
+ if hasattr(net,'module'):
135
+ net = net.module
136
+ net_state_dict = net.state_dict()
137
+ new_loaded_dict = {}
138
+ used_loaded_dict_key = []
139
+ unloaded_net_state_key = []
140
+ for k in net_state_dict:
141
+ new_k = k
142
+ if new_k in loaded_dict and net_state_dict[k].size() == loaded_dict[new_k].size():
143
+ new_loaded_dict[k] = loaded_dict[new_k]
144
+ elif (not k.startswith('module.')) and 'module.'+k in loaded_dict and net_state_dict[k].size() == loaded_dict['module.'+new_k].size():
145
+ new_loaded_dict[k] = loaded_dict['module.'+new_k]
146
+ used_loaded_dict_key.append('module.'+new_k)
147
+ elif 'BN' in k and new_k.replace("BN","bn") in loaded_dict:
148
+ new_k = new_k.replace("BN","bn")
149
+ if net_state_dict[k].size() == loaded_dict[new_k].size():
150
+ new_loaded_dict[k] = loaded_dict[new_k]
151
+ used_loaded_dict_key.append(new_k)
152
+ elif ".num_batches_tracked" not in k:
153
+ print(f"Skipped loading parameter {k} {net_state_dict[k].shape}")
154
+ unloaded_net_state_key.append(k)
155
+
156
+ print(f"---------------------------------------------------")
157
+ for k in loaded_dict:
158
+ if k not in new_loaded_dict and k not in used_loaded_dict_key and not _is_ignore_key(k):
159
+ if k in net_state_dict:
160
+ print(f"Skip {k} in loaded dict, shape={loaded_dict[k].shape} vs {net_state_dict[k].shape} in model")
161
+ else:
162
+ print(f"Skip {k} in loaded dict, shape={loaded_dict[k].shape}")
163
+ if verbose:
164
+ print(f"---------------------------------------------------")
165
+ for k in new_loaded_dict:
166
+ print(f"Load {k}, shape={new_loaded_dict[k].shape}")
167
+ net_state_dict.update(new_loaded_dict)
168
+ net.load_state_dict(net_state_dict)
169
+ sys.stdout.flush()
170
+ print(f"Load checkpoint finish.")
171
+ return net,list(new_loaded_dict.keys()),unloaded_net_state_key
172
+
173
+ def sequence_mask(lengths,maxlen=None,dtype=torch.bool):
174
+ if not isinstance(lengths,torch.Tensor):
175
+ lengths = torch.from_numpy(np.array(lengths))
176
+ if maxlen is None:
177
+ maxlen = lengths.max()
178
+ if len(lengths.shape)==1:
179
+ lengths = torch.unsqueeze(lengths,axis=-1)
180
+ matrix = torch.arange(maxlen,dtype=lengths.dtype)[None,:]
181
+ mask = matrix<lengths
182
+ return mask
183
+
184
+
185
+ class TraceAmpWrape(torch.nn.Module):
186
+ def __init__(self, model) -> None:
187
+ super().__init__()
188
+ self.model = model
189
+
190
+ def forward(self, x):
191
+ with torch.no_grad():
192
+ with torch.cuda.amp.autocast():
193
+ return self.model(x)
194
+
195
+ def get_tensor_info(tensor):
196
+ tensor = tensor.detach().cpu().to(torch.float32)
197
+ return torch.mean(tensor).item(),torch.min(tensor).item(),torch.max(tensor).item(),torch.std(tensor).item()
198
+
199
+ def merge_imgs_heatmap(imgs,heat_map,scale=1.0,alpha=0.4,channel=None,min=None,max=None):
200
+ if not isinstance(heat_map,torch.Tensor):
201
+ heat_map = torch.from_numpy(heat_map)
202
+ if not isinstance(imgs,torch.Tensor):
203
+ imgs = torch.from_numpy(imgs)
204
+ if min is None:
205
+ min = torch.min(heat_map)
206
+ else:
207
+ heat_map = torch.maximum(heat_map,torch.Tensor([min]))
208
+
209
+ if max is None:
210
+ max = torch.max(heat_map)
211
+ else:
212
+ heat_map = torch.minimum(heat_map,torch.Tensor([max]))
213
+ heat_map = (heat_map-min)*scale/(max-min+1e-8)
214
+ if channel is not None and heat_map.shape[channel]==1:
215
+ t_zeros = torch.zeros_like(heat_map)
216
+ heat_map = torch.cat([heat_map,t_zeros,t_zeros],dim=channel)
217
+ new_imgs = imgs*(1-alpha)+heat_map*alpha
218
+ mask = heat_map>(scale*0.01)
219
+ #imgs = torch.where(mask,new_imgs,imgs)
220
+ imgs = new_imgs
221
+ return imgs
222
+
223
+ def module_parameters_numel(net,only_training=False):
224
+ total = 0
225
+ for param in net.parameters():
226
+ if only_training and param.requires_grad or not only_training:
227
+ total += torch.numel(param)
228
+ return total
229
+
230
+
231
+ def concat_datas(datas,dim=0):
232
+ if isinstance(datas[0], Mapping):
233
+ new_data = {}
234
+ for k,v in datas[0].items():
235
+ new_data[k] = [v]
236
+ for data in datas[1:]:
237
+ for k,v in data.items():
238
+ new_data[k].append(v)
239
+ keys = list(new_data.keys())
240
+ for k in keys:
241
+ new_data[k] = concat_datas(new_data[k],dim=dim)
242
+ return new_data
243
+
244
+ if torch.is_tensor(datas[0]):
245
+ return torch.cat(datas,dim=dim)
246
+ elif isinstance(datas[0],DC):
247
+ return concat_dc_datas(datas,dim)
248
+ elif isinstance(datas[0],Iterable):
249
+ res = []
250
+ try:
251
+ for x in zip(*datas):
252
+ if torch.is_tensor(x[0]):
253
+ res.append(torch.cat(x,dim=dim))
254
+ else:
255
+ res.append(concat_datas(x))
256
+ except Exception as e:
257
+ print(e)
258
+ for i,x in enumerate(datas):
259
+ print(i,type(x),x)
260
+ print(f"--------------------------")
261
+ for i,x in enumerate(datas):
262
+ print(i,type(x))
263
+ sys.stdout.flush()
264
+ raise e
265
+ return res
266
+ else:
267
+ return torch.cat(datas,dim=dim)
268
+
269
+ def concat_dc_datas(datas,cat_dim=0):
270
+ if isinstance(datas[0], DC):
271
+ stacked = []
272
+ if datas[0].cpu_only:
273
+ for i in range(0, len(datas)):
274
+ for sample in datas[i].data:
275
+ stacked.extend(sample)
276
+ return DC(
277
+ [stacked], datas[0].stack, datas[0].padding_value, cpu_only=True)
278
+ elif datas[0].stack:
279
+ batch = []
280
+ for d in datas:
281
+ batch.extend(d.data)
282
+ pad_dims = datas[0].pad_dims
283
+ padding_value =datas[0].padding_value
284
+ max_shape = [0 for _ in range(pad_dims)]
285
+ for sample in batch:
286
+ for dim in range(1, pad_dims + 1):
287
+ max_shape[dim - 1] = max(max_shape[dim - 1],
288
+ sample.size(-dim))
289
+
290
+ for i in range(0, len(batch)):
291
+ assert isinstance(batch[i], torch.Tensor)
292
+
293
+ if pad_dims is not None:
294
+ pad = [0 for _ in range(pad_dims * 2)]
295
+ sample = batch[i]
296
+ for dim in range(1, pad_dims + 1):
297
+ pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim)
298
+ stacked.append(
299
+ F.pad(sample, pad, value=padding_value))
300
+ elif pad_dims is None:
301
+ stacked.append(batch)
302
+ else:
303
+ raise ValueError(
304
+ 'pad_dims should be either None or integers (1-3)')
305
+ stacked = torch.cat(stacked,dim=cat_dim)
306
+ return DC([stacked], datas[0].stack, datas[0].padding_value)
307
+ else:
308
+ for i in range(0, len(datas)):
309
+ for sample in datas[i].data:
310
+ stacked.extend(sample)
311
+ return DC([stacked], datas[0].stack, datas[0].padding_value)
312
+ else:
313
+ raise RuntimeError(f"ERROR concat dc type {type(datas[0])}")
314
+
315
+
316
+ def get_model(model):
317
+ if hasattr(model, "module"):
318
+ model = model.module
319
+ return model
320
+
321
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
322
+
323
+ '''
324
+ fea:[B,C,H,W]
325
+ size:(w,h)
326
+ '''
327
+ CENTER_PAD = 0
328
+ RANDOM_PAD = 1
329
+ TOPLEFT_PAD = 2
330
+ def pad_feature(fea, size, pad_value=0, pad_type=TOPLEFT_PAD, return_pad_value=False):
331
+ '''
332
+ pad_type: 0, center pad
333
+ pad_type: 1, random pad
334
+ pad_type: 2, topleft_pad
335
+ '''
336
+ w = fea.shape[-1]
337
+ h = fea.shape[-2]
338
+ if pad_type == 0:
339
+ if h < size[1]:
340
+ py0 = (size[1] - h) // 2
341
+ py1 = size[1] - h - py0
342
+ else:
343
+ py0 = 0
344
+ py1 = 0
345
+ if w < size[0]:
346
+ px0 = (size[0] - w) // 2
347
+ px1 = size[0] - w - px0
348
+ else:
349
+ px0 = 0
350
+ px1 = 0
351
+ elif pad_type == 1:
352
+ if h < size[1]:
353
+ py0 = random.randint(0, size[1] - h)
354
+ py1 = size[1] - h - py0
355
+ else:
356
+ py0 = 0
357
+ py1 = 0
358
+ if w < size[0]:
359
+ px0 = random.randint(0, size[0] - w)
360
+ px1 = size[0] - w - px0
361
+ else:
362
+ px0 = 0
363
+ px1 = 0
364
+ elif pad_type == 2:
365
+ if h < size[1]:
366
+ py0 = 0
367
+ py1 = size[1] - h - py0
368
+ else:
369
+ py0 = 0
370
+ py1 = 0
371
+ if w < size[0]:
372
+ px0 = 0
373
+ px1 = size[0] - w - px0
374
+ else:
375
+ px0 = 0
376
+ px1 = 0
377
+
378
+ if isinstance(pad_value,Iterable):
379
+ pad_value = pad_value[0]
380
+ fea = F.pad(fea, [px0, px1,py0,py1], "constant", pad_value)
381
+
382
+ if return_pad_value:
383
+ return fea, px0, px1, py0, py1
384
+ return fea
385
+
386
+ def split_forward_batch32(func):
387
+ @wraps(func)
388
+ def wrapper(self, data):
389
+ step = 32
390
+ res = []
391
+ cur_idx = 0
392
+ while cur_idx<data.shape[0]:
393
+ ret_val = func(self, data[cur_idx:cur_idx+step])
394
+ cur_idx += step
395
+ res.append(ret_val)
396
+ if len(res)==1:
397
+ return res[0]
398
+ if torch.is_tensor(res[0]):
399
+ return torch.cat(res,dim=0)
400
+ else:
401
+ return np.concatenate(res,axis=0)
402
+ return wrapper
403
+
404
+ def to(data,device=torch.device("cpu")):
405
+ if torch.is_tensor(data):
406
+ return data.to(device)
407
+ elif isinstance(data,(CfgNode,ConfigDict)):
408
+ return data
409
+ elif isinstance(data,dict):
410
+ keys = list(data.keys())
411
+ new_data = {}
412
+ for k in keys:
413
+ new_data[k] = to(data[k],device)
414
+ elif isinstance(data,(list,tuple)):
415
+ new_data = []
416
+ for v in data:
417
+ new_data.append(to(v,device))
418
+ new_data = type(data)(new_data)
419
+ elif not isinstance(data,Iterable):
420
+ if not isinstance(data,torch.nn.Module) and hasattr(data,"to"):
421
+ data = data.to(device)
422
+ return data
423
+ elif isinstance(data,(np.ndarray,str,bytes)):
424
+ return data
425
+ else:
426
+ print(f"Unsupport type {type(data)}")
427
+
428
+ return new_data
429
+
430
+ def cpu(data):
431
+ return to(data,device=torch.device("cpu"))
432
+
433
+ def cuda(data):
434
+ return to(data,device=torch.device("cuda"))
435
+
436
+ def cpu_wraps(func):
437
+ @wraps(func)
438
+ def wraps_func(*args,**kwargs):
439
+ args = cpu(args)
440
+ kwargs = cpu(kwargs)
441
+ res = func(*args,**kwargs)
442
+ res = cuda(res)
443
+ return res
444
+ return wraps_func
445
+
446
+ def cpu_cpu_wraps(func):
447
+ @wraps(func)
448
+ def wraps_func(*args,**kwargs):
449
+ args = cpu(args)
450
+ kwargs = cpu(kwargs)
451
+ res = func(*args,**kwargs)
452
+ return res
453
+ return wraps_func
454
+
455
+ def numpy(data):
456
+ if torch.is_tensor(data):
457
+ return data.cpu().numpy()
458
+ elif isinstance(data,dict):
459
+ keys = list(data.keys())
460
+ new_data = {}
461
+ for k in keys:
462
+ new_data[k] = numpy(data[k])
463
+ elif isinstance(data,(list,tuple)):
464
+ new_data = []
465
+ for v in data:
466
+ new_data.append(numpy(v))
467
+ new_data = type(data)(new_data)
468
+ elif not isinstance(data,Iterable):
469
+ return data
470
+ elif isinstance(data,np.ndarray):
471
+ return data
472
+ else:
473
+ print(f"Unsupport type {type(data)}")
474
+
475
+ return new_data
476
+
477
+ def sparse_gather(data,index,return_tensor=True):
478
+ '''
479
+ data: list of tensor (mybe different length)
480
+ '''
481
+ res = []
482
+ for i,d in enumerate(data):
483
+ res.append(d[index[i]])
484
+ if return_tensor:
485
+ return torch.stack(res,dim=0)
486
+ else:
487
+ return res
488
+
489
+ def simple_model_device(model):
490
+ return next(model.parameters()).device
491
+
492
+ def resize_mask(mask,size=None,r=None):
493
+ '''
494
+ mask: [N,H,W]
495
+ size: (new_w,new_h)
496
+ '''
497
+ if size is None:
498
+ size = (int(mask.shape[2]*r),int(mask.shape[1]*r))
499
+ if mask.numel()==0:
500
+ return mask.new_zeros([mask.shape[0],size[1],size[0]])
501
+
502
+ mask = torch.unsqueeze(mask,dim=0)
503
+ mask = torch.nn.functional.interpolate(mask,size=(size[1],size[0]),mode='nearest')
504
+ mask = torch.squeeze(mask,dim=0)
505
+ return mask
506
+
507
+ def npresize_mask(mask,size=None,r=None):
508
+ '''
509
+ mask: [N,H,W]
510
+ size: (new_w,new_h)
511
+ '''
512
+ if mask.shape[0]==0:
513
+ return np.zeros([0,size[1],size[0]],dtype=mask.dtype)
514
+ if mask.shape[0]==1:
515
+ cur_m = cv2.resize(mask[0],dsize=(size[0],size[1]),interpolation=cv2.INTER_NEAREST)
516
+ return np.expand_dims(cur_m,axis=0)
517
+ mask = resize_mask(torch.from_numpy(mask),size,r)
518
+ return mask.numpy()
519
+
520
+
521
+
522
+ def __correct_bboxes(bboxes,h,w):
523
+ old_type = bboxes.dtype
524
+ bboxes = np.maximum(bboxes,0)
525
+ bboxes = np.minimum(bboxes,np.array([[w,h,w,h]]))
526
+ return bboxes.astype(old_type)
527
+
528
+ def npresize_mask_in_bboxes(mask,bboxes,size=None,r=None):
529
+ '''
530
+ mask: [N,H,W]
531
+ bboxes: [N,4](x0,y0,x1,y1)
532
+ size: (new_w,new_h)
533
+ '''
534
+ if isinstance(mask,(WPolygonMasks,WBitmapMasks,WMCKeypoints)):
535
+ return mask.resize_mask_in_bboxes(bboxes,size=size,r=r)
536
+ if mask.shape[0]==0:
537
+ return np.zeros([0,size[1],size[0]],dtype=mask.dtype),np.zeros([0,4],dtype=bboxes.dtype)
538
+ x_scale = size[0]/mask.shape[2]
539
+ y_scale = size[1]/mask.shape[1]
540
+ bboxes = __correct_bboxes(bboxes,h=mask.shape[1],w=mask.shape[2])
541
+ resized_bboxes = (bboxes*np.array([[x_scale,y_scale,x_scale,y_scale]])).astype(np.int32)
542
+ resized_bboxes = __correct_bboxes(resized_bboxes,h=size[1],w=size[0])
543
+ bboxes = np.array(bboxes).astype(np.int32)
544
+ res_mask = np.zeros([mask.shape[0],size[1],size[0]],dtype=mask.dtype)
545
+ for i in range(mask.shape[0]):
546
+ dbbox = resized_bboxes[i]
547
+ dsize = (dbbox[2]-dbbox[0],dbbox[3]-dbbox[1])
548
+ if dsize[0]<=1 or dsize[1]<=1:
549
+ continue
550
+ sub_mask = wmli.crop_img_absolute_xy(mask[i],bboxes[i])
551
+ cur_m = cv2.resize(sub_mask,dsize=dsize,interpolation=cv2.INTER_NEAREST)
552
+ wmli.set_subimg(res_mask[i],cur_m,dbbox[:2])
553
+ return res_mask,resized_bboxes
554
+
555
+ def __time_npresize_mask_in_bboxes(mask,bboxes,size=None,r=None):
556
+ t = wmlu.TimeThis()
557
+ b = npresize_mask(mask,size,r)
558
+ t0 = t.time(reset=True)
559
+ a = npresize_mask_in_bboxes(mask,bboxes,size,r)
560
+ t1 = t.time(reset=True)
561
+ c = __npresize_mask(mask,size,r)
562
+ t2 = t.time(reset=True)
563
+ print(f"RM,{t0},{t1},{t2}")
564
+ return a
565
+
566
+ def clone_tensors(x):
567
+ if isinstance(x,(list,tuple)):
568
+ return [v.clone() for v in x]
569
+ return x.clone()
570
+
571
+ def _trunc_normal_(tensor, mean, std, a, b):
572
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
573
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
574
+ def norm_cdf(x):
575
+ # Computes standard normal cumulative distribution function
576
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
577
+
578
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
579
+ print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
580
+ "The distribution of values may be incorrect.")
581
+
582
+ # Values are generated by using a truncated uniform distribution and
583
+ # then using the inverse CDF for the normal distribution.
584
+ # Get upper and lower cdf values
585
+ l = norm_cdf((a - mean) / std)
586
+ u = norm_cdf((b - mean) / std)
587
+
588
+ # Uniformly fill tensor with values from [l, u], then translate to
589
+ # [2l-1, 2u-1].
590
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
591
+
592
+ # Use inverse cdf transform for normal distribution to get truncated
593
+ # standard normal
594
+ tensor.erfinv_()
595
+
596
+ # Transform to proper mean, std
597
+ tensor.mul_(std * math.sqrt(2.))
598
+ tensor.add_(mean)
599
+
600
+ # Clamp to ensure it's in the proper range
601
+ tensor.clamp_(min=a, max=b)
602
+ return tensor
603
+
604
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
605
+ # type: (Tensor, float, float, float, float) -> Tensor
606
+ r"""Fills the input Tensor with values drawn from a truncated
607
+ normal distribution. The values are effectively drawn from the
608
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
609
+ with values outside :math:`[a, b]` redrawn until they are within
610
+ the bounds. The method used for generating the random values works
611
+ best when :math:`a \leq \text{mean} \leq b`.
612
+
613
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
614
+ applied while sampling the normal with mean/std applied, therefore a, b args
615
+ should be adjusted to match the range of mean, std args.
616
+
617
+ Args:
618
+ tensor: an n-dimensional `torch.Tensor`
619
+ mean: the mean of the normal distribution
620
+ std: the standard deviation of the normal distribution
621
+ a: the minimum cutoff value
622
+ b: the maximum cutoff value
623
+ Examples:
624
+ >>> w = torch.empty(3, 5)
625
+ >>> nn.init.trunc_normal_(w)
626
+ """
627
+ with torch.no_grad():
628
+ return _trunc_normal_(tensor, mean, std, a, b)
629
+
630
+
631
+ def embedding_version2scores(scores,version,exponent=2):
632
+ assert version>=0 and version<100,f"ERROR: version need in range [0,100)"
633
+ scale = math.pow(10,exponent)
634
+ scores = (scores*scale).to(torch.int32).to(torch.float32)
635
+ version = version/100
636
+ scores = (scores+version)/scale
637
+ return scores
638
+
639
+ def embedding_version2coord(coord,version,exponent=0):
640
+ assert version>=0 and version<100,f"ERROR: version need in range [0,100)"
641
+ scale = math.pow(10,exponent)
642
+ coord = (coord*scale).to(torch.int32).to(torch.float32)
643
+ version = version/100
644
+ coord = (coord+version)/scale
645
+ return coord
646
+
647
+
648
+ def add_version2onnx(onnx_path,save_path,version):
649
+ model_proto = onnx.load(onnx_path)
650
+ #graph_proto = model_proto.graph
651
+ #model_metadata = {}
652
+ # 添加元数据
653
+ model_proto.metadata_props.extend([
654
+ onnx.helper.make_string_initializer(
655
+ 'model_version',
656
+ onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[type(version)],
657
+ [1],
658
+ [version],
659
+ )
660
+ ])
661
+ if save_path is None:
662
+ save_path = onnx_path
663
+ onnx.save(model_proto,save_path)
664
+
665
+ def add_metadta2onnx(model_onnx,metadata):
666
+ #model_onnx = onnx.load(f) # load onnx model
667
+ for k, v in metadata.items():
668
+ meta = model_onnx.metadata_props.add()
669
+ meta.key, meta.value = k, str(v)
670
+
671
+ return model_onnx
672
+
673
+
674
+ class SafeClass:
675
+ """A placeholder class to replace unknown classes during unpickling."""
676
+
677
+ def __init__(self, *args, **kwargs):
678
+ """Initialize SafeClass instance, ignoring all arguments."""
679
+ pass
680
+
681
+ def __call__(self, *args, **kwargs):
682
+ """Run SafeClass instance, ignoring all arguments."""
683
+ pass
684
+
685
+
686
+ class SafeUnpickler(pickle.Unpickler):
687
+ """Custom Unpickler that replaces unknown classes with SafeClass."""
688
+
689
+ def find_class(self, module, name):
690
+ """Attempt to find a class, returning SafeClass if not among safe modules."""
691
+ safe_modules = (
692
+ "torch",
693
+ "collections",
694
+ "collections.abc",
695
+ "builtins",
696
+ "math",
697
+ "numpy",
698
+ # Add other modules considered safe
699
+ )
700
+ if module in safe_modules:
701
+ return super().find_class(module, name)
702
+ else:
703
+ return SafeClass
704
+
705
+ def safe_load(file,*args,**kwargs):
706
+ # Load via custom pickle module
707
+ safe_pickle = types.ModuleType("safe_pickle")
708
+ safe_pickle.Unpickler = SafeUnpickler
709
+ safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
710
+ with open(file, "rb") as f:
711
+ ckpt = torch.load(f, pickle_module=safe_pickle,*args,**kwargs)
712
+ return ckpt
713
+
714
+ def load(file,*args,**kwargs):
715
+ try:
716
+ return torch.load(file,*args,**kwargs)
717
+ except Exception as e:
718
+ print(f"WARNING: load ckpt {file} faild, info: {e}, try safe load...")
719
+ return safe_load(file)