python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
wml/threadtoolkit.py ADDED
@@ -0,0 +1,50 @@
1
+ #coding=utf-8
2
+ from multiprocessing import cpu_count
3
+ from multiprocessing import Process,Queue,Pool
4
+ from functools import partial
5
+ import traceback
6
+ import time
7
+ import os
8
+ import wml.wml_utils as wmlu
9
+
10
+ DEFAULT_THREAD_NR=cpu_count() if cpu_count()<4 else cpu_count()-1
11
+
12
+ def fn_wraper(datas,fn,is_factory_fn=False):
13
+ if is_factory_fn:
14
+ fn = fn()
15
+ res_queue = []
16
+ print(f"Process {os.getpid()}: data nr {len(datas)}.")
17
+ for data,i in datas:
18
+ try:
19
+ res = fn(data)
20
+ res_queue.append((i,res))
21
+ except:
22
+ traceback.print_exc()
23
+ print(f"Process {os.getpid()} is finished.")
24
+ return res_queue
25
+
26
+ def par_for_each(data,fn,thread_nr=DEFAULT_THREAD_NR,is_factory_fn=False,timeout=None):
27
+ if len(data) == 0:
28
+ return []
29
+ thread_nr = min(len(data),thread_nr)
30
+ pool = Pool(thread_nr)
31
+ data = list(zip(data,range(len(data))))
32
+ datas = wmlu.list_to_2dlistv2(data,thread_nr)
33
+ raw_res = list(pool.map(partial(fn_wraper,fn=fn,is_factory_fn=is_factory_fn),datas))
34
+ pool.close()
35
+ pool.join()
36
+
37
+ res_data = []
38
+ for res in raw_res:
39
+ res_data.extend(res)
40
+ res_data = sorted(res_data,key=lambda x:x[0])
41
+ _,res_data = zip(*res_data)
42
+ return res_data
43
+
44
+ def par_for_each_no_return(data,fn,thread_nr=DEFAULT_THREAD_NR):
45
+ thread_nr = min(len(data),thread_nr)
46
+ pool = Pool(thread_nr)
47
+ datas = wmlu.list_to_2dlistv2(data,thread_nr)
48
+ pool.map(fn,datas)
49
+ pool.close()
50
+ pool.join()
wml/walgorithm.py ADDED
@@ -0,0 +1,228 @@
1
+ #coding=utf-8
2
+ from multiprocessing import Pool
3
+ import numpy as np
4
+ import math
5
+ import cv2
6
+
7
+ def _edit_distance(v0,v1):
8
+ if v0 == v1:
9
+ return 0
10
+ if (len(v0)==0) or (len(v1)==0):
11
+ return max(len(v0),len(v1))
12
+ c0 = _edit_distance(v0[:-1],v1)+1
13
+ c1 = _edit_distance(v0,v1[:-1])+1
14
+ cr = 0
15
+ if v0[-1] != v1[-1]:
16
+ cr = 1
17
+ c2 = _edit_distance(v0[:-1],v1[:-1])+cr
18
+ return min(min(c0,c1),c2)
19
+
20
+ def mt_edit_distance(v0,v1,pool):
21
+ if v0 == v1:
22
+ return 0
23
+ if (len(v0)==0) or (len(v1)==0):
24
+ return max(len(v0),len(v1))
25
+ c0 = edit_distance(v0[:-1],v1)+1
26
+ c1 = edit_distance(v0,v1[:-1])+1
27
+ cr = 0
28
+ if v0[-1] != v1[-1]:
29
+ cr = 1
30
+ c2 = edit_distance(v0[:-1],v1[:-1])+cr
31
+ return min(min(c0,c1),c2)
32
+
33
+ def edit_distance(sm, sn):
34
+ m, n = len(sm) + 1, len(sn) + 1
35
+
36
+ matrix = np.ndarray(shape=[m,n],dtype=np.int32)
37
+
38
+ matrix[0][0] = 0
39
+ for i in range(1, m):
40
+ matrix[i][0] = matrix[i - 1][0] + 1
41
+
42
+ for j in range(1, n):
43
+ matrix[0][j] = matrix[0][j - 1] + 1
44
+
45
+ for i in range(1, m):
46
+ for j in range(1, n):
47
+ if sm[i - 1] == sn[j - 1]:
48
+ cost = 0
49
+ else:
50
+ cost = 1
51
+
52
+ matrix[i][j] = min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost)
53
+
54
+ return matrix[m - 1][n - 1]
55
+
56
+
57
+ def pearsonr(x,y):
58
+ #Pearson_correlation coefficient [-1,1]
59
+ if not isinstance(x,np.ndarray):
60
+ x = np.array(x)
61
+
62
+ if not isinstance(y, np.ndarray):
63
+ y = np.array(y)
64
+
65
+ x_mean = np.mean(x)
66
+ y_mean = np.mean(y)
67
+ x_ba = x-x_mean
68
+ y_ba = y-y_mean
69
+ v = np.sum(x_ba*y_ba)
70
+ dx = np.sum((x-x_mean)**2)
71
+ dy = np.sum((y-y_mean)**2)
72
+ sv = np.sqrt(dx*dy)+1e-8
73
+
74
+ return v/sv
75
+
76
+ def points_to_polygon(points):
77
+ '''
78
+
79
+ Args:
80
+ points: [N,2],(x,y)
81
+
82
+ Returns:
83
+ idxs,[N],sorted points[N,2]
84
+ '''
85
+
86
+ points = np.array(points)
87
+ base_point = 0
88
+ if points.shape[0]<=3:
89
+ return list(range(points.shape[0])),points
90
+ for i in range(points.shape[0]):
91
+ if points[i,1]<points[base_point,1]:
92
+ base_point = i
93
+ elif points[i, 1] == points[base_point, 1] and points[i,0]<points[base_point,0]:
94
+ base_point = i
95
+
96
+ angles = np.zeros([points.shape[0]],dtype=np.float32)
97
+
98
+ for i in range(points.shape[0]):
99
+ y = points[i,1]-points[base_point,1]
100
+ x = points[i,0]-points[base_point,0]
101
+ angles[i] = math.atan2(y,x)
102
+ if angles[i]<0:
103
+ angles[i] += math.pi
104
+ angles[base_point] = -1e-8
105
+ idxs = np.argsort(angles)
106
+ return idxs,points[idxs]
107
+
108
+ def left_shift_array(array,size=1):
109
+ '''
110
+
111
+ Args:
112
+ array: [N]
113
+ size: 1->N-1
114
+ example:
115
+ array = [1,2,3,4]
116
+ size=1
117
+ return:
118
+ [2,3,4,1]
119
+ Returns:
120
+ [N]
121
+ '''
122
+ first_part = array[size:]
123
+ second_part = array[:size]
124
+ return np.concatenate([first_part,second_part],axis=0)
125
+
126
+ def right_shift_array(array, size=1):
127
+ '''
128
+
129
+ Args:
130
+ array: [N]
131
+ size: 1->N-1
132
+ example:
133
+ array = [1,2,3,4]
134
+ size=1
135
+ return:
136
+ [4,1,2,3,]
137
+ Returns:
138
+ [N]
139
+ '''
140
+ first_part = array[-size:]
141
+ second_part = array[:-size]
142
+ return np.concatenate([first_part, second_part], axis=0)
143
+
144
+
145
+ def sign_point_line(point,line):
146
+ '''
147
+
148
+ Args:
149
+ point: [2] x,y
150
+ line: np.array([2,2]) [(x0,y0),(x1,y1)]
151
+
152
+ Returns:
153
+ True or False
154
+ '''
155
+ line = np.array(line)
156
+ p0 = line[0]
157
+ vec0 = line[1]-p0
158
+ vec1 = point-p0
159
+ return vec0[0]*vec1[1]-vec0[1]*vec1[0]<0
160
+
161
+ def in_range(v,*kargs):
162
+ if len(kargs)==1:
163
+ min_v = kargs[0][0]
164
+ max_v = kargs[0][1]
165
+ elif len(kargs)==2:
166
+ min_v = kargs[0]
167
+ max_v = kargs[1]
168
+ else:
169
+ raise RuntimeError(f"in_range: ERROR args {kargs}")
170
+
171
+ return v>=min_v and v<=max_v
172
+
173
+ def points_on_circle(center=None,r=None,points_nr=100):
174
+ '''
175
+ 将圆离散化为散点
176
+ '''
177
+ points = []
178
+ for i in range(points_nr):
179
+ angle = math.pi*2*i/points_nr
180
+ x = math.cos(angle)
181
+ y = math.sin(angle)
182
+ points.append([x,y])
183
+ if r is not None:
184
+ points = np.array(points)*r
185
+ if center is not None:
186
+ center = np.reshape(np.array(center),[1,2])
187
+ points = points + center
188
+
189
+ return points
190
+
191
+ def getRotationMatrix2D(center, angle, scale,out_offset=None):
192
+ if out_offset is None:
193
+ '''
194
+ cv2为先平移-center,scale,rotate,再平移center
195
+ M(center)*M(rotate)*M(scale)*M(-center)*X
196
+ '''
197
+ return cv2.getRotationMatrix2D(center=center,angle=angle,scale=scale)
198
+ offset_in = np.array([[1,0,-center[0]],[0,1,-center[1]]],dtype=np.float32)
199
+ rotate_m = cv2.getRotationMatrix2D(center=[0,0],angle=angle,scale=scale)
200
+ offset_out = np.array([[1,0,out_offset[0]],[0,1,out_offset[1]]],dtype=np.float32)
201
+ line3 = np.array([[0,0,1]],dtype=np.float32)
202
+ offset_in = np.concatenate([offset_in,line3],axis=0)
203
+ rotate_m = np.concatenate([rotate_m,line3],axis=0)
204
+ offset_out = np.concatenate([offset_out,line3],axis=0)
205
+ m = np.dot(rotate_m,offset_in)
206
+ m = np.dot(offset_out,m)
207
+ return m[:2]
208
+
209
+ def lower_bound(datas, target):
210
+ """
211
+ 对于升序数组,找到第一个大于等于(或不小于)给定值的目标元素的位置
212
+ """
213
+ if datas[-1]<target:
214
+ return -1
215
+ if datas[0]>=target:
216
+ return 0
217
+ left, right = 0, len(datas) - 1 # 闭区间[left, right]
218
+ while left <= right: # 区间不为空
219
+ mid = (left + right) // 2
220
+ if datas[mid] < target:
221
+ left = mid + 1 # [mid + 1, right]
222
+ else:
223
+ right = mid - 1 # [left, mid - 1]
224
+ return left
225
+
226
+
227
+ def remove_non_ascii(s):
228
+ return ''.join(filter(str.isascii, s))
wml/wcollections.py ADDED
@@ -0,0 +1,212 @@
1
+ import numpy as np
2
+ import random
3
+ import copy
4
+ from collections import abc
5
+
6
+
7
+ class ExperienceBuffer():
8
+ def __init__(self, buffer_size = 100000):
9
+ self.buffer = []
10
+ self.buffer_size = buffer_size
11
+
12
+ def add(self, experience):
13
+ if len(self.buffer) + len(experience) >= self.buffer_size:
14
+ self.buffer[0: (len(experience) + len(self.buffer)) - self.buffer_size] = []
15
+ self.buffer.extend(experience)
16
+
17
+ def sample(self, size):
18
+ if not isinstance(self.buffer[0],np.ndarray):
19
+ data = random.sample(self.buffer,size)
20
+ data = list(zip(*data))
21
+ return [np.array(list(x)) for x in data]
22
+ else:
23
+ return np.reshape(np.array(random.sample(self.buffer, size)), [size]+list(self.buffer[0].shape))
24
+
25
+
26
+ class CycleBuffer:
27
+ def __init__(self,cap=5):
28
+ self.cap = cap
29
+ self.buffer = []
30
+ def append(self,v):
31
+ self.buffer.append(v)
32
+ l = len(self.buffer)
33
+ if l>self.cap:
34
+ self.buffer = self.buffer[l-self.cap:]
35
+
36
+ def __getitem__(self, slice):
37
+ return self.buffer[slice]
38
+
39
+ def __len__(self):
40
+ return len(self.buffer)
41
+
42
+ class AlwaysNullObj(object):
43
+ def __init__(self,*args,**kwargs):
44
+ print(f"Construct a always null object")
45
+ pass
46
+
47
+ def __getattr__(self, item):
48
+ return self
49
+
50
+ def __setattr__(self, key, value):
51
+ pass
52
+
53
+ def __delattr__(self, item):
54
+ pass
55
+
56
+ def __call__(self, *args, **kwargs):
57
+ return self
58
+
59
+
60
+ class MDict(dict):
61
+ def __init__(self, *args, **kwargs):
62
+ '''
63
+
64
+ Args:
65
+ *args:
66
+ **kwargs:
67
+
68
+ example 1:
69
+ x = MDict(dtype=list)
70
+ x[1].append('a')
71
+ x[2].append('b')
72
+ x[1].append('c')
73
+ print(x)
74
+ output:
75
+ {1: ['a', 'c'], 2: ['b']}
76
+
77
+ example 2:
78
+ x = MDict(dvalue=[])
79
+ x[1].append('a')
80
+ x[2].append('b')
81
+ x[1].append('c')
82
+ print(x)
83
+ output:
84
+ {1: ['a', 'c'], 2: ['b']}
85
+ '''
86
+ self.default_type = None
87
+ self.default_value = None
88
+ if "dtype" in kwargs:
89
+ self.default_type = kwargs.pop("dtype")
90
+ elif "dvalue" in kwargs:
91
+ self.default_value = kwargs.pop("dvalue")
92
+ super().__init__(*args,**kwargs)
93
+
94
+ @classmethod
95
+ def from_dict(cls,data:dict,auto_dtype=True):
96
+ x = data.values()
97
+ assert len(x)>0, "error dict data"
98
+ if auto_dtype:
99
+ dtype = type(list(x)[0])
100
+ ret = cls(dtype=dtype)
101
+ else:
102
+ ret = cls(dvalue=None)
103
+ for k,v in data.items():
104
+ ret[k] = v
105
+ return ret
106
+
107
+ def __getattr__(self, key):
108
+ if key in self.__dict__:
109
+ return self.__dict__[key]
110
+ return self.__getitem__(key)
111
+
112
+ def __call__(self,key):
113
+ return self.__getitem__(key)
114
+
115
+ def __getitem__(self, key):
116
+ if key in self.__dict__:
117
+ return self.__dict__[key]
118
+ if key in self:
119
+ return super().__getitem__(key)
120
+ elif self.default_type is not None:
121
+ super().__setitem__(key,self.default_type())
122
+ return super().__getitem__(key)
123
+ elif self.default_value is not None:
124
+ super().__setitem__(key,self.default_value)
125
+ return None
126
+
127
+ def __delattr__(self, key):
128
+ del self[key]
129
+
130
+
131
+ class Counter(dict):
132
+ '''
133
+ Example:
134
+ counter = Counter()
135
+ counter.add("a")
136
+ counter.add("a")
137
+ counter.add("b")
138
+ print(counter):
139
+ {
140
+ a: 2
141
+ b: 1
142
+ }
143
+ '''
144
+ def add(self,key,nr=1):
145
+ if key in self:
146
+ self[key] += nr
147
+ else:
148
+ self[key] = nr
149
+ return self[key]
150
+
151
+ def total_size(self):
152
+ return np.sum(list(self.values()))
153
+
154
+ class EDict(dict):
155
+ '''
156
+ 只能添加键值,不允许更新
157
+ '''
158
+ def __setitem__(self,item,value):
159
+ if item in self:
160
+ raise RuntimeError(f"ERROR: key {item} already exists.")
161
+ super().__setitem__(item,value)
162
+
163
+ def safe_update_dict(target_dict,src_dict,do_raise=True):
164
+ duplicate_keys = []
165
+ for k in src_dict.keys():
166
+ if k in target_dict and target_dict[k] != src_dict[k]:
167
+ duplicate_keys.append(k)
168
+
169
+
170
+ if len(duplicate_keys)>0:
171
+ if do_raise:
172
+ raise RuntimeError(f"key {duplicate_keys} already in target dict, target dict keys {list(target_dict.keys())}")
173
+ else:
174
+ print(f"ERROR: key {duplicate_keys} already in target dict, target dict keys {list(target_dict.keys())}")
175
+
176
+ target_dict.update(src_dict)
177
+
178
+ def trans_dict_key2lower(data):
179
+ res = type(data)()
180
+ for k,v in data.items():
181
+ res[k.lower()] = v
182
+ return res
183
+
184
+ def is_seq_of(seq, expected_type, seq_type=None):
185
+ """Check whether it is a sequence of some type.
186
+
187
+ Args:
188
+ seq (Sequence): The sequence to be checked.
189
+ expected_type (type): Expected type of sequence items.
190
+ seq_type (type, optional): Expected sequence type.
191
+
192
+ Returns:
193
+ bool: Whether the sequence is valid.
194
+ """
195
+ if seq_type is None:
196
+ exp_seq_type = abc.Sequence
197
+ else:
198
+ assert isinstance(seq_type, type)
199
+ exp_seq_type = seq_type
200
+ if not isinstance(seq, exp_seq_type):
201
+ return False
202
+ for item in seq:
203
+ if not isinstance(item, expected_type):
204
+ return False
205
+ return True
206
+
207
+ def is_list_of(seq, expected_type):
208
+ """Check whether it is a list of some type.
209
+
210
+ A partial method of :func:`is_seq_of`.
211
+ """
212
+ return is_seq_of(seq, expected_type, seq_type=list)