kevin-toolbox-dev 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/computer_science/algorithm/pareto_front/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/pareto_front/optimum_picker.py +218 -0
- kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +2 -2
- kevin_toolbox/computer_science/algorithm/statistician/accumulator_for_ndl.py +69 -0
- kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +16 -4
- kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +5 -12
- kevin_toolbox/data_flow/file/json_/read_json.py +11 -5
- kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_writer.py +4 -1
- kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation_debug.py +27 -0
- kevin_toolbox/data_flow/file/kevin_notation/write.py +3 -0
- kevin_toolbox/data_flow/file/markdown/__init__.py +3 -0
- kevin_toolbox/data_flow/file/markdown/find_tables.py +65 -0
- kevin_toolbox/data_flow/file/markdown/generate_table.py +19 -5
- kevin_toolbox/data_flow/file/markdown/parse_table.py +135 -0
- kevin_toolbox/data_flow/file/markdown/save_images_in_ndl.py +81 -0
- kevin_toolbox/data_flow/file/markdown/variable.py +17 -0
- kevin_toolbox/nested_dict_list/serializer/backends/_ndl.py +4 -1
- kevin_toolbox/nested_dict_list/serializer/read.py +18 -14
- kevin_toolbox/nested_dict_list/serializer/write.py +23 -7
- kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +6 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +54 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +60 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +65 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +61 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +53 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +54 -0
- kevin_toolbox/patches/for_os/__init__.py +3 -0
- kevin_toolbox/patches/for_os/copy.py +33 -0
- kevin_toolbox/patches/for_os/find_files_in_dir.py +30 -0
- kevin_toolbox/patches/for_os/path/__init__.py +2 -0
- kevin_toolbox/patches/for_os/path/find_illegal_chars.py +47 -0
- kevin_toolbox/patches/for_os/path/replace_illegal_chars.py +49 -0
- kevin_toolbox/patches/for_test/check_consistency.py +104 -33
- kevin_toolbox_dev-1.3.6.dist-info/METADATA +95 -0
- {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/RECORD +39 -20
- kevin_toolbox_dev-1.3.4.dist-info/METADATA +0 -67
- {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,53 @@
|
|
1
|
+
import os
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
from kevin_toolbox.patches.for_matplotlib import generate_color_list
|
4
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
5
|
+
|
6
|
+
|
7
|
+
def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, **kwargs):
|
8
|
+
paras = {
|
9
|
+
"dpi": 200,
|
10
|
+
"scatter_size": 5
|
11
|
+
}
|
12
|
+
paras.update(kwargs)
|
13
|
+
|
14
|
+
plt.clf()
|
15
|
+
#
|
16
|
+
color_s = None
|
17
|
+
if cate_name is not None:
|
18
|
+
cates = list(set(data_s[cate_name]))
|
19
|
+
color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
|
20
|
+
c = [color_s[i] for i in data_s[cate_name]]
|
21
|
+
else:
|
22
|
+
c = "blue"
|
23
|
+
# 创建散点图
|
24
|
+
plt.scatter(data_s[x_name], data_s[y_name], s=paras["scatter_size"], c=c, alpha=0.8)
|
25
|
+
#
|
26
|
+
plt.xlabel(f'{x_name}')
|
27
|
+
plt.ylabel(f'{y_name}')
|
28
|
+
plt.title(f'{title}')
|
29
|
+
# 添加图例
|
30
|
+
if cate_name is not None:
|
31
|
+
plt.legend(handles=[
|
32
|
+
plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
|
33
|
+
markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
|
34
|
+
])
|
35
|
+
|
36
|
+
if output_dir is None:
|
37
|
+
plt.show()
|
38
|
+
return None
|
39
|
+
else:
|
40
|
+
os.makedirs(output_dir, exist_ok=True)
|
41
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
42
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
43
|
+
return output_path
|
44
|
+
|
45
|
+
|
46
|
+
if __name__ == '__main__':
|
47
|
+
data_s_ = dict(
|
48
|
+
x=[1, 2, 3, 4, 5],
|
49
|
+
y=[2, 4, 6, 8, 10],
|
50
|
+
categories=['A', 'B', 'A', 'B', 'A']
|
51
|
+
)
|
52
|
+
|
53
|
+
plot_scatters(data_s=data_s_, title='test', x_name='x', y_name='y', cate_name='categories')
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import os
|
2
|
+
import pandas as pd
|
3
|
+
import seaborn as sns
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
from kevin_toolbox.patches.for_matplotlib import generate_color_list
|
6
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
7
|
+
|
8
|
+
|
9
|
+
def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, cate_color_s=None, **kwargs):
|
10
|
+
paras = {
|
11
|
+
"dpi": 200,
|
12
|
+
"diag_kind": "kde" # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
|
13
|
+
}
|
14
|
+
assert cate_name in data_s and len(set(x_name_ls).difference(set(data_s.keys()))) == 0
|
15
|
+
if cate_color_s is None:
|
16
|
+
temp = set(data_s[cate_name])
|
17
|
+
cate_color_s = {k: v for k, v in zip(temp, generate_color_list(len(temp)))}
|
18
|
+
assert set(cate_color_s.keys()) == set(data_s[cate_name])
|
19
|
+
paras.update(kwargs)
|
20
|
+
|
21
|
+
plt.clf()
|
22
|
+
# 使用seaborn绘制散点图矩阵
|
23
|
+
sns.pairplot(
|
24
|
+
pd.DataFrame(data_s),
|
25
|
+
diag_kind=paras["diag_kind"], # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
|
26
|
+
hue=cate_name, # hue 表示根据该列的值进行分类
|
27
|
+
palette=cate_color_s, x_vars=x_name_ls, y_vars=x_name_ls, # x_vars,y_vars 指定子图的排列顺序
|
28
|
+
)
|
29
|
+
#
|
30
|
+
plt.subplots_adjust(top=0.95)
|
31
|
+
plt.suptitle(f'{title}', y=0.98, x=0.47)
|
32
|
+
# g.fig.suptitle(f'{title}', y=1.05)
|
33
|
+
|
34
|
+
if output_dir is None:
|
35
|
+
plt.show()
|
36
|
+
return None
|
37
|
+
else:
|
38
|
+
os.makedirs(output_dir, exist_ok=True)
|
39
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
40
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
41
|
+
return output_path
|
42
|
+
|
43
|
+
|
44
|
+
if __name__ == '__main__':
|
45
|
+
data_s_ = dict(
|
46
|
+
x=[1, 2, 3, 4, 5],
|
47
|
+
y=[2, 4, 6, 8, 10],
|
48
|
+
z=[2, 4, 6, 8, 10],
|
49
|
+
categories=['A', 'B', 'A', 'B', 'A'],
|
50
|
+
title='test',
|
51
|
+
)
|
52
|
+
|
53
|
+
plot_scatters_matrix(data_s=data_s_, title='test', x_name_ls=['y', 'x', 'z'], cate_name='categories',
|
54
|
+
cate_color_s={'A': 'red', 'B': 'blue'})
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import os
|
2
|
+
import shutil
|
3
|
+
from kevin_toolbox.patches.for_os import remove
|
4
|
+
|
5
|
+
|
6
|
+
def copy(src, dst, follow_symlinks=True, remove_dst_if_exists=False):
|
7
|
+
"""
|
8
|
+
复制文件/文件夹/软连接
|
9
|
+
|
10
|
+
参数:
|
11
|
+
follow_symlinks: <boolean> 是否跟随符号链接
|
12
|
+
对于 src 是软连接或者 src 指向的目录下具有软连接的情况,
|
13
|
+
当设置为 True 时,会复制所有链接指向的实际文件或目录内容。
|
14
|
+
当设置为 False 时,则仅仅软连接本身,而不是指向的内容
|
15
|
+
默认为 True
|
16
|
+
remove_dst_if_exists: <boolean> 当目标存在时,是否尝试进行移除
|
17
|
+
"""
|
18
|
+
assert os.path.exists(src), f'failed to copy, src not exists: {src}'
|
19
|
+
if remove_dst_if_exists:
|
20
|
+
remove(path=dst, ignore_errors=True)
|
21
|
+
assert not os.path.exists(dst), f'failed to copy, dst exists: {dst}'
|
22
|
+
|
23
|
+
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
24
|
+
if os.path.isdir(src):
|
25
|
+
if not follow_symlinks and os.path.islink(src):
|
26
|
+
# 如果是符号链接,并且我们跟随符号链接,则复制链接本身
|
27
|
+
os.symlink(os.readlink(src), dst)
|
28
|
+
else:
|
29
|
+
# 否则,递归复制目录
|
30
|
+
shutil.copytree(src, dst, symlinks=not follow_symlinks)
|
31
|
+
else:
|
32
|
+
# 复制文件
|
33
|
+
shutil.copy2(src, dst, follow_symlinks=follow_symlinks)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import os
|
2
|
+
from kevin_toolbox.patches.for_os import walk
|
3
|
+
|
4
|
+
|
5
|
+
def find_files_in_dir(input_dir, suffix_ls, b_relative_path=True, b_ignore_case=True):
|
6
|
+
"""
|
7
|
+
找出目录下带有给定后缀的所有文件的生成器
|
8
|
+
主要利用了 for_os.walk 中的过滤规则进行实现
|
9
|
+
|
10
|
+
参数:
|
11
|
+
suffix_ls: <list/tuple of str> 可选的后缀
|
12
|
+
b_relative_path: <bool> 是否返回相对路径
|
13
|
+
b_ignore_case: <bool> 是否忽略大小写
|
14
|
+
"""
|
15
|
+
suffix_ls = tuple(set(suffix_ls))
|
16
|
+
suffix_ls = tuple(map(lambda x: x.lower(), suffix_ls)) if b_ignore_case else suffix_ls
|
17
|
+
for root, dirs, files in walk(top=input_dir, topdown=True,
|
18
|
+
ignore_s=[{
|
19
|
+
"func": lambda _, b_is_symlink, path: b_is_symlink or not (
|
20
|
+
path.lower() if b_ignore_case else path).endswith(suffix_ls),
|
21
|
+
"scope": ["files", ]
|
22
|
+
}]):
|
23
|
+
for file in files:
|
24
|
+
file_path = os.path.join(root, file)
|
25
|
+
if b_relative_path:
|
26
|
+
file_path = os.path.relpath(file_path, start=input_dir)
|
27
|
+
yield file_path
|
28
|
+
|
29
|
+
|
30
|
+
|
@@ -0,0 +1,47 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
|
4
|
+
ILLEGAL_CHARS = {
|
5
|
+
'<': '<',
|
6
|
+
'>': '>',
|
7
|
+
':': ':',
|
8
|
+
'"': '"',
|
9
|
+
'/': '/',
|
10
|
+
'\\': '\',
|
11
|
+
'|': '|',
|
12
|
+
'?': '?',
|
13
|
+
'*': '*'
|
14
|
+
}
|
15
|
+
|
16
|
+
ILLEGAL_CHARS_PATTERN = re.compile(r'[<>:"/\\|?*]')
|
17
|
+
|
18
|
+
|
19
|
+
def find_illegal_chars(file_name, b_is_path=False):
|
20
|
+
"""
|
21
|
+
找出给定的文件名/路径中出现了哪些非法符号
|
22
|
+
所谓非法符号是指在特定系统(win/mac等)中文件名不允许出现的字符,
|
23
|
+
不建议在任何系统下使用带有这些符号的文件名,即使这些符号在当前系统中是合法的,
|
24
|
+
以免在跨系统时出现兼容性问题
|
25
|
+
|
26
|
+
参数:
|
27
|
+
file_name: <str>
|
28
|
+
b_is_path: <bool> 是否将file_name视为路径
|
29
|
+
默认为 False
|
30
|
+
当设置为 True 时,将会首先将file_name分割,再逐级从目录名中查找
|
31
|
+
"""
|
32
|
+
global ILLEGAL_CHARS_PATTERN
|
33
|
+
|
34
|
+
if b_is_path:
|
35
|
+
temp = [i for i in file_name.split(os.sep, -1) if len(i) > 0]
|
36
|
+
else:
|
37
|
+
temp = [file_name]
|
38
|
+
res = []
|
39
|
+
for i in temp:
|
40
|
+
res.extend(ILLEGAL_CHARS_PATTERN.findall(i))
|
41
|
+
return res
|
42
|
+
|
43
|
+
|
44
|
+
if __name__ == '__main__':
|
45
|
+
file_path = '//data0//b/<?>.md'
|
46
|
+
print(find_illegal_chars(file_name=file_path, b_is_path=True))
|
47
|
+
print(find_illegal_chars(file_name=file_path, b_is_path=False))
|
@@ -0,0 +1,49 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
|
4
|
+
ILLEGAL_CHARS = {
|
5
|
+
'<': '<',
|
6
|
+
'>': '>',
|
7
|
+
':': ':',
|
8
|
+
'"': '"',
|
9
|
+
'/': '/',
|
10
|
+
'\\': '\',
|
11
|
+
'|': '|',
|
12
|
+
'?': '?',
|
13
|
+
'*': '*'
|
14
|
+
}
|
15
|
+
ILLEGAL_CHARS_PATTERN = re.compile('|'.join(re.escape(char) for char in ILLEGAL_CHARS.keys()))
|
16
|
+
|
17
|
+
|
18
|
+
def replace_illegal_chars(file_name, b_is_path=False):
|
19
|
+
"""
|
20
|
+
将给定的文件名/路径中的非法符号替换为合法形式
|
21
|
+
所谓非法符号是指在特定系统(win/mac等)中文件名不允许出现的字符,
|
22
|
+
不建议在任何系统下使用带有这些符号的文件名,即使这些符号在当前系统中是合法的,
|
23
|
+
以免在跨系统时出现兼容性问题。
|
24
|
+
|
25
|
+
参数:
|
26
|
+
file_name: <str>
|
27
|
+
b_is_path: <bool> 是否将file_name视为路径
|
28
|
+
默认为 False
|
29
|
+
当设置为 True 时,将会首先将file_name分割,再逐级处理目录名,最后合并为路径
|
30
|
+
"""
|
31
|
+
if not b_is_path:
|
32
|
+
res = _replace_illegal_chars(var=file_name)
|
33
|
+
else:
|
34
|
+
temp = file_name.split(os.sep, -1)
|
35
|
+
res = os.path.join(*[_replace_illegal_chars(var=i) for i in temp if len(i) > 0])
|
36
|
+
if len(temp[0]) == 0:
|
37
|
+
res = os.sep + res
|
38
|
+
return res
|
39
|
+
|
40
|
+
|
41
|
+
def _replace_illegal_chars(var):
|
42
|
+
global ILLEGAL_CHARS_PATTERN
|
43
|
+
return ILLEGAL_CHARS_PATTERN.sub(lambda m: ILLEGAL_CHARS[m.group(0)], var)
|
44
|
+
|
45
|
+
|
46
|
+
if __name__ == '__main__':
|
47
|
+
file_path = 'data0//b/<?>.md'
|
48
|
+
print(replace_illegal_chars(file_name=file_path, b_is_path=True))
|
49
|
+
print(replace_illegal_chars(file_name=file_path, b_is_path=False))
|
@@ -7,7 +7,8 @@ import kevin_toolbox.nested_dict_list as ndl
|
|
7
7
|
def check_consistency(*args, tolerance=1e-7, require_same_shape=True):
|
8
8
|
"""
|
9
9
|
检查 args 中多个变量之间是否一致
|
10
|
-
变量支持python的所有内置类型,以及复杂的 nested_dict_list 结构, array 等
|
10
|
+
变量支持 python 的所有内置类型,以及复杂的 nested_dict_list 结构, array 等
|
11
|
+
对于 array,不区分 numpy 的 array,torch 的 tensor,还是 tuple of number,只要其中的值相等,即视为相同。
|
11
12
|
|
12
13
|
参数:
|
13
14
|
tolerance: <float> 判断 <np.number/np.bool_> 之间是否一致时,的容许误差。
|
@@ -28,13 +29,13 @@ def check_consistency(*args, tolerance=1e-7, require_same_shape=True):
|
|
28
29
|
try:
|
29
30
|
_check_item(*names_ls, tolerance=tolerance, require_same_shape=True)
|
30
31
|
except AssertionError as e:
|
31
|
-
|
32
|
+
raise AssertionError(f'inputs <nested_dict_list> has different structure\nthe nodes that differ are:\n{e}')
|
32
33
|
for its in zip(names_ls[0], *values_ls):
|
33
34
|
try:
|
34
35
|
_check_item(*its[1:], tolerance=tolerance, require_same_shape=require_same_shape)
|
35
36
|
except AssertionError as e:
|
36
|
-
|
37
|
-
f'value of nodes {its[0]} in inputs <nested_dict_list> are inconsistent\nthe difference is:\n{e}'
|
37
|
+
raise AssertionError(
|
38
|
+
f'value of nodes {its[0]} in inputs <nested_dict_list> are inconsistent\nthe difference is:\n{e}')
|
38
39
|
# 简单结构
|
39
40
|
else:
|
40
41
|
_check_item(*args, tolerance=tolerance, require_same_shape=require_same_shape)
|
@@ -45,44 +46,111 @@ def _check_item(*args, tolerance, require_same_shape):
|
|
45
46
|
检查 args 中多个 array 之间是否一致
|
46
47
|
|
47
48
|
工作流程:
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
49
|
+
1. 对于 args 都是 tuple/list 且每个 tuple/list 的长度都一致的情况,将会拆分为对应各个元素递归进行比较。
|
50
|
+
2. 将 args 中的 tuple 和 tensor 分别转换为 list 和 numpy。
|
51
|
+
3. 检查 args 中是否有 np.array 或者 tensor,若有则根据 require_same_shape 判断其形状是否一致。
|
52
|
+
4. 先将输入的 args 中的所有变量转换为 np.array;
|
53
|
+
然后使用 issubclass() 判断转换后得到的变量属于以下哪几种基本类型:
|
54
|
+
- 当所有变量都属于 np.number 数值(包含int、float等)或者 np.bool_ 布尔值时,
|
55
|
+
将对变量两两求差,当差值小于给定的容许误差 tolerance 时,视为一致。
|
56
|
+
注意:在比较过程中,若变量中存在 np.nan 值,将会首先比较有 np.nan 值的位置是否相等,然后再比较非 np.nan 值部分。
|
57
|
+
亦即在相同位置上都具有 np.nan 视为相同。
|
58
|
+
比如 [np.nan, np.nan, 3] 和 [np.nan, np.nan, 3] 会视为相等,
|
59
|
+
[np.nan, np.nan] 和 np.nan 在 require_same_shape=False 时会被视为相等。
|
60
|
+
- 当所有变量都属于 np.flexible 可变长度类型(包含string等)或者 np.object 时,
|
61
|
+
将使用==进行比较,当返回值都为 True 时,视为一致。
|
62
|
+
- 当变量的基本类型不一致(比如同时有np.number和np.flexible)时,
|
63
|
+
直接判断为不一致。
|
64
|
+
numpy 中基本类型之间的继承关系参见: https://numpy.org.cn/reference/arrays/scalars.html
|
57
65
|
|
58
66
|
参数:
|
59
67
|
tolerance: <float> 判断 <np.number/np.bool_> 之间是否一致时,的容许误差。
|
60
68
|
require_same_shape: <boolean> 是否强制要求变量的形状一致。
|
61
|
-
|
69
|
+
注意:仅在原始 args 中含有 np.array 或者 tensor 的情况会采取 broadcast,亦即此时该参数才会起效。
|
70
|
+
当设置为 False 时,不同形状的变量可能因为 broadcast 机制而在比较前自动 reshape 为相同维度,进而通过比较。
|
62
71
|
"""
|
63
|
-
|
72
|
+
warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
|
64
73
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
65
74
|
|
66
75
|
assert len(args) >= 2
|
67
|
-
assert isinstance(tolerance, (int, float,))
|
76
|
+
assert isinstance(tolerance, (int, float,)) and tolerance >= 0
|
77
|
+
raw_args = args
|
78
|
+
|
79
|
+
# 当 args 都是 tuple/list 且每个 tuple 的长度都一致时,将会拆分为对应各个元素进行比较
|
80
|
+
if all([isinstance(i, (tuple, list)) for i in args]) and all([len(i) == len(args[0]) for i in args[1:]]):
|
81
|
+
for i, it_ls in enumerate(zip(*args)):
|
82
|
+
try:
|
83
|
+
_check_item(*it_ls, tolerance=tolerance, require_same_shape=require_same_shape)
|
84
|
+
except AssertionError as e:
|
85
|
+
raise AssertionError(f'elements {i} are inconsistent\nthe difference is:\n{e}')
|
86
|
+
return
|
87
|
+
|
88
|
+
#
|
89
|
+
args = ndl.traverse(var=list(args), match_cond=lambda _, __, v: isinstance(v, (tuple,)), action_mode="replace",
|
90
|
+
converter=lambda _, v: list(v), b_traverse_matched_element=True)
|
91
|
+
args = ndl.traverse(var=list(args), match_cond=lambda _, __, v: torch.is_tensor(v), action_mode="replace",
|
92
|
+
converter=lambda _, v: v.detach().cpu().numpy(), b_traverse_matched_element=False)
|
93
|
+
b_has_raw_array = any([isinstance(i, np.ndarray) for i in args])
|
68
94
|
|
69
|
-
|
70
|
-
|
95
|
+
try:
|
96
|
+
args = [np.asarray(v) for v in args] # if b_has_raw_array else [np.array(v, dtype=object) for v in args]
|
97
|
+
except Exception as e:
|
98
|
+
raise RuntimeError(f'{raw_args} cannot be converted to np.array, \n'
|
99
|
+
f'because {e}')
|
71
100
|
|
72
|
-
|
101
|
+
# 比较形状
|
102
|
+
if b_has_raw_array:
|
73
103
|
if require_same_shape:
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
104
|
+
# 要求形状一致
|
105
|
+
for v in args[1:]:
|
106
|
+
assert args[0].shape == v.shape, \
|
107
|
+
f"{args[0]}, {v}, different shape: {args[0].shape}, {v.shape}"
|
108
|
+
else:
|
109
|
+
# 否则要求至少能够进行 broadcast
|
110
|
+
for v in args[1:]:
|
111
|
+
try:
|
112
|
+
np.broadcast_arrays(args[0], v)
|
113
|
+
except:
|
114
|
+
raise AssertionError(f'{args[0]}, {v}, failed to broadcast')
|
115
|
+
# 如果都是空的 array,直接视为相等
|
116
|
+
if all([i.size == 0 for i in args]):
|
117
|
+
return
|
118
|
+
b_allow_broadcast = b_has_raw_array and not require_same_shape
|
119
|
+
|
120
|
+
# 比较值
|
121
|
+
if issubclass(args[0].dtype.type, (np.number, np.bool_,)):
|
122
|
+
# 数字类型
|
123
|
+
for v in args[1:]:
|
78
124
|
assert issubclass(v.dtype.type, (np.number, np.bool_,))
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
#
|
125
|
+
v_0, v_1 = args[0].astype(dtype=float), v.astype(dtype=float)
|
126
|
+
v_0, v_1 = np.broadcast_arrays(v_0, v_1) if b_allow_broadcast else (v_0, v_1)
|
127
|
+
assert v_0.shape == v_1.shape, \
|
128
|
+
f'{v_0}, {v_1}, different shape: {v_0.shape}, {v_1.shape}'
|
129
|
+
#
|
130
|
+
if v_0.size > 0:
|
131
|
+
try:
|
132
|
+
if np.any(np.isnan(v)):
|
133
|
+
assert np.all(np.isnan(v_0) == np.isnan(v_1))
|
134
|
+
v_0 = np.nan_to_num(v_0, nan=1e10)
|
135
|
+
v_1 = np.nan_to_num(v_1, nan=1e10)
|
136
|
+
assert np.max(np.abs(v_0 - v_1)) < tolerance
|
137
|
+
except AssertionError:
|
138
|
+
raise AssertionError(f"{args[0]}, {v}, deviation: {np.max(np.abs(args[0] - v))}")
|
139
|
+
elif issubclass(args[0].dtype.type, (np.flexible, object,)):
|
140
|
+
# 可变长度类型
|
141
|
+
for v in args[1:]:
|
84
142
|
assert issubclass(v.dtype.type, (np.flexible, object,))
|
85
|
-
|
143
|
+
v_0, v_1 = np.broadcast_arrays(args[0], v) if b_allow_broadcast else (args[0], v)
|
144
|
+
assert v_0.shape == v_1.shape, \
|
145
|
+
f'{v_0}, {v_1}, different shape: {v_0.shape}, {v_1.shape}\n' + (
|
146
|
+
'' if require_same_shape else
|
147
|
+
f'\tMore details: \n'
|
148
|
+
f'\t\tAlthough require_same_shape=False has been setted, broadcast failed because the variable at \n'
|
149
|
+
f'\t\tthis position does not contain elements of type np.array and tensor.')
|
150
|
+
#
|
151
|
+
for i, j in zip(v_0.reshape(-1), v_1.reshape(-1)):
|
152
|
+
if i is j:
|
153
|
+
continue
|
86
154
|
temp = i == j
|
87
155
|
if isinstance(temp, (bool,)):
|
88
156
|
assert temp, \
|
@@ -90,14 +158,17 @@ def _check_item(*args, tolerance, require_same_shape):
|
|
90
158
|
else:
|
91
159
|
assert temp.all(), \
|
92
160
|
f"{args[0]}, {v}, diff: {temp}"
|
93
|
-
|
94
|
-
|
161
|
+
else:
|
162
|
+
raise ValueError
|
95
163
|
|
96
164
|
|
97
165
|
if __name__ == '__main__':
|
98
166
|
a = np.array([[1, 2, 3]])
|
99
167
|
b = np.array([[1, 2, 3]])
|
100
168
|
c = {'d': 3, 'c': 4}
|
101
|
-
check_consistency([c, a], [c, b])
|
102
169
|
|
103
|
-
|
170
|
+
# var = ((1, 2), (4))
|
171
|
+
#
|
172
|
+
# var = ndl.traverse(var=[var], match_cond=lambda _, __, v: isinstance(v, (tuple,)), action_mode="replace",
|
173
|
+
# converter=lambda _, v: list(v), b_traverse_matched_element=True)[0]
|
174
|
+
# print(var)
|
@@ -0,0 +1,95 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: kevin-toolbox-dev
|
3
|
+
Version: 1.3.6
|
4
|
+
Summary: 一个常用的工具代码包集合
|
5
|
+
Home-page: https://github.com/cantbeblank96/kevin_toolbox
|
6
|
+
Download-URL: https://github.com/username/your-package/archive/refs/tags/v1.0.0.tar.gz
|
7
|
+
Author: kevin hsu
|
8
|
+
Author-email: xukaiming1996@163.com
|
9
|
+
License: MIT
|
10
|
+
Keywords: mathematics,pytorch,numpy,machine-learning,algorithm
|
11
|
+
Platform: UNKNOWN
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
13
|
+
Classifier: Programming Language :: Python
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
15
|
+
Requires-Python: >=3.6
|
16
|
+
Description-Content-Type: text/markdown
|
17
|
+
Requires-Dist: torch (>=1.2.0)
|
18
|
+
Requires-Dist: numpy (>=1.19.0)
|
19
|
+
Provides-Extra: plot
|
20
|
+
Requires-Dist: matplotlib (>=3.0) ; extra == 'plot'
|
21
|
+
Provides-Extra: rest
|
22
|
+
Requires-Dist: pytest (>=6.2.5) ; extra == 'rest'
|
23
|
+
Requires-Dist: line-profiler (>=3.5) ; extra == 'rest'
|
24
|
+
|
25
|
+
# kevin_toolbox
|
26
|
+
|
27
|
+
一个通用的工具代码包集合
|
28
|
+
|
29
|
+
|
30
|
+
|
31
|
+
环境要求
|
32
|
+
|
33
|
+
```shell
|
34
|
+
numpy>=1.19
|
35
|
+
pytorch>=1.2
|
36
|
+
```
|
37
|
+
|
38
|
+
安装方法:
|
39
|
+
|
40
|
+
```shell
|
41
|
+
pip install kevin-toolbox --no-dependencies
|
42
|
+
```
|
43
|
+
|
44
|
+
|
45
|
+
|
46
|
+
[项目地址 Repo](https://github.com/cantbeblank96/kevin_toolbox)
|
47
|
+
|
48
|
+
[使用指南 User_Guide](./notes/User_Guide.md)
|
49
|
+
|
50
|
+
[免责声明 Disclaimer](./notes/Disclaimer.md)
|
51
|
+
|
52
|
+
[版本更新记录](./notes/Release_Record.md):
|
53
|
+
|
54
|
+
- v 1.3.6 (2024-07-03)【new feature】
|
55
|
+
- patches
|
56
|
+
- for_os
|
57
|
+
- 【new feature】add find_files_in_dir(),找出目录下带有给定后缀的所有文件的生成器。
|
58
|
+
- for_os.path
|
59
|
+
- 【new feature】add find_illegal_chars(),找出给定的文件名/路径中出现了哪些非法符号。
|
60
|
+
- 【new feature】add replace_illegal_chars(),将给定的文件名/路径中的非法符号替换为合法形式。
|
61
|
+
- for_matplotlib
|
62
|
+
- 【new feature】add module common_charts,新增模块——常用图表,该模块下包含以下函数:
|
63
|
+
- plot_bars(),绘制柱状图
|
64
|
+
- plot_scatters(),绘制散点图
|
65
|
+
- plot_lines(),绘制折线图
|
66
|
+
- plot_distribution(),绘制分布图
|
67
|
+
- plot_scatters_matrix(),绘制散点图矩阵(常用于多变量关系分析)
|
68
|
+
- plot_confusion_matrix(),绘制混淆矩阵(常用于混淆矩阵、相关性矩阵、特征图可视化)
|
69
|
+
- 添加了测试用例。
|
70
|
+
- data_flow.file
|
71
|
+
- markdown
|
72
|
+
- 【new feature】add save_images_in_ndl(),将ndl结构叶节点下的图片对象保存到 plot_dir 中,并替换为该图片的markdown链接。
|
73
|
+
- 便于对表格中的图片或者列表中的图片进行保存和替换。
|
74
|
+
- 【new feature】add find_tables(),用于从文本中找出markdown格式的表格,并以二维数组的列表形式返回。
|
75
|
+
- 【new feature】add parse_table(),将二维数组形式的表格(比如find_tables()的返回列表的元素),解析成指定的格式。
|
76
|
+
- kevin_notation
|
77
|
+
- 【bug fix】fix bug in Kevin_Notation_Writer,增加检验写入的列的元素数量是否一致,不一致时进行报错。
|
78
|
+
- 【bug fix】fix bug in write(),避免对输入参数 metadata 中的内容进行意料之外的改动。
|
79
|
+
- nested_dict_list
|
80
|
+
- add para b_allow_override to serializer.write to allow overwriting,增加参数用于允许强制覆盖已有文件。
|
81
|
+
- computer_science.algorithm
|
82
|
+
- pareto_front
|
83
|
+
- 【new feature】add Optimum_Picker,帕累托最优值选取器。
|
84
|
+
- 记录并更新帕累托最优值
|
85
|
+
- 同时支持监控以下行为,并触发设定的执行器,详见参数 trigger_for_new 和 trigger_for_out。
|
86
|
+
- 新加值是一个新的帕累托最优值
|
87
|
+
- 抛弃一个不再是最优的旧的最优值
|
88
|
+
- statistician
|
89
|
+
- 【new feature】add Accumulator_for_Ndl,适用于 ndl 结构的统计器。
|
90
|
+
- 【bug fix】fix bug in Accumulator_Base._init_var()
|
91
|
+
- 【new feature】modify Average_Accumulator,在 add() 中新增了 weight 参数用于计算带权重的平均值
|
92
|
+
- modify Exponential_Moving_Average,add_sequence() 不再支持 weight_ls 参数,让该接口与其他类更加一致。
|
93
|
+
- 添加了测试用例。
|
94
|
+
|
95
|
+
|