python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/utils.py
ADDED
|
@@ -0,0 +1,719 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from collections import Iterable
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
import random
|
|
6
|
+
import sys
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from collections.abc import Mapping, Sequence
|
|
9
|
+
import wml.wml_utils as wmlu
|
|
10
|
+
import wml.img_utils as wmli
|
|
11
|
+
import cv2
|
|
12
|
+
from wml.thirdparty.config import CfgNode
|
|
13
|
+
from wml.wstructures import WPolygonMasks,WBitmapMasks, WMCKeypoints
|
|
14
|
+
from wml.semantic.basic_toolkit import *
|
|
15
|
+
from itertools import repeat
|
|
16
|
+
import collections.abc
|
|
17
|
+
import math
|
|
18
|
+
import onnx
|
|
19
|
+
import pickle
|
|
20
|
+
import types
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from mmcv.parallel import DataContainer as DC
|
|
24
|
+
from mmcv.utils.config import ConfigDict
|
|
25
|
+
except:
|
|
26
|
+
DC = None
|
|
27
|
+
ConfigDict = wmlu.AlwaysNullObj
|
|
28
|
+
|
|
29
|
+
def _ntuple(n):
|
|
30
|
+
def parse(x):
|
|
31
|
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
32
|
+
return tuple(x)
|
|
33
|
+
return tuple(repeat(x, n))
|
|
34
|
+
return parse
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
to_1tuple = _ntuple(1)
|
|
38
|
+
to_2tuple = _ntuple(2)
|
|
39
|
+
to_3tuple = _ntuple(3)
|
|
40
|
+
to_4tuple = _ntuple(4)
|
|
41
|
+
to_ntuple = _ntuple
|
|
42
|
+
|
|
43
|
+
def unnormalize(x:torch.Tensor,mean=[0.0,0.0,0.0],std=[1.0,1.0,1.0]):
|
|
44
|
+
if len(x.size())==4:
|
|
45
|
+
C = x.shape[1]
|
|
46
|
+
scale = np.reshape(np.array(std,dtype=np.float32),[1,C,1,1])
|
|
47
|
+
offset = np.reshape(np.array(mean,dtype=np.float32),[1,C,1,1])
|
|
48
|
+
elif len(x.size())==5:
|
|
49
|
+
C = x.shape[2]
|
|
50
|
+
scale = np.reshape(np.array(std, dtype=np.float32), [1, 1,C, 1, 1])
|
|
51
|
+
offset = np.reshape(np.array(mean, dtype=np.float32), [1,1, C, 1, 1])
|
|
52
|
+
elif len(x.size())==3:
|
|
53
|
+
C = x.shape[0]
|
|
54
|
+
scale = np.reshape(np.array(std, dtype=np.float32), [C, 1, 1])
|
|
55
|
+
offset = np.reshape(np.array(mean, dtype=np.float32), [C, 1, 1])
|
|
56
|
+
|
|
57
|
+
offset = torch.from_numpy(offset).to(x.device)
|
|
58
|
+
scale = torch.from_numpy(scale).to(x.device)
|
|
59
|
+
x = x*scale+offset
|
|
60
|
+
return x
|
|
61
|
+
|
|
62
|
+
def normalize(x:torch.Tensor,mean=[0.0,0.0,0.0],std=[1.0,1.0,1.0]):
|
|
63
|
+
channel = len(mean)
|
|
64
|
+
if len(x.size())==4:
|
|
65
|
+
scale = np.reshape(np.array(std,dtype=np.float32),[1,channel,1,1])
|
|
66
|
+
offset = np.reshape(np.array(mean,dtype=np.float32),[1,channel,1,1])
|
|
67
|
+
elif len(x.size())==5:
|
|
68
|
+
scale = np.reshape(np.array(std, dtype=np.float32), [1, 1,channel, 1, 1])
|
|
69
|
+
offset = np.reshape(np.array(mean, dtype=np.float32), [1,1, channel, 1, 1])
|
|
70
|
+
elif len(x.size())==3:
|
|
71
|
+
scale = np.reshape(np.array(std, dtype=np.float32), [channel, 1, 1])
|
|
72
|
+
offset = np.reshape(np.array(mean, dtype=np.float32), [channel, 1, 1])
|
|
73
|
+
|
|
74
|
+
offset = torch.from_numpy(offset).to(x.device)
|
|
75
|
+
scale = torch.from_numpy(scale).to(x.device)
|
|
76
|
+
x = (x-offset)/scale
|
|
77
|
+
return x
|
|
78
|
+
|
|
79
|
+
def npnormalize(x:np.ndarray,mean=[0.0,0.0,0.0],std=[1.0,1.0,1.0]):
|
|
80
|
+
if len(x.shape)==4:
|
|
81
|
+
scale = np.reshape(np.array(std,dtype=np.float32),[1,3,1,1])
|
|
82
|
+
offset = np.reshape(np.array(mean,dtype=np.float32),[1,3,1,1])
|
|
83
|
+
elif len(x.shape)==5:
|
|
84
|
+
scale = np.reshape(np.array(std, dtype=np.float32), [1, 1,3, 1, 1])
|
|
85
|
+
offset = np.reshape(np.array(mean, dtype=np.float32), [1,1, 3, 1, 1])
|
|
86
|
+
elif len(x.shape)==3:
|
|
87
|
+
scale = np.reshape(np.array(std, dtype=np.float32), [3, 1, 1])
|
|
88
|
+
offset = np.reshape(np.array(mean, dtype=np.float32), [3, 1, 1])
|
|
89
|
+
|
|
90
|
+
x = (x.astype(np.float32)-offset)/scale
|
|
91
|
+
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
def rgb2gray(img):
|
|
95
|
+
'''
|
|
96
|
+
img: [B,3,H,W]/[3,H,W] (R,G,B) order
|
|
97
|
+
'''
|
|
98
|
+
if len(img.shape)==3:
|
|
99
|
+
s = np.reshape(np.array([0.299, 0.587, 0.114], dtype=np.float32),[3,1,1])
|
|
100
|
+
s = img.new_tensor(s)
|
|
101
|
+
img = img*s
|
|
102
|
+
img = torch.sum(img,dim=0,keepdim=True)
|
|
103
|
+
else:
|
|
104
|
+
s = np.reshape(np.array([0.299, 0.587, 0.114], dtype=np.float32),[1,3,1,1])
|
|
105
|
+
s = img.new_tensor(s)
|
|
106
|
+
img = img*s
|
|
107
|
+
img = torch.sum(img,dim=1,keepdim=True)
|
|
108
|
+
|
|
109
|
+
return img
|
|
110
|
+
|
|
111
|
+
def remove_prefix_from_state_dict(state_dict,prefix="module."):
|
|
112
|
+
res = {}
|
|
113
|
+
for k,v in state_dict.items():
|
|
114
|
+
if k.startswith(prefix):
|
|
115
|
+
k = k[len(prefix):]
|
|
116
|
+
res[k] = v
|
|
117
|
+
return res
|
|
118
|
+
|
|
119
|
+
def forgiving_state_restore(net, loaded_dict,verbose=False):
|
|
120
|
+
"""
|
|
121
|
+
Handle partial loading when some tensors don't match up in size.
|
|
122
|
+
Because we want to use models that were trained off a different
|
|
123
|
+
number of classes.
|
|
124
|
+
"""
|
|
125
|
+
ignore_key = ['num_batches_tracked']
|
|
126
|
+
def _is_ignore_key(k):
|
|
127
|
+
for v in ignore_key:
|
|
128
|
+
if v in k:
|
|
129
|
+
return True
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
if 'state_dict' in loaded_dict:
|
|
133
|
+
loaded_dict = loaded_dict['state_dict']
|
|
134
|
+
if hasattr(net,'module'):
|
|
135
|
+
net = net.module
|
|
136
|
+
net_state_dict = net.state_dict()
|
|
137
|
+
new_loaded_dict = {}
|
|
138
|
+
used_loaded_dict_key = []
|
|
139
|
+
unloaded_net_state_key = []
|
|
140
|
+
for k in net_state_dict:
|
|
141
|
+
new_k = k
|
|
142
|
+
if new_k in loaded_dict and net_state_dict[k].size() == loaded_dict[new_k].size():
|
|
143
|
+
new_loaded_dict[k] = loaded_dict[new_k]
|
|
144
|
+
elif (not k.startswith('module.')) and 'module.'+k in loaded_dict and net_state_dict[k].size() == loaded_dict['module.'+new_k].size():
|
|
145
|
+
new_loaded_dict[k] = loaded_dict['module.'+new_k]
|
|
146
|
+
used_loaded_dict_key.append('module.'+new_k)
|
|
147
|
+
elif 'BN' in k and new_k.replace("BN","bn") in loaded_dict:
|
|
148
|
+
new_k = new_k.replace("BN","bn")
|
|
149
|
+
if net_state_dict[k].size() == loaded_dict[new_k].size():
|
|
150
|
+
new_loaded_dict[k] = loaded_dict[new_k]
|
|
151
|
+
used_loaded_dict_key.append(new_k)
|
|
152
|
+
elif ".num_batches_tracked" not in k:
|
|
153
|
+
print(f"Skipped loading parameter {k} {net_state_dict[k].shape}")
|
|
154
|
+
unloaded_net_state_key.append(k)
|
|
155
|
+
|
|
156
|
+
print(f"---------------------------------------------------")
|
|
157
|
+
for k in loaded_dict:
|
|
158
|
+
if k not in new_loaded_dict and k not in used_loaded_dict_key and not _is_ignore_key(k):
|
|
159
|
+
if k in net_state_dict:
|
|
160
|
+
print(f"Skip {k} in loaded dict, shape={loaded_dict[k].shape} vs {net_state_dict[k].shape} in model")
|
|
161
|
+
else:
|
|
162
|
+
print(f"Skip {k} in loaded dict, shape={loaded_dict[k].shape}")
|
|
163
|
+
if verbose:
|
|
164
|
+
print(f"---------------------------------------------------")
|
|
165
|
+
for k in new_loaded_dict:
|
|
166
|
+
print(f"Load {k}, shape={new_loaded_dict[k].shape}")
|
|
167
|
+
net_state_dict.update(new_loaded_dict)
|
|
168
|
+
net.load_state_dict(net_state_dict)
|
|
169
|
+
sys.stdout.flush()
|
|
170
|
+
print(f"Load checkpoint finish.")
|
|
171
|
+
return net,list(new_loaded_dict.keys()),unloaded_net_state_key
|
|
172
|
+
|
|
173
|
+
def sequence_mask(lengths,maxlen=None,dtype=torch.bool):
|
|
174
|
+
if not isinstance(lengths,torch.Tensor):
|
|
175
|
+
lengths = torch.from_numpy(np.array(lengths))
|
|
176
|
+
if maxlen is None:
|
|
177
|
+
maxlen = lengths.max()
|
|
178
|
+
if len(lengths.shape)==1:
|
|
179
|
+
lengths = torch.unsqueeze(lengths,axis=-1)
|
|
180
|
+
matrix = torch.arange(maxlen,dtype=lengths.dtype)[None,:]
|
|
181
|
+
mask = matrix<lengths
|
|
182
|
+
return mask
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class TraceAmpWrape(torch.nn.Module):
|
|
186
|
+
def __init__(self, model) -> None:
|
|
187
|
+
super().__init__()
|
|
188
|
+
self.model = model
|
|
189
|
+
|
|
190
|
+
def forward(self, x):
|
|
191
|
+
with torch.no_grad():
|
|
192
|
+
with torch.cuda.amp.autocast():
|
|
193
|
+
return self.model(x)
|
|
194
|
+
|
|
195
|
+
def get_tensor_info(tensor):
|
|
196
|
+
tensor = tensor.detach().cpu().to(torch.float32)
|
|
197
|
+
return torch.mean(tensor).item(),torch.min(tensor).item(),torch.max(tensor).item(),torch.std(tensor).item()
|
|
198
|
+
|
|
199
|
+
def merge_imgs_heatmap(imgs,heat_map,scale=1.0,alpha=0.4,channel=None,min=None,max=None):
|
|
200
|
+
if not isinstance(heat_map,torch.Tensor):
|
|
201
|
+
heat_map = torch.from_numpy(heat_map)
|
|
202
|
+
if not isinstance(imgs,torch.Tensor):
|
|
203
|
+
imgs = torch.from_numpy(imgs)
|
|
204
|
+
if min is None:
|
|
205
|
+
min = torch.min(heat_map)
|
|
206
|
+
else:
|
|
207
|
+
heat_map = torch.maximum(heat_map,torch.Tensor([min]))
|
|
208
|
+
|
|
209
|
+
if max is None:
|
|
210
|
+
max = torch.max(heat_map)
|
|
211
|
+
else:
|
|
212
|
+
heat_map = torch.minimum(heat_map,torch.Tensor([max]))
|
|
213
|
+
heat_map = (heat_map-min)*scale/(max-min+1e-8)
|
|
214
|
+
if channel is not None and heat_map.shape[channel]==1:
|
|
215
|
+
t_zeros = torch.zeros_like(heat_map)
|
|
216
|
+
heat_map = torch.cat([heat_map,t_zeros,t_zeros],dim=channel)
|
|
217
|
+
new_imgs = imgs*(1-alpha)+heat_map*alpha
|
|
218
|
+
mask = heat_map>(scale*0.01)
|
|
219
|
+
#imgs = torch.where(mask,new_imgs,imgs)
|
|
220
|
+
imgs = new_imgs
|
|
221
|
+
return imgs
|
|
222
|
+
|
|
223
|
+
def module_parameters_numel(net,only_training=False):
|
|
224
|
+
total = 0
|
|
225
|
+
for param in net.parameters():
|
|
226
|
+
if only_training and param.requires_grad or not only_training:
|
|
227
|
+
total += torch.numel(param)
|
|
228
|
+
return total
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def concat_datas(datas,dim=0):
|
|
232
|
+
if isinstance(datas[0], Mapping):
|
|
233
|
+
new_data = {}
|
|
234
|
+
for k,v in datas[0].items():
|
|
235
|
+
new_data[k] = [v]
|
|
236
|
+
for data in datas[1:]:
|
|
237
|
+
for k,v in data.items():
|
|
238
|
+
new_data[k].append(v)
|
|
239
|
+
keys = list(new_data.keys())
|
|
240
|
+
for k in keys:
|
|
241
|
+
new_data[k] = concat_datas(new_data[k],dim=dim)
|
|
242
|
+
return new_data
|
|
243
|
+
|
|
244
|
+
if torch.is_tensor(datas[0]):
|
|
245
|
+
return torch.cat(datas,dim=dim)
|
|
246
|
+
elif isinstance(datas[0],DC):
|
|
247
|
+
return concat_dc_datas(datas,dim)
|
|
248
|
+
elif isinstance(datas[0],Iterable):
|
|
249
|
+
res = []
|
|
250
|
+
try:
|
|
251
|
+
for x in zip(*datas):
|
|
252
|
+
if torch.is_tensor(x[0]):
|
|
253
|
+
res.append(torch.cat(x,dim=dim))
|
|
254
|
+
else:
|
|
255
|
+
res.append(concat_datas(x))
|
|
256
|
+
except Exception as e:
|
|
257
|
+
print(e)
|
|
258
|
+
for i,x in enumerate(datas):
|
|
259
|
+
print(i,type(x),x)
|
|
260
|
+
print(f"--------------------------")
|
|
261
|
+
for i,x in enumerate(datas):
|
|
262
|
+
print(i,type(x))
|
|
263
|
+
sys.stdout.flush()
|
|
264
|
+
raise e
|
|
265
|
+
return res
|
|
266
|
+
else:
|
|
267
|
+
return torch.cat(datas,dim=dim)
|
|
268
|
+
|
|
269
|
+
def concat_dc_datas(datas,cat_dim=0):
|
|
270
|
+
if isinstance(datas[0], DC):
|
|
271
|
+
stacked = []
|
|
272
|
+
if datas[0].cpu_only:
|
|
273
|
+
for i in range(0, len(datas)):
|
|
274
|
+
for sample in datas[i].data:
|
|
275
|
+
stacked.extend(sample)
|
|
276
|
+
return DC(
|
|
277
|
+
[stacked], datas[0].stack, datas[0].padding_value, cpu_only=True)
|
|
278
|
+
elif datas[0].stack:
|
|
279
|
+
batch = []
|
|
280
|
+
for d in datas:
|
|
281
|
+
batch.extend(d.data)
|
|
282
|
+
pad_dims = datas[0].pad_dims
|
|
283
|
+
padding_value =datas[0].padding_value
|
|
284
|
+
max_shape = [0 for _ in range(pad_dims)]
|
|
285
|
+
for sample in batch:
|
|
286
|
+
for dim in range(1, pad_dims + 1):
|
|
287
|
+
max_shape[dim - 1] = max(max_shape[dim - 1],
|
|
288
|
+
sample.size(-dim))
|
|
289
|
+
|
|
290
|
+
for i in range(0, len(batch)):
|
|
291
|
+
assert isinstance(batch[i], torch.Tensor)
|
|
292
|
+
|
|
293
|
+
if pad_dims is not None:
|
|
294
|
+
pad = [0 for _ in range(pad_dims * 2)]
|
|
295
|
+
sample = batch[i]
|
|
296
|
+
for dim in range(1, pad_dims + 1):
|
|
297
|
+
pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim)
|
|
298
|
+
stacked.append(
|
|
299
|
+
F.pad(sample, pad, value=padding_value))
|
|
300
|
+
elif pad_dims is None:
|
|
301
|
+
stacked.append(batch)
|
|
302
|
+
else:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
'pad_dims should be either None or integers (1-3)')
|
|
305
|
+
stacked = torch.cat(stacked,dim=cat_dim)
|
|
306
|
+
return DC([stacked], datas[0].stack, datas[0].padding_value)
|
|
307
|
+
else:
|
|
308
|
+
for i in range(0, len(datas)):
|
|
309
|
+
for sample in datas[i].data:
|
|
310
|
+
stacked.extend(sample)
|
|
311
|
+
return DC([stacked], datas[0].stack, datas[0].padding_value)
|
|
312
|
+
else:
|
|
313
|
+
raise RuntimeError(f"ERROR concat dc type {type(datas[0])}")
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def get_model(model):
|
|
317
|
+
if hasattr(model, "module"):
|
|
318
|
+
model = model.module
|
|
319
|
+
return model
|
|
320
|
+
|
|
321
|
+
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
|
322
|
+
|
|
323
|
+
'''
|
|
324
|
+
fea:[B,C,H,W]
|
|
325
|
+
size:(w,h)
|
|
326
|
+
'''
|
|
327
|
+
CENTER_PAD = 0
|
|
328
|
+
RANDOM_PAD = 1
|
|
329
|
+
TOPLEFT_PAD = 2
|
|
330
|
+
def pad_feature(fea, size, pad_value=0, pad_type=TOPLEFT_PAD, return_pad_value=False):
|
|
331
|
+
'''
|
|
332
|
+
pad_type: 0, center pad
|
|
333
|
+
pad_type: 1, random pad
|
|
334
|
+
pad_type: 2, topleft_pad
|
|
335
|
+
'''
|
|
336
|
+
w = fea.shape[-1]
|
|
337
|
+
h = fea.shape[-2]
|
|
338
|
+
if pad_type == 0:
|
|
339
|
+
if h < size[1]:
|
|
340
|
+
py0 = (size[1] - h) // 2
|
|
341
|
+
py1 = size[1] - h - py0
|
|
342
|
+
else:
|
|
343
|
+
py0 = 0
|
|
344
|
+
py1 = 0
|
|
345
|
+
if w < size[0]:
|
|
346
|
+
px0 = (size[0] - w) // 2
|
|
347
|
+
px1 = size[0] - w - px0
|
|
348
|
+
else:
|
|
349
|
+
px0 = 0
|
|
350
|
+
px1 = 0
|
|
351
|
+
elif pad_type == 1:
|
|
352
|
+
if h < size[1]:
|
|
353
|
+
py0 = random.randint(0, size[1] - h)
|
|
354
|
+
py1 = size[1] - h - py0
|
|
355
|
+
else:
|
|
356
|
+
py0 = 0
|
|
357
|
+
py1 = 0
|
|
358
|
+
if w < size[0]:
|
|
359
|
+
px0 = random.randint(0, size[0] - w)
|
|
360
|
+
px1 = size[0] - w - px0
|
|
361
|
+
else:
|
|
362
|
+
px0 = 0
|
|
363
|
+
px1 = 0
|
|
364
|
+
elif pad_type == 2:
|
|
365
|
+
if h < size[1]:
|
|
366
|
+
py0 = 0
|
|
367
|
+
py1 = size[1] - h - py0
|
|
368
|
+
else:
|
|
369
|
+
py0 = 0
|
|
370
|
+
py1 = 0
|
|
371
|
+
if w < size[0]:
|
|
372
|
+
px0 = 0
|
|
373
|
+
px1 = size[0] - w - px0
|
|
374
|
+
else:
|
|
375
|
+
px0 = 0
|
|
376
|
+
px1 = 0
|
|
377
|
+
|
|
378
|
+
if isinstance(pad_value,Iterable):
|
|
379
|
+
pad_value = pad_value[0]
|
|
380
|
+
fea = F.pad(fea, [px0, px1,py0,py1], "constant", pad_value)
|
|
381
|
+
|
|
382
|
+
if return_pad_value:
|
|
383
|
+
return fea, px0, px1, py0, py1
|
|
384
|
+
return fea
|
|
385
|
+
|
|
386
|
+
def split_forward_batch32(func):
|
|
387
|
+
@wraps(func)
|
|
388
|
+
def wrapper(self, data):
|
|
389
|
+
step = 32
|
|
390
|
+
res = []
|
|
391
|
+
cur_idx = 0
|
|
392
|
+
while cur_idx<data.shape[0]:
|
|
393
|
+
ret_val = func(self, data[cur_idx:cur_idx+step])
|
|
394
|
+
cur_idx += step
|
|
395
|
+
res.append(ret_val)
|
|
396
|
+
if len(res)==1:
|
|
397
|
+
return res[0]
|
|
398
|
+
if torch.is_tensor(res[0]):
|
|
399
|
+
return torch.cat(res,dim=0)
|
|
400
|
+
else:
|
|
401
|
+
return np.concatenate(res,axis=0)
|
|
402
|
+
return wrapper
|
|
403
|
+
|
|
404
|
+
def to(data,device=torch.device("cpu")):
|
|
405
|
+
if torch.is_tensor(data):
|
|
406
|
+
return data.to(device)
|
|
407
|
+
elif isinstance(data,(CfgNode,ConfigDict)):
|
|
408
|
+
return data
|
|
409
|
+
elif isinstance(data,dict):
|
|
410
|
+
keys = list(data.keys())
|
|
411
|
+
new_data = {}
|
|
412
|
+
for k in keys:
|
|
413
|
+
new_data[k] = to(data[k],device)
|
|
414
|
+
elif isinstance(data,(list,tuple)):
|
|
415
|
+
new_data = []
|
|
416
|
+
for v in data:
|
|
417
|
+
new_data.append(to(v,device))
|
|
418
|
+
new_data = type(data)(new_data)
|
|
419
|
+
elif not isinstance(data,Iterable):
|
|
420
|
+
if not isinstance(data,torch.nn.Module) and hasattr(data,"to"):
|
|
421
|
+
data = data.to(device)
|
|
422
|
+
return data
|
|
423
|
+
elif isinstance(data,(np.ndarray,str,bytes)):
|
|
424
|
+
return data
|
|
425
|
+
else:
|
|
426
|
+
print(f"Unsupport type {type(data)}")
|
|
427
|
+
|
|
428
|
+
return new_data
|
|
429
|
+
|
|
430
|
+
def cpu(data):
|
|
431
|
+
return to(data,device=torch.device("cpu"))
|
|
432
|
+
|
|
433
|
+
def cuda(data):
|
|
434
|
+
return to(data,device=torch.device("cuda"))
|
|
435
|
+
|
|
436
|
+
def cpu_wraps(func):
|
|
437
|
+
@wraps(func)
|
|
438
|
+
def wraps_func(*args,**kwargs):
|
|
439
|
+
args = cpu(args)
|
|
440
|
+
kwargs = cpu(kwargs)
|
|
441
|
+
res = func(*args,**kwargs)
|
|
442
|
+
res = cuda(res)
|
|
443
|
+
return res
|
|
444
|
+
return wraps_func
|
|
445
|
+
|
|
446
|
+
def cpu_cpu_wraps(func):
|
|
447
|
+
@wraps(func)
|
|
448
|
+
def wraps_func(*args,**kwargs):
|
|
449
|
+
args = cpu(args)
|
|
450
|
+
kwargs = cpu(kwargs)
|
|
451
|
+
res = func(*args,**kwargs)
|
|
452
|
+
return res
|
|
453
|
+
return wraps_func
|
|
454
|
+
|
|
455
|
+
def numpy(data):
|
|
456
|
+
if torch.is_tensor(data):
|
|
457
|
+
return data.cpu().numpy()
|
|
458
|
+
elif isinstance(data,dict):
|
|
459
|
+
keys = list(data.keys())
|
|
460
|
+
new_data = {}
|
|
461
|
+
for k in keys:
|
|
462
|
+
new_data[k] = numpy(data[k])
|
|
463
|
+
elif isinstance(data,(list,tuple)):
|
|
464
|
+
new_data = []
|
|
465
|
+
for v in data:
|
|
466
|
+
new_data.append(numpy(v))
|
|
467
|
+
new_data = type(data)(new_data)
|
|
468
|
+
elif not isinstance(data,Iterable):
|
|
469
|
+
return data
|
|
470
|
+
elif isinstance(data,np.ndarray):
|
|
471
|
+
return data
|
|
472
|
+
else:
|
|
473
|
+
print(f"Unsupport type {type(data)}")
|
|
474
|
+
|
|
475
|
+
return new_data
|
|
476
|
+
|
|
477
|
+
def sparse_gather(data,index,return_tensor=True):
|
|
478
|
+
'''
|
|
479
|
+
data: list of tensor (mybe different length)
|
|
480
|
+
'''
|
|
481
|
+
res = []
|
|
482
|
+
for i,d in enumerate(data):
|
|
483
|
+
res.append(d[index[i]])
|
|
484
|
+
if return_tensor:
|
|
485
|
+
return torch.stack(res,dim=0)
|
|
486
|
+
else:
|
|
487
|
+
return res
|
|
488
|
+
|
|
489
|
+
def simple_model_device(model):
|
|
490
|
+
return next(model.parameters()).device
|
|
491
|
+
|
|
492
|
+
def resize_mask(mask,size=None,r=None):
|
|
493
|
+
'''
|
|
494
|
+
mask: [N,H,W]
|
|
495
|
+
size: (new_w,new_h)
|
|
496
|
+
'''
|
|
497
|
+
if size is None:
|
|
498
|
+
size = (int(mask.shape[2]*r),int(mask.shape[1]*r))
|
|
499
|
+
if mask.numel()==0:
|
|
500
|
+
return mask.new_zeros([mask.shape[0],size[1],size[0]])
|
|
501
|
+
|
|
502
|
+
mask = torch.unsqueeze(mask,dim=0)
|
|
503
|
+
mask = torch.nn.functional.interpolate(mask,size=(size[1],size[0]),mode='nearest')
|
|
504
|
+
mask = torch.squeeze(mask,dim=0)
|
|
505
|
+
return mask
|
|
506
|
+
|
|
507
|
+
def npresize_mask(mask,size=None,r=None):
|
|
508
|
+
'''
|
|
509
|
+
mask: [N,H,W]
|
|
510
|
+
size: (new_w,new_h)
|
|
511
|
+
'''
|
|
512
|
+
if mask.shape[0]==0:
|
|
513
|
+
return np.zeros([0,size[1],size[0]],dtype=mask.dtype)
|
|
514
|
+
if mask.shape[0]==1:
|
|
515
|
+
cur_m = cv2.resize(mask[0],dsize=(size[0],size[1]),interpolation=cv2.INTER_NEAREST)
|
|
516
|
+
return np.expand_dims(cur_m,axis=0)
|
|
517
|
+
mask = resize_mask(torch.from_numpy(mask),size,r)
|
|
518
|
+
return mask.numpy()
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def __correct_bboxes(bboxes,h,w):
|
|
523
|
+
old_type = bboxes.dtype
|
|
524
|
+
bboxes = np.maximum(bboxes,0)
|
|
525
|
+
bboxes = np.minimum(bboxes,np.array([[w,h,w,h]]))
|
|
526
|
+
return bboxes.astype(old_type)
|
|
527
|
+
|
|
528
|
+
def npresize_mask_in_bboxes(mask,bboxes,size=None,r=None):
|
|
529
|
+
'''
|
|
530
|
+
mask: [N,H,W]
|
|
531
|
+
bboxes: [N,4](x0,y0,x1,y1)
|
|
532
|
+
size: (new_w,new_h)
|
|
533
|
+
'''
|
|
534
|
+
if isinstance(mask,(WPolygonMasks,WBitmapMasks,WMCKeypoints)):
|
|
535
|
+
return mask.resize_mask_in_bboxes(bboxes,size=size,r=r)
|
|
536
|
+
if mask.shape[0]==0:
|
|
537
|
+
return np.zeros([0,size[1],size[0]],dtype=mask.dtype),np.zeros([0,4],dtype=bboxes.dtype)
|
|
538
|
+
x_scale = size[0]/mask.shape[2]
|
|
539
|
+
y_scale = size[1]/mask.shape[1]
|
|
540
|
+
bboxes = __correct_bboxes(bboxes,h=mask.shape[1],w=mask.shape[2])
|
|
541
|
+
resized_bboxes = (bboxes*np.array([[x_scale,y_scale,x_scale,y_scale]])).astype(np.int32)
|
|
542
|
+
resized_bboxes = __correct_bboxes(resized_bboxes,h=size[1],w=size[0])
|
|
543
|
+
bboxes = np.array(bboxes).astype(np.int32)
|
|
544
|
+
res_mask = np.zeros([mask.shape[0],size[1],size[0]],dtype=mask.dtype)
|
|
545
|
+
for i in range(mask.shape[0]):
|
|
546
|
+
dbbox = resized_bboxes[i]
|
|
547
|
+
dsize = (dbbox[2]-dbbox[0],dbbox[3]-dbbox[1])
|
|
548
|
+
if dsize[0]<=1 or dsize[1]<=1:
|
|
549
|
+
continue
|
|
550
|
+
sub_mask = wmli.crop_img_absolute_xy(mask[i],bboxes[i])
|
|
551
|
+
cur_m = cv2.resize(sub_mask,dsize=dsize,interpolation=cv2.INTER_NEAREST)
|
|
552
|
+
wmli.set_subimg(res_mask[i],cur_m,dbbox[:2])
|
|
553
|
+
return res_mask,resized_bboxes
|
|
554
|
+
|
|
555
|
+
def __time_npresize_mask_in_bboxes(mask,bboxes,size=None,r=None):
|
|
556
|
+
t = wmlu.TimeThis()
|
|
557
|
+
b = npresize_mask(mask,size,r)
|
|
558
|
+
t0 = t.time(reset=True)
|
|
559
|
+
a = npresize_mask_in_bboxes(mask,bboxes,size,r)
|
|
560
|
+
t1 = t.time(reset=True)
|
|
561
|
+
c = __npresize_mask(mask,size,r)
|
|
562
|
+
t2 = t.time(reset=True)
|
|
563
|
+
print(f"RM,{t0},{t1},{t2}")
|
|
564
|
+
return a
|
|
565
|
+
|
|
566
|
+
def clone_tensors(x):
|
|
567
|
+
if isinstance(x,(list,tuple)):
|
|
568
|
+
return [v.clone() for v in x]
|
|
569
|
+
return x.clone()
|
|
570
|
+
|
|
571
|
+
def _trunc_normal_(tensor, mean, std, a, b):
|
|
572
|
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
573
|
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
574
|
+
def norm_cdf(x):
|
|
575
|
+
# Computes standard normal cumulative distribution function
|
|
576
|
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
|
577
|
+
|
|
578
|
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
579
|
+
print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
580
|
+
"The distribution of values may be incorrect.")
|
|
581
|
+
|
|
582
|
+
# Values are generated by using a truncated uniform distribution and
|
|
583
|
+
# then using the inverse CDF for the normal distribution.
|
|
584
|
+
# Get upper and lower cdf values
|
|
585
|
+
l = norm_cdf((a - mean) / std)
|
|
586
|
+
u = norm_cdf((b - mean) / std)
|
|
587
|
+
|
|
588
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
|
589
|
+
# [2l-1, 2u-1].
|
|
590
|
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
591
|
+
|
|
592
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
|
593
|
+
# standard normal
|
|
594
|
+
tensor.erfinv_()
|
|
595
|
+
|
|
596
|
+
# Transform to proper mean, std
|
|
597
|
+
tensor.mul_(std * math.sqrt(2.))
|
|
598
|
+
tensor.add_(mean)
|
|
599
|
+
|
|
600
|
+
# Clamp to ensure it's in the proper range
|
|
601
|
+
tensor.clamp_(min=a, max=b)
|
|
602
|
+
return tensor
|
|
603
|
+
|
|
604
|
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
605
|
+
# type: (Tensor, float, float, float, float) -> Tensor
|
|
606
|
+
r"""Fills the input Tensor with values drawn from a truncated
|
|
607
|
+
normal distribution. The values are effectively drawn from the
|
|
608
|
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
609
|
+
with values outside :math:`[a, b]` redrawn until they are within
|
|
610
|
+
the bounds. The method used for generating the random values works
|
|
611
|
+
best when :math:`a \leq \text{mean} \leq b`.
|
|
612
|
+
|
|
613
|
+
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
|
614
|
+
applied while sampling the normal with mean/std applied, therefore a, b args
|
|
615
|
+
should be adjusted to match the range of mean, std args.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
tensor: an n-dimensional `torch.Tensor`
|
|
619
|
+
mean: the mean of the normal distribution
|
|
620
|
+
std: the standard deviation of the normal distribution
|
|
621
|
+
a: the minimum cutoff value
|
|
622
|
+
b: the maximum cutoff value
|
|
623
|
+
Examples:
|
|
624
|
+
>>> w = torch.empty(3, 5)
|
|
625
|
+
>>> nn.init.trunc_normal_(w)
|
|
626
|
+
"""
|
|
627
|
+
with torch.no_grad():
|
|
628
|
+
return _trunc_normal_(tensor, mean, std, a, b)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def embedding_version2scores(scores,version,exponent=2):
|
|
632
|
+
assert version>=0 and version<100,f"ERROR: version need in range [0,100)"
|
|
633
|
+
scale = math.pow(10,exponent)
|
|
634
|
+
scores = (scores*scale).to(torch.int32).to(torch.float32)
|
|
635
|
+
version = version/100
|
|
636
|
+
scores = (scores+version)/scale
|
|
637
|
+
return scores
|
|
638
|
+
|
|
639
|
+
def embedding_version2coord(coord,version,exponent=0):
|
|
640
|
+
assert version>=0 and version<100,f"ERROR: version need in range [0,100)"
|
|
641
|
+
scale = math.pow(10,exponent)
|
|
642
|
+
coord = (coord*scale).to(torch.int32).to(torch.float32)
|
|
643
|
+
version = version/100
|
|
644
|
+
coord = (coord+version)/scale
|
|
645
|
+
return coord
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def add_version2onnx(onnx_path,save_path,version):
|
|
649
|
+
model_proto = onnx.load(onnx_path)
|
|
650
|
+
#graph_proto = model_proto.graph
|
|
651
|
+
#model_metadata = {}
|
|
652
|
+
# 添加元数据
|
|
653
|
+
model_proto.metadata_props.extend([
|
|
654
|
+
onnx.helper.make_string_initializer(
|
|
655
|
+
'model_version',
|
|
656
|
+
onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[type(version)],
|
|
657
|
+
[1],
|
|
658
|
+
[version],
|
|
659
|
+
)
|
|
660
|
+
])
|
|
661
|
+
if save_path is None:
|
|
662
|
+
save_path = onnx_path
|
|
663
|
+
onnx.save(model_proto,save_path)
|
|
664
|
+
|
|
665
|
+
def add_metadta2onnx(model_onnx,metadata):
|
|
666
|
+
#model_onnx = onnx.load(f) # load onnx model
|
|
667
|
+
for k, v in metadata.items():
|
|
668
|
+
meta = model_onnx.metadata_props.add()
|
|
669
|
+
meta.key, meta.value = k, str(v)
|
|
670
|
+
|
|
671
|
+
return model_onnx
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
class SafeClass:
|
|
675
|
+
"""A placeholder class to replace unknown classes during unpickling."""
|
|
676
|
+
|
|
677
|
+
def __init__(self, *args, **kwargs):
|
|
678
|
+
"""Initialize SafeClass instance, ignoring all arguments."""
|
|
679
|
+
pass
|
|
680
|
+
|
|
681
|
+
def __call__(self, *args, **kwargs):
|
|
682
|
+
"""Run SafeClass instance, ignoring all arguments."""
|
|
683
|
+
pass
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
class SafeUnpickler(pickle.Unpickler):
|
|
687
|
+
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
|
688
|
+
|
|
689
|
+
def find_class(self, module, name):
|
|
690
|
+
"""Attempt to find a class, returning SafeClass if not among safe modules."""
|
|
691
|
+
safe_modules = (
|
|
692
|
+
"torch",
|
|
693
|
+
"collections",
|
|
694
|
+
"collections.abc",
|
|
695
|
+
"builtins",
|
|
696
|
+
"math",
|
|
697
|
+
"numpy",
|
|
698
|
+
# Add other modules considered safe
|
|
699
|
+
)
|
|
700
|
+
if module in safe_modules:
|
|
701
|
+
return super().find_class(module, name)
|
|
702
|
+
else:
|
|
703
|
+
return SafeClass
|
|
704
|
+
|
|
705
|
+
def safe_load(file,*args,**kwargs):
|
|
706
|
+
# Load via custom pickle module
|
|
707
|
+
safe_pickle = types.ModuleType("safe_pickle")
|
|
708
|
+
safe_pickle.Unpickler = SafeUnpickler
|
|
709
|
+
safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
|
|
710
|
+
with open(file, "rb") as f:
|
|
711
|
+
ckpt = torch.load(f, pickle_module=safe_pickle,*args,**kwargs)
|
|
712
|
+
return ckpt
|
|
713
|
+
|
|
714
|
+
def load(file,*args,**kwargs):
|
|
715
|
+
try:
|
|
716
|
+
return torch.load(file,*args,**kwargs)
|
|
717
|
+
except Exception as e:
|
|
718
|
+
print(f"WARNING: load ckpt {file} faild, info: {e}, try safe load...")
|
|
719
|
+
return safe_load(file)
|