python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
wml/threadtoolkit.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
#coding=utf-8
|
|
2
|
+
from multiprocessing import cpu_count
|
|
3
|
+
from multiprocessing import Process,Queue,Pool
|
|
4
|
+
from functools import partial
|
|
5
|
+
import traceback
|
|
6
|
+
import time
|
|
7
|
+
import os
|
|
8
|
+
import wml.wml_utils as wmlu
|
|
9
|
+
|
|
10
|
+
DEFAULT_THREAD_NR=cpu_count() if cpu_count()<4 else cpu_count()-1
|
|
11
|
+
|
|
12
|
+
def fn_wraper(datas,fn,is_factory_fn=False):
|
|
13
|
+
if is_factory_fn:
|
|
14
|
+
fn = fn()
|
|
15
|
+
res_queue = []
|
|
16
|
+
print(f"Process {os.getpid()}: data nr {len(datas)}.")
|
|
17
|
+
for data,i in datas:
|
|
18
|
+
try:
|
|
19
|
+
res = fn(data)
|
|
20
|
+
res_queue.append((i,res))
|
|
21
|
+
except:
|
|
22
|
+
traceback.print_exc()
|
|
23
|
+
print(f"Process {os.getpid()} is finished.")
|
|
24
|
+
return res_queue
|
|
25
|
+
|
|
26
|
+
def par_for_each(data,fn,thread_nr=DEFAULT_THREAD_NR,is_factory_fn=False,timeout=None):
|
|
27
|
+
if len(data) == 0:
|
|
28
|
+
return []
|
|
29
|
+
thread_nr = min(len(data),thread_nr)
|
|
30
|
+
pool = Pool(thread_nr)
|
|
31
|
+
data = list(zip(data,range(len(data))))
|
|
32
|
+
datas = wmlu.list_to_2dlistv2(data,thread_nr)
|
|
33
|
+
raw_res = list(pool.map(partial(fn_wraper,fn=fn,is_factory_fn=is_factory_fn),datas))
|
|
34
|
+
pool.close()
|
|
35
|
+
pool.join()
|
|
36
|
+
|
|
37
|
+
res_data = []
|
|
38
|
+
for res in raw_res:
|
|
39
|
+
res_data.extend(res)
|
|
40
|
+
res_data = sorted(res_data,key=lambda x:x[0])
|
|
41
|
+
_,res_data = zip(*res_data)
|
|
42
|
+
return res_data
|
|
43
|
+
|
|
44
|
+
def par_for_each_no_return(data,fn,thread_nr=DEFAULT_THREAD_NR):
|
|
45
|
+
thread_nr = min(len(data),thread_nr)
|
|
46
|
+
pool = Pool(thread_nr)
|
|
47
|
+
datas = wmlu.list_to_2dlistv2(data,thread_nr)
|
|
48
|
+
pool.map(fn,datas)
|
|
49
|
+
pool.close()
|
|
50
|
+
pool.join()
|
wml/walgorithm.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
#coding=utf-8
|
|
2
|
+
from multiprocessing import Pool
|
|
3
|
+
import numpy as np
|
|
4
|
+
import math
|
|
5
|
+
import cv2
|
|
6
|
+
|
|
7
|
+
def _edit_distance(v0,v1):
|
|
8
|
+
if v0 == v1:
|
|
9
|
+
return 0
|
|
10
|
+
if (len(v0)==0) or (len(v1)==0):
|
|
11
|
+
return max(len(v0),len(v1))
|
|
12
|
+
c0 = _edit_distance(v0[:-1],v1)+1
|
|
13
|
+
c1 = _edit_distance(v0,v1[:-1])+1
|
|
14
|
+
cr = 0
|
|
15
|
+
if v0[-1] != v1[-1]:
|
|
16
|
+
cr = 1
|
|
17
|
+
c2 = _edit_distance(v0[:-1],v1[:-1])+cr
|
|
18
|
+
return min(min(c0,c1),c2)
|
|
19
|
+
|
|
20
|
+
def mt_edit_distance(v0,v1,pool):
|
|
21
|
+
if v0 == v1:
|
|
22
|
+
return 0
|
|
23
|
+
if (len(v0)==0) or (len(v1)==0):
|
|
24
|
+
return max(len(v0),len(v1))
|
|
25
|
+
c0 = edit_distance(v0[:-1],v1)+1
|
|
26
|
+
c1 = edit_distance(v0,v1[:-1])+1
|
|
27
|
+
cr = 0
|
|
28
|
+
if v0[-1] != v1[-1]:
|
|
29
|
+
cr = 1
|
|
30
|
+
c2 = edit_distance(v0[:-1],v1[:-1])+cr
|
|
31
|
+
return min(min(c0,c1),c2)
|
|
32
|
+
|
|
33
|
+
def edit_distance(sm, sn):
|
|
34
|
+
m, n = len(sm) + 1, len(sn) + 1
|
|
35
|
+
|
|
36
|
+
matrix = np.ndarray(shape=[m,n],dtype=np.int32)
|
|
37
|
+
|
|
38
|
+
matrix[0][0] = 0
|
|
39
|
+
for i in range(1, m):
|
|
40
|
+
matrix[i][0] = matrix[i - 1][0] + 1
|
|
41
|
+
|
|
42
|
+
for j in range(1, n):
|
|
43
|
+
matrix[0][j] = matrix[0][j - 1] + 1
|
|
44
|
+
|
|
45
|
+
for i in range(1, m):
|
|
46
|
+
for j in range(1, n):
|
|
47
|
+
if sm[i - 1] == sn[j - 1]:
|
|
48
|
+
cost = 0
|
|
49
|
+
else:
|
|
50
|
+
cost = 1
|
|
51
|
+
|
|
52
|
+
matrix[i][j] = min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost)
|
|
53
|
+
|
|
54
|
+
return matrix[m - 1][n - 1]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def pearsonr(x,y):
|
|
58
|
+
#Pearson_correlation coefficient [-1,1]
|
|
59
|
+
if not isinstance(x,np.ndarray):
|
|
60
|
+
x = np.array(x)
|
|
61
|
+
|
|
62
|
+
if not isinstance(y, np.ndarray):
|
|
63
|
+
y = np.array(y)
|
|
64
|
+
|
|
65
|
+
x_mean = np.mean(x)
|
|
66
|
+
y_mean = np.mean(y)
|
|
67
|
+
x_ba = x-x_mean
|
|
68
|
+
y_ba = y-y_mean
|
|
69
|
+
v = np.sum(x_ba*y_ba)
|
|
70
|
+
dx = np.sum((x-x_mean)**2)
|
|
71
|
+
dy = np.sum((y-y_mean)**2)
|
|
72
|
+
sv = np.sqrt(dx*dy)+1e-8
|
|
73
|
+
|
|
74
|
+
return v/sv
|
|
75
|
+
|
|
76
|
+
def points_to_polygon(points):
|
|
77
|
+
'''
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
points: [N,2],(x,y)
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
idxs,[N],sorted points[N,2]
|
|
84
|
+
'''
|
|
85
|
+
|
|
86
|
+
points = np.array(points)
|
|
87
|
+
base_point = 0
|
|
88
|
+
if points.shape[0]<=3:
|
|
89
|
+
return list(range(points.shape[0])),points
|
|
90
|
+
for i in range(points.shape[0]):
|
|
91
|
+
if points[i,1]<points[base_point,1]:
|
|
92
|
+
base_point = i
|
|
93
|
+
elif points[i, 1] == points[base_point, 1] and points[i,0]<points[base_point,0]:
|
|
94
|
+
base_point = i
|
|
95
|
+
|
|
96
|
+
angles = np.zeros([points.shape[0]],dtype=np.float32)
|
|
97
|
+
|
|
98
|
+
for i in range(points.shape[0]):
|
|
99
|
+
y = points[i,1]-points[base_point,1]
|
|
100
|
+
x = points[i,0]-points[base_point,0]
|
|
101
|
+
angles[i] = math.atan2(y,x)
|
|
102
|
+
if angles[i]<0:
|
|
103
|
+
angles[i] += math.pi
|
|
104
|
+
angles[base_point] = -1e-8
|
|
105
|
+
idxs = np.argsort(angles)
|
|
106
|
+
return idxs,points[idxs]
|
|
107
|
+
|
|
108
|
+
def left_shift_array(array,size=1):
|
|
109
|
+
'''
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
array: [N]
|
|
113
|
+
size: 1->N-1
|
|
114
|
+
example:
|
|
115
|
+
array = [1,2,3,4]
|
|
116
|
+
size=1
|
|
117
|
+
return:
|
|
118
|
+
[2,3,4,1]
|
|
119
|
+
Returns:
|
|
120
|
+
[N]
|
|
121
|
+
'''
|
|
122
|
+
first_part = array[size:]
|
|
123
|
+
second_part = array[:size]
|
|
124
|
+
return np.concatenate([first_part,second_part],axis=0)
|
|
125
|
+
|
|
126
|
+
def right_shift_array(array, size=1):
|
|
127
|
+
'''
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
array: [N]
|
|
131
|
+
size: 1->N-1
|
|
132
|
+
example:
|
|
133
|
+
array = [1,2,3,4]
|
|
134
|
+
size=1
|
|
135
|
+
return:
|
|
136
|
+
[4,1,2,3,]
|
|
137
|
+
Returns:
|
|
138
|
+
[N]
|
|
139
|
+
'''
|
|
140
|
+
first_part = array[-size:]
|
|
141
|
+
second_part = array[:-size]
|
|
142
|
+
return np.concatenate([first_part, second_part], axis=0)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def sign_point_line(point,line):
|
|
146
|
+
'''
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
point: [2] x,y
|
|
150
|
+
line: np.array([2,2]) [(x0,y0),(x1,y1)]
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
True or False
|
|
154
|
+
'''
|
|
155
|
+
line = np.array(line)
|
|
156
|
+
p0 = line[0]
|
|
157
|
+
vec0 = line[1]-p0
|
|
158
|
+
vec1 = point-p0
|
|
159
|
+
return vec0[0]*vec1[1]-vec0[1]*vec1[0]<0
|
|
160
|
+
|
|
161
|
+
def in_range(v,*kargs):
|
|
162
|
+
if len(kargs)==1:
|
|
163
|
+
min_v = kargs[0][0]
|
|
164
|
+
max_v = kargs[0][1]
|
|
165
|
+
elif len(kargs)==2:
|
|
166
|
+
min_v = kargs[0]
|
|
167
|
+
max_v = kargs[1]
|
|
168
|
+
else:
|
|
169
|
+
raise RuntimeError(f"in_range: ERROR args {kargs}")
|
|
170
|
+
|
|
171
|
+
return v>=min_v and v<=max_v
|
|
172
|
+
|
|
173
|
+
def points_on_circle(center=None,r=None,points_nr=100):
|
|
174
|
+
'''
|
|
175
|
+
将圆离散化为散点
|
|
176
|
+
'''
|
|
177
|
+
points = []
|
|
178
|
+
for i in range(points_nr):
|
|
179
|
+
angle = math.pi*2*i/points_nr
|
|
180
|
+
x = math.cos(angle)
|
|
181
|
+
y = math.sin(angle)
|
|
182
|
+
points.append([x,y])
|
|
183
|
+
if r is not None:
|
|
184
|
+
points = np.array(points)*r
|
|
185
|
+
if center is not None:
|
|
186
|
+
center = np.reshape(np.array(center),[1,2])
|
|
187
|
+
points = points + center
|
|
188
|
+
|
|
189
|
+
return points
|
|
190
|
+
|
|
191
|
+
def getRotationMatrix2D(center, angle, scale,out_offset=None):
|
|
192
|
+
if out_offset is None:
|
|
193
|
+
'''
|
|
194
|
+
cv2为先平移-center,scale,rotate,再平移center
|
|
195
|
+
M(center)*M(rotate)*M(scale)*M(-center)*X
|
|
196
|
+
'''
|
|
197
|
+
return cv2.getRotationMatrix2D(center=center,angle=angle,scale=scale)
|
|
198
|
+
offset_in = np.array([[1,0,-center[0]],[0,1,-center[1]]],dtype=np.float32)
|
|
199
|
+
rotate_m = cv2.getRotationMatrix2D(center=[0,0],angle=angle,scale=scale)
|
|
200
|
+
offset_out = np.array([[1,0,out_offset[0]],[0,1,out_offset[1]]],dtype=np.float32)
|
|
201
|
+
line3 = np.array([[0,0,1]],dtype=np.float32)
|
|
202
|
+
offset_in = np.concatenate([offset_in,line3],axis=0)
|
|
203
|
+
rotate_m = np.concatenate([rotate_m,line3],axis=0)
|
|
204
|
+
offset_out = np.concatenate([offset_out,line3],axis=0)
|
|
205
|
+
m = np.dot(rotate_m,offset_in)
|
|
206
|
+
m = np.dot(offset_out,m)
|
|
207
|
+
return m[:2]
|
|
208
|
+
|
|
209
|
+
def lower_bound(datas, target):
|
|
210
|
+
"""
|
|
211
|
+
对于升序数组,找到第一个大于等于(或不小于)给定值的目标元素的位置
|
|
212
|
+
"""
|
|
213
|
+
if datas[-1]<target:
|
|
214
|
+
return -1
|
|
215
|
+
if datas[0]>=target:
|
|
216
|
+
return 0
|
|
217
|
+
left, right = 0, len(datas) - 1 # 闭区间[left, right]
|
|
218
|
+
while left <= right: # 区间不为空
|
|
219
|
+
mid = (left + right) // 2
|
|
220
|
+
if datas[mid] < target:
|
|
221
|
+
left = mid + 1 # [mid + 1, right]
|
|
222
|
+
else:
|
|
223
|
+
right = mid - 1 # [left, mid - 1]
|
|
224
|
+
return left
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def remove_non_ascii(s):
|
|
228
|
+
return ''.join(filter(str.isascii, s))
|
wml/wcollections.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import random
|
|
3
|
+
import copy
|
|
4
|
+
from collections import abc
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ExperienceBuffer():
|
|
8
|
+
def __init__(self, buffer_size = 100000):
|
|
9
|
+
self.buffer = []
|
|
10
|
+
self.buffer_size = buffer_size
|
|
11
|
+
|
|
12
|
+
def add(self, experience):
|
|
13
|
+
if len(self.buffer) + len(experience) >= self.buffer_size:
|
|
14
|
+
self.buffer[0: (len(experience) + len(self.buffer)) - self.buffer_size] = []
|
|
15
|
+
self.buffer.extend(experience)
|
|
16
|
+
|
|
17
|
+
def sample(self, size):
|
|
18
|
+
if not isinstance(self.buffer[0],np.ndarray):
|
|
19
|
+
data = random.sample(self.buffer,size)
|
|
20
|
+
data = list(zip(*data))
|
|
21
|
+
return [np.array(list(x)) for x in data]
|
|
22
|
+
else:
|
|
23
|
+
return np.reshape(np.array(random.sample(self.buffer, size)), [size]+list(self.buffer[0].shape))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CycleBuffer:
|
|
27
|
+
def __init__(self,cap=5):
|
|
28
|
+
self.cap = cap
|
|
29
|
+
self.buffer = []
|
|
30
|
+
def append(self,v):
|
|
31
|
+
self.buffer.append(v)
|
|
32
|
+
l = len(self.buffer)
|
|
33
|
+
if l>self.cap:
|
|
34
|
+
self.buffer = self.buffer[l-self.cap:]
|
|
35
|
+
|
|
36
|
+
def __getitem__(self, slice):
|
|
37
|
+
return self.buffer[slice]
|
|
38
|
+
|
|
39
|
+
def __len__(self):
|
|
40
|
+
return len(self.buffer)
|
|
41
|
+
|
|
42
|
+
class AlwaysNullObj(object):
|
|
43
|
+
def __init__(self,*args,**kwargs):
|
|
44
|
+
print(f"Construct a always null object")
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def __getattr__(self, item):
|
|
48
|
+
return self
|
|
49
|
+
|
|
50
|
+
def __setattr__(self, key, value):
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def __delattr__(self, item):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
def __call__(self, *args, **kwargs):
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MDict(dict):
|
|
61
|
+
def __init__(self, *args, **kwargs):
|
|
62
|
+
'''
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
*args:
|
|
66
|
+
**kwargs:
|
|
67
|
+
|
|
68
|
+
example 1:
|
|
69
|
+
x = MDict(dtype=list)
|
|
70
|
+
x[1].append('a')
|
|
71
|
+
x[2].append('b')
|
|
72
|
+
x[1].append('c')
|
|
73
|
+
print(x)
|
|
74
|
+
output:
|
|
75
|
+
{1: ['a', 'c'], 2: ['b']}
|
|
76
|
+
|
|
77
|
+
example 2:
|
|
78
|
+
x = MDict(dvalue=[])
|
|
79
|
+
x[1].append('a')
|
|
80
|
+
x[2].append('b')
|
|
81
|
+
x[1].append('c')
|
|
82
|
+
print(x)
|
|
83
|
+
output:
|
|
84
|
+
{1: ['a', 'c'], 2: ['b']}
|
|
85
|
+
'''
|
|
86
|
+
self.default_type = None
|
|
87
|
+
self.default_value = None
|
|
88
|
+
if "dtype" in kwargs:
|
|
89
|
+
self.default_type = kwargs.pop("dtype")
|
|
90
|
+
elif "dvalue" in kwargs:
|
|
91
|
+
self.default_value = kwargs.pop("dvalue")
|
|
92
|
+
super().__init__(*args,**kwargs)
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_dict(cls,data:dict,auto_dtype=True):
|
|
96
|
+
x = data.values()
|
|
97
|
+
assert len(x)>0, "error dict data"
|
|
98
|
+
if auto_dtype:
|
|
99
|
+
dtype = type(list(x)[0])
|
|
100
|
+
ret = cls(dtype=dtype)
|
|
101
|
+
else:
|
|
102
|
+
ret = cls(dvalue=None)
|
|
103
|
+
for k,v in data.items():
|
|
104
|
+
ret[k] = v
|
|
105
|
+
return ret
|
|
106
|
+
|
|
107
|
+
def __getattr__(self, key):
|
|
108
|
+
if key in self.__dict__:
|
|
109
|
+
return self.__dict__[key]
|
|
110
|
+
return self.__getitem__(key)
|
|
111
|
+
|
|
112
|
+
def __call__(self,key):
|
|
113
|
+
return self.__getitem__(key)
|
|
114
|
+
|
|
115
|
+
def __getitem__(self, key):
|
|
116
|
+
if key in self.__dict__:
|
|
117
|
+
return self.__dict__[key]
|
|
118
|
+
if key in self:
|
|
119
|
+
return super().__getitem__(key)
|
|
120
|
+
elif self.default_type is not None:
|
|
121
|
+
super().__setitem__(key,self.default_type())
|
|
122
|
+
return super().__getitem__(key)
|
|
123
|
+
elif self.default_value is not None:
|
|
124
|
+
super().__setitem__(key,self.default_value)
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
def __delattr__(self, key):
|
|
128
|
+
del self[key]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Counter(dict):
|
|
132
|
+
'''
|
|
133
|
+
Example:
|
|
134
|
+
counter = Counter()
|
|
135
|
+
counter.add("a")
|
|
136
|
+
counter.add("a")
|
|
137
|
+
counter.add("b")
|
|
138
|
+
print(counter):
|
|
139
|
+
{
|
|
140
|
+
a: 2
|
|
141
|
+
b: 1
|
|
142
|
+
}
|
|
143
|
+
'''
|
|
144
|
+
def add(self,key,nr=1):
|
|
145
|
+
if key in self:
|
|
146
|
+
self[key] += nr
|
|
147
|
+
else:
|
|
148
|
+
self[key] = nr
|
|
149
|
+
return self[key]
|
|
150
|
+
|
|
151
|
+
def total_size(self):
|
|
152
|
+
return np.sum(list(self.values()))
|
|
153
|
+
|
|
154
|
+
class EDict(dict):
|
|
155
|
+
'''
|
|
156
|
+
只能添加键值,不允许更新
|
|
157
|
+
'''
|
|
158
|
+
def __setitem__(self,item,value):
|
|
159
|
+
if item in self:
|
|
160
|
+
raise RuntimeError(f"ERROR: key {item} already exists.")
|
|
161
|
+
super().__setitem__(item,value)
|
|
162
|
+
|
|
163
|
+
def safe_update_dict(target_dict,src_dict,do_raise=True):
|
|
164
|
+
duplicate_keys = []
|
|
165
|
+
for k in src_dict.keys():
|
|
166
|
+
if k in target_dict and target_dict[k] != src_dict[k]:
|
|
167
|
+
duplicate_keys.append(k)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
if len(duplicate_keys)>0:
|
|
171
|
+
if do_raise:
|
|
172
|
+
raise RuntimeError(f"key {duplicate_keys} already in target dict, target dict keys {list(target_dict.keys())}")
|
|
173
|
+
else:
|
|
174
|
+
print(f"ERROR: key {duplicate_keys} already in target dict, target dict keys {list(target_dict.keys())}")
|
|
175
|
+
|
|
176
|
+
target_dict.update(src_dict)
|
|
177
|
+
|
|
178
|
+
def trans_dict_key2lower(data):
|
|
179
|
+
res = type(data)()
|
|
180
|
+
for k,v in data.items():
|
|
181
|
+
res[k.lower()] = v
|
|
182
|
+
return res
|
|
183
|
+
|
|
184
|
+
def is_seq_of(seq, expected_type, seq_type=None):
|
|
185
|
+
"""Check whether it is a sequence of some type.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
seq (Sequence): The sequence to be checked.
|
|
189
|
+
expected_type (type): Expected type of sequence items.
|
|
190
|
+
seq_type (type, optional): Expected sequence type.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
bool: Whether the sequence is valid.
|
|
194
|
+
"""
|
|
195
|
+
if seq_type is None:
|
|
196
|
+
exp_seq_type = abc.Sequence
|
|
197
|
+
else:
|
|
198
|
+
assert isinstance(seq_type, type)
|
|
199
|
+
exp_seq_type = seq_type
|
|
200
|
+
if not isinstance(seq, exp_seq_type):
|
|
201
|
+
return False
|
|
202
|
+
for item in seq:
|
|
203
|
+
if not isinstance(item, expected_type):
|
|
204
|
+
return False
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
def is_list_of(seq, expected_type):
|
|
208
|
+
"""Check whether it is a list of some type.
|
|
209
|
+
|
|
210
|
+
A partial method of :func:`is_seq_of`.
|
|
211
|
+
"""
|
|
212
|
+
return is_seq_of(seq, expected_type, seq_type=list)
|