kevin-toolbox-dev 1.2.7__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/computer_science/algorithm/registration/registry.py +35 -11
- kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +146 -0
- kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +9 -36
- kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +10 -18
- kevin_toolbox/computer_science/algorithm/statistician/init_var/__init__.py +2 -0
- kevin_toolbox/computer_science/algorithm/statistician/init_var/init_by_data_format.py +24 -0
- kevin_toolbox/computer_science/algorithm/statistician/init_var/init_by_like.py +20 -0
- kevin_toolbox/data_flow/file/json_/read_json.py +1 -1
- kevin_toolbox/data_flow/file/markdown/generate_table.py +117 -13
- kevin_toolbox/nested_dict_list/copy_.py +7 -7
- kevin_toolbox/nested_dict_list/serializer/__init__.py +1 -0
- kevin_toolbox/nested_dict_list/serializer/enum_variable.py +12 -0
- kevin_toolbox/nested_dict_list/serializer/read.py +2 -1
- kevin_toolbox/nested_dict_list/serializer/variable.py +0 -12
- kevin_toolbox/nested_dict_list/serializer/write.py +2 -1
- kevin_toolbox/nested_dict_list/set_value.py +8 -2
- kevin_toolbox/patches/for_numpy/__init__.py +1 -0
- kevin_toolbox/patches/for_numpy/linalg/softmax.py +40 -4
- kevin_toolbox/patches/for_optuna/sample_from_feasible_domain.py +39 -13
- kevin_toolbox/patches/for_os/__init__.py +1 -0
- kevin_toolbox/patches/for_os/remove.py +4 -2
- kevin_toolbox/patches/for_os/walk.py +167 -0
- kevin_toolbox_dev-1.3.0.dist-info/METADATA +73 -0
- {kevin_toolbox_dev-1.2.7.dist-info → kevin_toolbox_dev-1.3.0.dist-info}/RECORD +28 -23
- kevin_toolbox/computer_science/algorithm/statistician/_init_var.py +0 -27
- kevin_toolbox_dev-1.2.7.dist-info/METADATA +0 -69
- {kevin_toolbox_dev-1.2.7.dist-info → kevin_toolbox_dev-1.3.0.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.2.7.dist-info → kevin_toolbox_dev-1.3.0.dist-info}/top_level.txt +0 -0
kevin_toolbox/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "1.
|
1
|
+
__version__ = "1.3.0"
|
2
2
|
|
3
3
|
|
4
4
|
import os
|
@@ -12,5 +12,5 @@ os.system(
|
|
12
12
|
os.system(
|
13
13
|
f'python {os.path.split(__file__)[0]}/env_info/check_validity_and_uninstall.py '
|
14
14
|
f'--package_name kevin-toolbox-dev '
|
15
|
-
f'--expiration_timestamp
|
15
|
+
f'--expiration_timestamp 1718018144 --verbose 0'
|
16
16
|
)
|
@@ -4,6 +4,8 @@ import inspect
|
|
4
4
|
import pkgutil
|
5
5
|
import weakref
|
6
6
|
import kevin_toolbox.nested_dict_list as ndl
|
7
|
+
from kevin_toolbox.patches import for_os
|
8
|
+
from kevin_toolbox.patches.for_os import Path_Ignorer, Ignore_Scope
|
7
9
|
|
8
10
|
|
9
11
|
class Registry:
|
@@ -218,14 +220,22 @@ class Registry:
|
|
218
220
|
|
219
221
|
# -------------------- 通过路径添加 --------------------- #
|
220
222
|
|
221
|
-
def collect_from_paths(self, path_ls=None,
|
223
|
+
def collect_from_paths(self, path_ls=None, ignore_s=None, b_execute_now=False):
|
222
224
|
"""
|
223
225
|
遍历 path_ls 下的所有模块,并自动导入其中主要被注册的部分
|
224
226
|
比如被 register() 装饰器包裹或者通过 add() 添加的部分
|
225
227
|
|
226
228
|
参数:
|
227
229
|
path_ls: <list of paths> 需要搜索的目录
|
228
|
-
|
230
|
+
ignore_s: <list/tuple of dict> 在搜索遍历目录时候要执行的排除规则
|
231
|
+
具体设置方式参考 patches.for_os.walk() 中的 ignore_s 参数。
|
232
|
+
比如使用下面的规则就可以排除待搜索目录下的所有 temp/ 和 test/ 文件夹:
|
233
|
+
[
|
234
|
+
{
|
235
|
+
"func": lambda _, __, path: os.path.basename(path) in ["temp", "test"],
|
236
|
+
"scope": ["root", "dirs"]
|
237
|
+
},
|
238
|
+
]
|
229
239
|
b_execute_now: <boolean> 现在就执行导入
|
230
240
|
默认为 False,将等到第一次执行 get() 函数时才会真正尝试导入
|
231
241
|
|
@@ -243,10 +253,13 @@ class Registry:
|
|
243
253
|
f'calling Registry.collect_from_paths() in __init__.py is forbidden, file: {caller_frame.filename}.\n' \
|
244
254
|
f'you can call it in other files, and then import the result of the call in __init__.py'
|
245
255
|
|
256
|
+
# 根据 ignore_s 构建 Path_Ignorer
|
257
|
+
path_ignorer = ignore_s if isinstance(ignore_s, (Path_Ignorer,)) else Path_Ignorer(ignore_s=ignore_s)
|
258
|
+
|
246
259
|
#
|
247
260
|
if not b_execute_now:
|
248
261
|
self._path_to_collect.append(
|
249
|
-
dict(path_ls=path_ls,
|
262
|
+
dict(path_ls=path_ls, ignore_s=path_ignorer, b_execute_now=True))
|
250
263
|
return
|
251
264
|
|
252
265
|
#
|
@@ -254,27 +267,38 @@ class Registry:
|
|
254
267
|
temp = []
|
255
268
|
for path in filter(lambda x: os.path.isdir(x), path_ls):
|
256
269
|
temp.append(path)
|
257
|
-
for root, dirs, _ in
|
270
|
+
for root, dirs, _ in for_os.walk(path, topdown=False, ignore_s=path_ignorer, followlinks=True):
|
258
271
|
temp.extend([os.path.join(root, i) for i in dirs])
|
259
|
-
if path_ls_to_exclude is not None:
|
260
|
-
for path_ex in path_ls_to_exclude:
|
261
|
-
if not os.path.exists(path_ex):
|
262
|
-
continue
|
263
|
-
for i in reversed(range(len(temp))):
|
264
|
-
if os.path.samefile(os.path.commonpath([path_ex, temp[i]]), path_ex):
|
265
|
-
temp.pop(i)
|
266
272
|
# 从深到浅导入,可以避免继承引起的 TypeError: super(type, obj) 类型错误
|
267
273
|
path_ls = list(set(temp))
|
268
274
|
path_ls.sort(reverse=True)
|
275
|
+
path_set = set(path_ls)
|
269
276
|
|
270
277
|
temp = None
|
271
278
|
for loader, module_name, is_pkg in pkgutil.walk_packages(path_ls):
|
279
|
+
# 判断该模块是否需要导入
|
280
|
+
# (快速判断)判断该模块所在目录是否在 path_set 中
|
281
|
+
if loader.path not in path_set:
|
282
|
+
continue
|
283
|
+
if is_pkg:
|
284
|
+
# 若不是 package,判断是否满足 Path_Ignorer 中的 dirs 对应的规则
|
285
|
+
path = os.path.dirname(loader.find_module(module_name).path)
|
286
|
+
if path_ignorer(Ignore_Scope.DIRS, True, os.path.islink(path), path):
|
287
|
+
continue
|
288
|
+
else:
|
289
|
+
# 若该模块是 package,判断该模块的文件路径是否满足 Path_Ignorer 中的 files 对应的规则
|
290
|
+
path = loader.find_module(module_name).path
|
291
|
+
if path_ignorer(Ignore_Scope.FILES, False, os.path.islink(path), path):
|
292
|
+
continue
|
293
|
+
# 加载模块
|
272
294
|
module = loader.find_module(module_name).load_module(module_name)
|
295
|
+
# 选择遍历过程中第一次找到的 Registry 实例
|
273
296
|
if temp is None:
|
274
297
|
for name, obj in inspect.getmembers(module):
|
275
298
|
if getattr(obj, "name", None) == Registry.name and getattr(obj, "uid", None) == self.uid:
|
276
299
|
temp = obj
|
277
300
|
break
|
301
|
+
|
278
302
|
if temp is not None:
|
279
303
|
self.database = temp.database
|
280
304
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
from kevin_toolbox.computer_science.algorithm.statistician import init_var
|
2
|
+
import kevin_toolbox.nested_dict_list as ndl
|
3
|
+
|
4
|
+
|
5
|
+
class Accumulator_Base(object):
|
6
|
+
"""
|
7
|
+
累积型统计器的抽象基类
|
8
|
+
|
9
|
+
包含以下变量:
|
10
|
+
self.var <torch.tensor / np.ndarray / int / float> 用于保存累积值
|
11
|
+
使用 statistician._init_var() 函数进行初始化。
|
12
|
+
有三种初始化累积值(指定输入数据的格式)方式:
|
13
|
+
1. (在初始化实例时)显式指定数据的形状和所在设备等。
|
14
|
+
data_format: <dict of paras>
|
15
|
+
其中需要包含以下参数:
|
16
|
+
type_: <str>
|
17
|
+
"numpy": np.ndarray
|
18
|
+
"torch": torch.tensor
|
19
|
+
shape: <list of integers>
|
20
|
+
device: <torch.device>
|
21
|
+
dtype: <torch.dtype>
|
22
|
+
2. (在初始化实例时)根据输入的数据,来推断出形状、设备等。
|
23
|
+
like: <torch.tensor / np.ndarray / int / float>
|
24
|
+
3. 不在初始化实例时指定 data_format 和 like,此时将等到第一次调用 add()/add_sequence()
|
25
|
+
时再根据输入来自动推断。
|
26
|
+
要实现该方式,需要在 add() 中添加:
|
27
|
+
if self.var is None:
|
28
|
+
self.var = self._init_var(like=var)
|
29
|
+
以上三种方式,默认选用最后一种。
|
30
|
+
如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
|
31
|
+
self.state <dict> 用于保存状态的字典
|
32
|
+
其中包含以下字段:
|
33
|
+
"total_nums": <int> 用于统计一共调用了多少次 self.add() 方法去进行累积
|
34
|
+
self.paras <dict> 用于保存构建实例时的各个参数
|
35
|
+
|
36
|
+
包含以下接口:
|
37
|
+
add() (*)添加单个数据
|
38
|
+
add_sequence() (*)添加一系列数据
|
39
|
+
get() (*)获取累积值
|
40
|
+
clear() 情况已有数据(self.var)和状态(self.state)
|
41
|
+
state_dict() 返回当前实例的状态(返回一个包含 self.var 和 self.state 的字典)
|
42
|
+
load_state_dict() 通过接受 state_dict 来更新当前实例的状态
|
43
|
+
|
44
|
+
有可能需要另外实现or覆写的函数:
|
45
|
+
上面带有 (*)的接口
|
46
|
+
_init_state() 方法
|
47
|
+
_init_var() 方法
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, **kwargs):
|
51
|
+
"""
|
52
|
+
至少包含以下参数:
|
53
|
+
data_format: 指定数据格式(对应方式1)
|
54
|
+
like: 指定数据格式(对应方式2)
|
55
|
+
"""
|
56
|
+
# 默认参数
|
57
|
+
paras = {
|
58
|
+
# 指定输入数据的形状、设备
|
59
|
+
"data_format": None,
|
60
|
+
"like": None,
|
61
|
+
}
|
62
|
+
|
63
|
+
# 获取参数
|
64
|
+
paras.update(kwargs)
|
65
|
+
|
66
|
+
# 校验参数
|
67
|
+
#
|
68
|
+
self.paras = paras
|
69
|
+
self.var = self._init_var(like=paras["like"], data_format=paras["data_format"])
|
70
|
+
self.state = self._init_state()
|
71
|
+
|
72
|
+
def add_sequence(self, var_ls, **kwargs):
|
73
|
+
# for var in var_ls:
|
74
|
+
# self.add(var)
|
75
|
+
raise NotImplementedError
|
76
|
+
|
77
|
+
def add(self, var, **kwargs):
|
78
|
+
"""
|
79
|
+
添加单个数据
|
80
|
+
|
81
|
+
参数:
|
82
|
+
var: 数据
|
83
|
+
"""
|
84
|
+
# if self.var is None:
|
85
|
+
# self.var = self._init_var(like=var)
|
86
|
+
# # 累积
|
87
|
+
# self.state["total_nums"] += 1
|
88
|
+
# # 对 self.var 做处理(需要具体实现)
|
89
|
+
raise NotImplementedError
|
90
|
+
|
91
|
+
def get(self, **kwargs):
|
92
|
+
"""
|
93
|
+
获取当前累加的平均值
|
94
|
+
当未有累积时,返回 None
|
95
|
+
"""
|
96
|
+
# if len(self) == 0:
|
97
|
+
# return None
|
98
|
+
# # 对 self.var 做处理并返回(需要具体实现)
|
99
|
+
raise NotImplementedError
|
100
|
+
|
101
|
+
@staticmethod
|
102
|
+
def _init_state():
|
103
|
+
"""
|
104
|
+
初始化状态
|
105
|
+
"""
|
106
|
+
return dict(
|
107
|
+
total_nums=0,
|
108
|
+
)
|
109
|
+
|
110
|
+
@staticmethod
|
111
|
+
def _init_var(like=None, data_format=None):
|
112
|
+
if like is not None:
|
113
|
+
var = init_var.by_like(var=like)
|
114
|
+
elif data_format is not None:
|
115
|
+
var = init_var.by_data_format(**kwargs)
|
116
|
+
else:
|
117
|
+
var = None
|
118
|
+
return var
|
119
|
+
|
120
|
+
def clear(self):
|
121
|
+
self.var = self._init_var(like=self.var)
|
122
|
+
self.state = self._init_state()
|
123
|
+
|
124
|
+
def __len__(self):
|
125
|
+
return self.state["total_nums"]
|
126
|
+
|
127
|
+
# ---------------------- 用于保存和加载状态 ---------------------- #
|
128
|
+
|
129
|
+
def load_state_dict(self, state_dict):
|
130
|
+
"""
|
131
|
+
加载状态
|
132
|
+
"""
|
133
|
+
self.clear()
|
134
|
+
self.state.update(state_dict.get("state", dict()))
|
135
|
+
if state_dict.get("var", None) is not None:
|
136
|
+
if self.var is None:
|
137
|
+
self.var = state_dict["var"]
|
138
|
+
else:
|
139
|
+
self.var *= 0
|
140
|
+
self.var += state_dict["var"]
|
141
|
+
|
142
|
+
def state_dict(self):
|
143
|
+
"""
|
144
|
+
获取状态
|
145
|
+
"""
|
146
|
+
return ndl.copy_(var={"state": self.state, "var": self.var}, b_deepcopy=True, b_keep_internal_references=True)
|
@@ -1,10 +1,7 @@
|
|
1
|
-
import
|
2
|
-
import torch
|
3
|
-
import numpy as np
|
4
|
-
from kevin_toolbox.computer_science.algorithm.statistician._init_var import _init_var
|
1
|
+
from kevin_toolbox.computer_science.algorithm.statistician import Accumulator_Base
|
5
2
|
|
6
3
|
|
7
|
-
class Average_Accumulator:
|
4
|
+
class Average_Accumulator(Accumulator_Base):
|
8
5
|
"""
|
9
6
|
用于计算平均值的累积器
|
10
7
|
"""
|
@@ -30,30 +27,13 @@ class Average_Accumulator:
|
|
30
27
|
以上三种方式,默认选用最后一种。
|
31
28
|
如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
|
32
29
|
"""
|
30
|
+
super(Average_Accumulator, self).__init__(**kwargs)
|
33
31
|
|
34
|
-
|
35
|
-
paras = {
|
36
|
-
# 指定输入数据的形状、设备
|
37
|
-
"data_format": None,
|
38
|
-
"like": None,
|
39
|
-
}
|
40
|
-
|
41
|
-
# 获取参数
|
42
|
-
paras.update(kwargs)
|
43
|
-
|
44
|
-
# 校验参数
|
45
|
-
#
|
46
|
-
self.paras = paras
|
47
|
-
self.var = _init_var(like=paras["like"], data_format=paras["data_format"])
|
48
|
-
self.state = dict(
|
49
|
-
total_nums=0,
|
50
|
-
)
|
51
|
-
|
52
|
-
def add_sequence(self, var_ls):
|
32
|
+
def add_sequence(self, var_ls, **kwargs):
|
53
33
|
for var in var_ls:
|
54
34
|
self.add(var)
|
55
35
|
|
56
|
-
def add(self, var):
|
36
|
+
def add(self, var, **kwargs):
|
57
37
|
"""
|
58
38
|
添加单个数据
|
59
39
|
|
@@ -61,12 +41,12 @@ class Average_Accumulator:
|
|
61
41
|
var: 数据
|
62
42
|
"""
|
63
43
|
if self.var is None:
|
64
|
-
self.var = _init_var(like=var)
|
44
|
+
self.var = self._init_var(like=var)
|
65
45
|
# 累积
|
66
46
|
self.var += var
|
67
47
|
self.state["total_nums"] += 1
|
68
48
|
|
69
|
-
def get(self):
|
49
|
+
def get(self, **kwargs):
|
70
50
|
"""
|
71
51
|
获取当前累加的平均值
|
72
52
|
当未有累积时,返回 None
|
@@ -75,17 +55,10 @@ class Average_Accumulator:
|
|
75
55
|
return None
|
76
56
|
return self.var / len(self)
|
77
57
|
|
78
|
-
def clear(self):
|
79
|
-
self.var = _init_var(like=self.var)
|
80
|
-
self.state = dict(
|
81
|
-
total_nums=0,
|
82
|
-
)
|
83
|
-
|
84
|
-
def __len__(self):
|
85
|
-
return self.state["total_nums"]
|
86
|
-
|
87
58
|
|
88
59
|
if __name__ == '__main__':
|
60
|
+
import torch
|
61
|
+
|
89
62
|
seq = list(torch.tensor(range(1, 10)))
|
90
63
|
avg = Average_Accumulator()
|
91
64
|
for i, v in enumerate(seq):
|
@@ -1,10 +1,7 @@
|
|
1
|
-
import
|
2
|
-
import torch
|
3
|
-
import numpy as np
|
4
|
-
from kevin_toolbox.computer_science.algorithm.statistician._init_var import _init_var
|
1
|
+
from kevin_toolbox.computer_science.algorithm.statistician import Accumulator_Base
|
5
2
|
|
6
3
|
|
7
|
-
class Exponential_Moving_Average:
|
4
|
+
class Exponential_Moving_Average(Accumulator_Base):
|
8
5
|
"""
|
9
6
|
滑动平均器
|
10
7
|
支持为每个输入数据配置不同的权重
|
@@ -59,12 +56,7 @@ class Exponential_Moving_Average:
|
|
59
56
|
# 校验参数
|
60
57
|
assert isinstance(paras["keep_ratio"], (int, float,)) and 0 <= paras["keep_ratio"] <= 1
|
61
58
|
#
|
62
|
-
self.paras
|
63
|
-
self.var = _init_var(like=paras["like"], data_format=paras["data_format"])
|
64
|
-
self.state = dict(
|
65
|
-
total_nums=0,
|
66
|
-
bias_fix=1,
|
67
|
-
)
|
59
|
+
super(Exponential_Moving_Average, self).__init__(**paras)
|
68
60
|
|
69
61
|
def add_sequence(self, var_ls, weight_ls=None):
|
70
62
|
if weight_ls is not None:
|
@@ -87,7 +79,7 @@ class Exponential_Moving_Average:
|
|
87
79
|
默认为 1
|
88
80
|
"""
|
89
81
|
if self.var is None:
|
90
|
-
self.var = _init_var(like=var)
|
82
|
+
self.var = self._init_var(like=var)
|
91
83
|
new_ratio = (1 - self.paras["keep_ratio"]) * weight
|
92
84
|
keep_ratio = (1 - new_ratio)
|
93
85
|
# 累积
|
@@ -113,18 +105,18 @@ class Exponential_Moving_Average:
|
|
113
105
|
else:
|
114
106
|
return self.var
|
115
107
|
|
116
|
-
|
117
|
-
|
118
|
-
|
108
|
+
@staticmethod
|
109
|
+
def _init_state():
|
110
|
+
return dict(
|
119
111
|
total_nums=0,
|
120
112
|
bias_fix=1,
|
121
113
|
)
|
122
114
|
|
123
|
-
def __len__(self):
|
124
|
-
return self.state["total_nums"]
|
125
|
-
|
126
115
|
|
127
116
|
if __name__ == '__main__':
|
117
|
+
import torch
|
118
|
+
import numpy as np
|
119
|
+
|
128
120
|
seq = list(torch.tensor(range(1, 10)))
|
129
121
|
wls = np.asarray([0.1] * 5 + [0.9] + [0.1] * 4) * 0.1
|
130
122
|
ema = Exponential_Moving_Average(keep_ratio=0.9, bias_correction=True)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
|
5
|
+
def init_by_data_format(type_, shape, **kwargs):
|
6
|
+
"""
|
7
|
+
构建一个与输入 var 具有相同类型、形状、设备的 0 数组
|
8
|
+
|
9
|
+
参数:
|
10
|
+
type_: <str>
|
11
|
+
"numpy": np.ndarray
|
12
|
+
"torch": torch.tensor
|
13
|
+
"number": float
|
14
|
+
shape: <list of integers>
|
15
|
+
device: <torch.device>
|
16
|
+
dtype: <torch.dtype / np.dtype>
|
17
|
+
"""
|
18
|
+
if type_ == "torch":
|
19
|
+
res = torch.zeros(size=shape, **kwargs)
|
20
|
+
elif type_ == "numpy":
|
21
|
+
res = np.zeros(shape=shape, **kwargs)
|
22
|
+
else:
|
23
|
+
res = 0.0
|
24
|
+
return res
|
@@ -0,0 +1,20 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
|
5
|
+
def init_by_like(var):
|
6
|
+
"""
|
7
|
+
构建一个与输入 var 具有相同类型、形状、设备的 0 数组
|
8
|
+
|
9
|
+
参数:
|
10
|
+
var: <torch.tensor / np.ndarray / int / float>
|
11
|
+
"""
|
12
|
+
if torch.is_tensor(var):
|
13
|
+
res = torch.zeros_like(var)
|
14
|
+
elif isinstance(var, (np.ndarray,)):
|
15
|
+
res = np.zeros_like(var)
|
16
|
+
elif isinstance(var, (int, float, np.number,)):
|
17
|
+
res = 0.0
|
18
|
+
else:
|
19
|
+
raise ValueError("paras 'like' should be np.ndarray, torch.tensor or int/float")
|
20
|
+
return res
|
@@ -19,7 +19,7 @@ def read_json(file_path, converters=None, b_use_suggested_converter=False):
|
|
19
19
|
默认为 False。
|
20
20
|
注意:当 converters 非 None,此参数失效,以 converters 中的具体设置为准
|
21
21
|
"""
|
22
|
-
assert os.path.isfile(file_path)
|
22
|
+
assert os.path.isfile(file_path), f'file {file_path} not found'
|
23
23
|
if converters is None and b_use_suggested_converter:
|
24
24
|
converters = [unescape_tuple, unescape_non_str_dict_key]
|
25
25
|
|
@@ -1,21 +1,106 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
from kevin_toolbox.math.utils import spilt_integer_most_evenly
|
2
|
+
|
3
|
+
|
4
|
+
def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_size=None, b_allow_misaligned_values=False,
|
5
|
+
f_gen_order_of_values=None):
|
6
|
+
"""
|
7
|
+
生成表格
|
8
|
+
|
9
|
+
参数:
|
10
|
+
content_s: <dict> 内容
|
11
|
+
支持两种输入模式:
|
12
|
+
1.简易模式:
|
13
|
+
content_s = {<title>: <list of value>, ...}
|
14
|
+
此时键作为标题,值作为标题下的一系列值。
|
15
|
+
由于字典的无序性,此时标题的顺序是不能保证的,若要额外指定顺序,请使用下面的 完整模式。
|
16
|
+
2. 完整模式:
|
17
|
+
content_s = {<index>: {"title": <title>,"values":<list of value>}, ...}
|
18
|
+
此时将取第 <index> 个 "title" 的值来作为第 <index> 个标题的值。values 同理。
|
19
|
+
该模式允许缺省某些 <index>,此时这些 <index> 对应的行/列将全部置空。
|
20
|
+
orientation: <str> 表格的方向
|
21
|
+
支持以下值:
|
22
|
+
"vertical" / "v": 纵向排列,亦即标题在第一行
|
23
|
+
"horizontal" / "h": 横向排列,亦即标题在第一列
|
24
|
+
chunk_nums: <int> 将表格平均分割为多少份进行并列显示。
|
25
|
+
chunk_size: <int> 将表格按照最大长度进行分割,然后并列显示。
|
26
|
+
注意:以上两个参数只能设置一个,同时设置时将报错
|
27
|
+
b_allow_misaligned_values <boolean> 允许不对齐的 values
|
28
|
+
默认为 False,此时当不同标题下的 values 的长度不相等时,将会直接报错。
|
29
|
+
当设置为 True 时,对于短于最大长度的 values 将直接补充 ""。
|
30
|
+
f_gen_order_of_values: <callable> 生成values排序顺序的函数
|
31
|
+
该函数需要接受一个形如 {<title>: <value>, ...} 的 <dict>,并返回一个用于排序的 int/float/tuple
|
32
|
+
"""
|
33
|
+
# 检验参数
|
34
|
+
assert chunk_nums is None or 1 <= chunk_nums
|
35
|
+
assert chunk_size is None or 1 <= chunk_size
|
3
36
|
assert orientation in ["vertical", "horizontal", "h", "v"]
|
4
|
-
min_row_nums = min([len(i) for i in content_s.values()])
|
5
|
-
for k in ordered_keys:
|
6
|
-
assert len(content_s[k]) == min_row_nums, \
|
7
|
-
f'number of rows ({len(content_s[k])}) in column {k} exceeds ' \
|
8
|
-
f'the minimum number ({min_row_nums}) of rows in content'
|
9
37
|
|
38
|
+
# 将简易模式转换为完整模式
|
39
|
+
if len(content_s.values()) > 0 and not isinstance(list(content_s.values())[0], (dict,)):
|
40
|
+
content_s = {i: {"title": k, "values": v} for i, (k, v) in enumerate(content_s.items())}
|
41
|
+
# 对齐 values
|
42
|
+
len_ls = [len(v["values"]) for v in content_s.values()]
|
43
|
+
max_length = max(len_ls)
|
44
|
+
if min(len_ls) != max_length:
|
45
|
+
assert b_allow_misaligned_values, \
|
46
|
+
f'The lengths of the values under each title are not aligned. ' \
|
47
|
+
f'The maximum length is {max_length}, but each length is {len_ls}'
|
48
|
+
for v in content_s.values():
|
49
|
+
v["values"].extend([""] * (max_length - len(v["values"])))
|
50
|
+
# 对值进行排序
|
51
|
+
if callable(f_gen_order_of_values):
|
52
|
+
idx_ls = list(range(max_length))
|
53
|
+
idx_ls.sort(key=lambda x: f_gen_order_of_values({v["title"]: v["values"][x] for v in content_s.values()}))
|
54
|
+
for v in content_s.values():
|
55
|
+
v["values"] = [v["values"][i] for i in idx_ls]
|
56
|
+
# 补充缺省的 title
|
57
|
+
for i in range(max(content_s.keys()) + 1):
|
58
|
+
if i not in content_s:
|
59
|
+
content_s[i] = {"title": "", "values": [""] * max_length}
|
60
|
+
# 按照 chunk_nums 或者 chunk_size 对表格进行分割
|
61
|
+
if chunk_nums is not None or chunk_size is not None:
|
62
|
+
if chunk_nums is not None:
|
63
|
+
split_len_ls = spilt_integer_most_evenly(x=max_length, group_nums=chunk_nums)
|
64
|
+
else:
|
65
|
+
split_len_ls = [chunk_size] * (max_length // chunk_size)
|
66
|
+
if max_length % chunk_size != 0:
|
67
|
+
split_len_ls += [max_length % chunk_size]
|
68
|
+
max_length = max(split_len_ls)
|
69
|
+
temp = dict()
|
70
|
+
beg = 0
|
71
|
+
for i, new_length in enumerate(split_len_ls):
|
72
|
+
end = beg + new_length
|
73
|
+
temp.update({k + i * len(content_s): {"title": v["title"],
|
74
|
+
"values": v["values"][beg:end] + [""] * (max_length - new_length)} for
|
75
|
+
k, v in content_s.items()})
|
76
|
+
beg = end
|
77
|
+
content_s = temp
|
78
|
+
# 构建表格
|
79
|
+
return _show_table(content_s=content_s, orientation=orientation)
|
80
|
+
|
81
|
+
|
82
|
+
def _show_table(content_s, orientation="vertical"):
|
83
|
+
"""
|
84
|
+
生成表格
|
85
|
+
|
86
|
+
参数:
|
87
|
+
content_s: <dict> 内容
|
88
|
+
content_s = {<index>: {"title": <title>,"values":<list of value>}, ...}
|
89
|
+
此时将取第 <index> 个 "title" 的值来作为第 <index> 个标题的值。values 同理。
|
90
|
+
orientation: <str> 表格的方向
|
91
|
+
支持以下值:
|
92
|
+
"vertical" / "v": 纵向排列,亦即标题在第一行
|
93
|
+
"horizontal" / "h": 横向排列,亦即标题在第一列
|
94
|
+
"""
|
10
95
|
table = ""
|
11
96
|
if orientation in ["vertical", "v"]:
|
12
|
-
table += "| " + " | ".join([f'{i}' for i in
|
13
|
-
table += "| " + " | ".join(["---"] * len(
|
14
|
-
for row in zip(*[content_s[
|
97
|
+
table += "| " + " | ".join([f'{content_s[i]["title"]}' for i in range(len(content_s))]) + " |\n"
|
98
|
+
table += "| " + " | ".join(["---"] * len(content_s)) + " |\n"
|
99
|
+
for row in zip(*[content_s[i]["values"] for i in range(len(content_s))]):
|
15
100
|
table += "| " + " | ".join([f'{i}' for i in row]) + " |\n"
|
16
101
|
else:
|
17
|
-
for i
|
18
|
-
row = [f'{
|
102
|
+
for i in range(len(content_s)):
|
103
|
+
row = [f'{content_s[i]["title"]}'] + [f'{i}' for i in content_s[i]["values"]]
|
19
104
|
table += "| " + " | ".join(row) + " |\n"
|
20
105
|
if i == 0:
|
21
106
|
table += "| " + " | ".join(["---"] * len(row)) + " |\n"
|
@@ -23,4 +108,23 @@ def generate_table(content_s, ordered_keys=None, orientation="vertical"):
|
|
23
108
|
|
24
109
|
|
25
110
|
if __name__ == '__main__':
|
26
|
-
|
111
|
+
content_s = {0: dict(title="a", values=[1, 2, 3]), 2: dict(title="b", values=[4, 5, 6])}
|
112
|
+
doc = generate_table(content_s=content_s, orientation="h")
|
113
|
+
print(doc)
|
114
|
+
|
115
|
+
# from collections import OrderedDict
|
116
|
+
#
|
117
|
+
# content_s = OrderedDict({
|
118
|
+
# "y/n": [True] * 5 + [False] * 5,
|
119
|
+
# "a": list(range(10)),
|
120
|
+
# "b": [chr(i) for i in range(50, 60, 2)]
|
121
|
+
# })
|
122
|
+
# doc = generate_table(content_s=content_s, orientation="v", chunk_size=4, b_allow_misaligned_values=True,
|
123
|
+
# f_gen_order_of_values=lambda x: (-int(x["y/n"] is False), -(x["a"] % 3)))
|
124
|
+
# print(doc)
|
125
|
+
# import os
|
126
|
+
#
|
127
|
+
# with open(os.path.join(
|
128
|
+
# "/home/SENSETIME/xukaiming/Desktop/my_repos/python_projects/kevin_toolbox/kevin_toolbox/data_flow/file/markdown/test/test_data/for_generate_table",
|
129
|
+
# f"data_5.md"), "w") as f:
|
130
|
+
# f.write(doc)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import copy
|
2
|
-
|
2
|
+
from kevin_toolbox.nested_dict_list import traverse
|
3
3
|
import torch
|
4
4
|
|
5
5
|
|
@@ -121,9 +121,9 @@ def _copy_structure(var, b_keep_internal_references):
|
|
121
121
|
else:
|
122
122
|
return value.copy()
|
123
123
|
|
124
|
-
return
|
125
|
-
|
126
|
-
|
124
|
+
return traverse(var=[var], match_cond=lambda _, __, value: isinstance(value, (list, dict,)),
|
125
|
+
action_mode="replace", converter=converter,
|
126
|
+
traversal_mode="dfs_pre_order", b_traverse_matched_element=True)[0]
|
127
127
|
|
128
128
|
|
129
129
|
def _copy_nodes(var, b_keep_internal_references):
|
@@ -147,9 +147,9 @@ def _copy_nodes(var, b_keep_internal_references):
|
|
147
147
|
else:
|
148
148
|
return copy_item(value)
|
149
149
|
|
150
|
-
return
|
151
|
-
|
152
|
-
|
150
|
+
return traverse(var=[var], match_cond=lambda _, __, value: not isinstance(value, (list, dict,)),
|
151
|
+
action_mode="replace", converter=converter,
|
152
|
+
traversal_mode="dfs_pre_order", b_traverse_matched_element=True)[0]
|
153
153
|
|
154
154
|
|
155
155
|
if __name__ == '__main__':
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class Strictness_Level(Enum):
|
5
|
+
"""
|
6
|
+
对于正确性与完整性的要求的严格程度
|
7
|
+
"""
|
8
|
+
COMPLETE = "high" # 所有节点均有一个或者多个匹配上的 backend,且第一个匹配上的 backend 就成功写入。
|
9
|
+
COMPATIBLE = "normal" # 所有节点均有一个或者多个匹配上的 backend,但是首先匹配到的 backend 写入出错,使用其后再次匹配到的其他 backend 能够成功写入
|
10
|
+
# (这种情况更多应归咎于 backend 的 writable() 方法无法拒绝所有错误输入或者 backend 本身没有按照预期工作。一般而言这对最终写入内容的正确不会有太大影响。)
|
11
|
+
# 这个等级是默认等级
|
12
|
+
IGNORE_FAILURE = "low" # 匹配不完整,或者某些节点尝试过所有匹配到的 backend 之后仍然无法写入
|