kevin-toolbox-dev 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/__init__.py +1 -0
  3. kevin_toolbox/computer_science/algorithm/pareto_front/optimum_picker.py +218 -0
  4. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +2 -2
  6. kevin_toolbox/computer_science/algorithm/statistician/accumulator_for_ndl.py +69 -0
  7. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +16 -4
  8. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +5 -12
  9. kevin_toolbox/data_flow/file/json_/read_json.py +11 -5
  10. kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_writer.py +4 -1
  11. kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation_debug.py +27 -0
  12. kevin_toolbox/data_flow/file/kevin_notation/write.py +3 -0
  13. kevin_toolbox/data_flow/file/markdown/__init__.py +3 -0
  14. kevin_toolbox/data_flow/file/markdown/find_tables.py +65 -0
  15. kevin_toolbox/data_flow/file/markdown/generate_table.py +19 -5
  16. kevin_toolbox/data_flow/file/markdown/parse_table.py +135 -0
  17. kevin_toolbox/data_flow/file/markdown/save_images_in_ndl.py +81 -0
  18. kevin_toolbox/data_flow/file/markdown/variable.py +17 -0
  19. kevin_toolbox/nested_dict_list/serializer/backends/_ndl.py +4 -1
  20. kevin_toolbox/nested_dict_list/serializer/read.py +18 -14
  21. kevin_toolbox/nested_dict_list/serializer/write.py +23 -7
  22. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +6 -0
  23. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +54 -0
  24. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +60 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +65 -0
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +61 -0
  27. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +53 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +54 -0
  29. kevin_toolbox/patches/for_os/__init__.py +3 -0
  30. kevin_toolbox/patches/for_os/copy.py +33 -0
  31. kevin_toolbox/patches/for_os/find_files_in_dir.py +30 -0
  32. kevin_toolbox/patches/for_os/path/__init__.py +2 -0
  33. kevin_toolbox/patches/for_os/path/find_illegal_chars.py +47 -0
  34. kevin_toolbox/patches/for_os/path/replace_illegal_chars.py +49 -0
  35. kevin_toolbox/patches/for_test/check_consistency.py +104 -33
  36. kevin_toolbox_dev-1.3.6.dist-info/METADATA +95 -0
  37. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/RECORD +39 -20
  38. kevin_toolbox_dev-1.3.4.dist-info/METADATA +0 -67
  39. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/WHEEL +0 -0
  40. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,53 @@
1
+ import os
2
+ import matplotlib.pyplot as plt
3
+ from kevin_toolbox.patches.for_matplotlib import generate_color_list
4
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
+
6
+
7
+ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, **kwargs):
8
+ paras = {
9
+ "dpi": 200,
10
+ "scatter_size": 5
11
+ }
12
+ paras.update(kwargs)
13
+
14
+ plt.clf()
15
+ #
16
+ color_s = None
17
+ if cate_name is not None:
18
+ cates = list(set(data_s[cate_name]))
19
+ color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
20
+ c = [color_s[i] for i in data_s[cate_name]]
21
+ else:
22
+ c = "blue"
23
+ # 创建散点图
24
+ plt.scatter(data_s[x_name], data_s[y_name], s=paras["scatter_size"], c=c, alpha=0.8)
25
+ #
26
+ plt.xlabel(f'{x_name}')
27
+ plt.ylabel(f'{y_name}')
28
+ plt.title(f'{title}')
29
+ # 添加图例
30
+ if cate_name is not None:
31
+ plt.legend(handles=[
32
+ plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
33
+ markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
34
+ ])
35
+
36
+ if output_dir is None:
37
+ plt.show()
38
+ return None
39
+ else:
40
+ os.makedirs(output_dir, exist_ok=True)
41
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
42
+ plt.savefig(output_path, dpi=paras["dpi"])
43
+ return output_path
44
+
45
+
46
+ if __name__ == '__main__':
47
+ data_s_ = dict(
48
+ x=[1, 2, 3, 4, 5],
49
+ y=[2, 4, 6, 8, 10],
50
+ categories=['A', 'B', 'A', 'B', 'A']
51
+ )
52
+
53
+ plot_scatters(data_s=data_s_, title='test', x_name='x', y_name='y', cate_name='categories')
@@ -0,0 +1,54 @@
1
+ import os
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ from kevin_toolbox.patches.for_matplotlib import generate_color_list
6
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
7
+
8
+
9
+ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, cate_color_s=None, **kwargs):
10
+ paras = {
11
+ "dpi": 200,
12
+ "diag_kind": "kde" # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
13
+ }
14
+ assert cate_name in data_s and len(set(x_name_ls).difference(set(data_s.keys()))) == 0
15
+ if cate_color_s is None:
16
+ temp = set(data_s[cate_name])
17
+ cate_color_s = {k: v for k, v in zip(temp, generate_color_list(len(temp)))}
18
+ assert set(cate_color_s.keys()) == set(data_s[cate_name])
19
+ paras.update(kwargs)
20
+
21
+ plt.clf()
22
+ # 使用seaborn绘制散点图矩阵
23
+ sns.pairplot(
24
+ pd.DataFrame(data_s),
25
+ diag_kind=paras["diag_kind"], # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
26
+ hue=cate_name, # hue 表示根据该列的值进行分类
27
+ palette=cate_color_s, x_vars=x_name_ls, y_vars=x_name_ls, # x_vars,y_vars 指定子图的排列顺序
28
+ )
29
+ #
30
+ plt.subplots_adjust(top=0.95)
31
+ plt.suptitle(f'{title}', y=0.98, x=0.47)
32
+ # g.fig.suptitle(f'{title}', y=1.05)
33
+
34
+ if output_dir is None:
35
+ plt.show()
36
+ return None
37
+ else:
38
+ os.makedirs(output_dir, exist_ok=True)
39
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
40
+ plt.savefig(output_path, dpi=paras["dpi"])
41
+ return output_path
42
+
43
+
44
+ if __name__ == '__main__':
45
+ data_s_ = dict(
46
+ x=[1, 2, 3, 4, 5],
47
+ y=[2, 4, 6, 8, 10],
48
+ z=[2, 4, 6, 8, 10],
49
+ categories=['A', 'B', 'A', 'B', 'A'],
50
+ title='test',
51
+ )
52
+
53
+ plot_scatters_matrix(data_s=data_s_, title='test', x_name_ls=['y', 'x', 'z'], cate_name='categories',
54
+ cate_color_s={'A': 'red', 'B': 'blue'})
@@ -1,4 +1,7 @@
1
1
  from .remove import remove
2
+ from .copy import copy
2
3
  from .pack import pack
3
4
  from .unpack import unpack
4
5
  from .walk import walk, Path_Ignorer, Ignore_Scope
6
+ from .find_files_in_dir import find_files_in_dir
7
+ from . import path
@@ -0,0 +1,33 @@
1
+ import os
2
+ import shutil
3
+ from kevin_toolbox.patches.for_os import remove
4
+
5
+
6
+ def copy(src, dst, follow_symlinks=True, remove_dst_if_exists=False):
7
+ """
8
+ 复制文件/文件夹/软连接
9
+
10
+ 参数:
11
+ follow_symlinks: <boolean> 是否跟随符号链接
12
+ 对于 src 是软连接或者 src 指向的目录下具有软连接的情况,
13
+ 当设置为 True 时,会复制所有链接指向的实际文件或目录内容。
14
+ 当设置为 False 时,则仅仅软连接本身,而不是指向的内容
15
+ 默认为 True
16
+ remove_dst_if_exists: <boolean> 当目标存在时,是否尝试进行移除
17
+ """
18
+ assert os.path.exists(src), f'failed to copy, src not exists: {src}'
19
+ if remove_dst_if_exists:
20
+ remove(path=dst, ignore_errors=True)
21
+ assert not os.path.exists(dst), f'failed to copy, dst exists: {dst}'
22
+
23
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
24
+ if os.path.isdir(src):
25
+ if not follow_symlinks and os.path.islink(src):
26
+ # 如果是符号链接,并且我们跟随符号链接,则复制链接本身
27
+ os.symlink(os.readlink(src), dst)
28
+ else:
29
+ # 否则,递归复制目录
30
+ shutil.copytree(src, dst, symlinks=not follow_symlinks)
31
+ else:
32
+ # 复制文件
33
+ shutil.copy2(src, dst, follow_symlinks=follow_symlinks)
@@ -0,0 +1,30 @@
1
+ import os
2
+ from kevin_toolbox.patches.for_os import walk
3
+
4
+
5
+ def find_files_in_dir(input_dir, suffix_ls, b_relative_path=True, b_ignore_case=True):
6
+ """
7
+ 找出目录下带有给定后缀的所有文件的生成器
8
+ 主要利用了 for_os.walk 中的过滤规则进行实现
9
+
10
+ 参数:
11
+ suffix_ls: <list/tuple of str> 可选的后缀
12
+ b_relative_path: <bool> 是否返回相对路径
13
+ b_ignore_case: <bool> 是否忽略大小写
14
+ """
15
+ suffix_ls = tuple(set(suffix_ls))
16
+ suffix_ls = tuple(map(lambda x: x.lower(), suffix_ls)) if b_ignore_case else suffix_ls
17
+ for root, dirs, files in walk(top=input_dir, topdown=True,
18
+ ignore_s=[{
19
+ "func": lambda _, b_is_symlink, path: b_is_symlink or not (
20
+ path.lower() if b_ignore_case else path).endswith(suffix_ls),
21
+ "scope": ["files", ]
22
+ }]):
23
+ for file in files:
24
+ file_path = os.path.join(root, file)
25
+ if b_relative_path:
26
+ file_path = os.path.relpath(file_path, start=input_dir)
27
+ yield file_path
28
+
29
+
30
+
@@ -0,0 +1,2 @@
1
+ from .find_illegal_chars import find_illegal_chars
2
+ from .replace_illegal_chars import replace_illegal_chars
@@ -0,0 +1,47 @@
1
+ import os
2
+ import re
3
+
4
+ ILLEGAL_CHARS = {
5
+ '<': '<',
6
+ '>': '>',
7
+ ':': ':',
8
+ '"': '"',
9
+ '/': '/',
10
+ '\\': '\',
11
+ '|': '|',
12
+ '?': '?',
13
+ '*': '*'
14
+ }
15
+
16
+ ILLEGAL_CHARS_PATTERN = re.compile(r'[<>:"/\\|?*]')
17
+
18
+
19
+ def find_illegal_chars(file_name, b_is_path=False):
20
+ """
21
+ 找出给定的文件名/路径中出现了哪些非法符号
22
+ 所谓非法符号是指在特定系统(win/mac等)中文件名不允许出现的字符,
23
+ 不建议在任何系统下使用带有这些符号的文件名,即使这些符号在当前系统中是合法的,
24
+ 以免在跨系统时出现兼容性问题
25
+
26
+ 参数:
27
+ file_name: <str>
28
+ b_is_path: <bool> 是否将file_name视为路径
29
+ 默认为 False
30
+ 当设置为 True 时,将会首先将file_name分割,再逐级从目录名中查找
31
+ """
32
+ global ILLEGAL_CHARS_PATTERN
33
+
34
+ if b_is_path:
35
+ temp = [i for i in file_name.split(os.sep, -1) if len(i) > 0]
36
+ else:
37
+ temp = [file_name]
38
+ res = []
39
+ for i in temp:
40
+ res.extend(ILLEGAL_CHARS_PATTERN.findall(i))
41
+ return res
42
+
43
+
44
+ if __name__ == '__main__':
45
+ file_path = '//data0//b/<?>.md'
46
+ print(find_illegal_chars(file_name=file_path, b_is_path=True))
47
+ print(find_illegal_chars(file_name=file_path, b_is_path=False))
@@ -0,0 +1,49 @@
1
+ import os
2
+ import re
3
+
4
+ ILLEGAL_CHARS = {
5
+ '<': '<',
6
+ '>': '>',
7
+ ':': ':',
8
+ '"': '"',
9
+ '/': '/',
10
+ '\\': '\',
11
+ '|': '|',
12
+ '?': '?',
13
+ '*': '*'
14
+ }
15
+ ILLEGAL_CHARS_PATTERN = re.compile('|'.join(re.escape(char) for char in ILLEGAL_CHARS.keys()))
16
+
17
+
18
+ def replace_illegal_chars(file_name, b_is_path=False):
19
+ """
20
+ 将给定的文件名/路径中的非法符号替换为合法形式
21
+ 所谓非法符号是指在特定系统(win/mac等)中文件名不允许出现的字符,
22
+ 不建议在任何系统下使用带有这些符号的文件名,即使这些符号在当前系统中是合法的,
23
+ 以免在跨系统时出现兼容性问题。
24
+
25
+ 参数:
26
+ file_name: <str>
27
+ b_is_path: <bool> 是否将file_name视为路径
28
+ 默认为 False
29
+ 当设置为 True 时,将会首先将file_name分割,再逐级处理目录名,最后合并为路径
30
+ """
31
+ if not b_is_path:
32
+ res = _replace_illegal_chars(var=file_name)
33
+ else:
34
+ temp = file_name.split(os.sep, -1)
35
+ res = os.path.join(*[_replace_illegal_chars(var=i) for i in temp if len(i) > 0])
36
+ if len(temp[0]) == 0:
37
+ res = os.sep + res
38
+ return res
39
+
40
+
41
+ def _replace_illegal_chars(var):
42
+ global ILLEGAL_CHARS_PATTERN
43
+ return ILLEGAL_CHARS_PATTERN.sub(lambda m: ILLEGAL_CHARS[m.group(0)], var)
44
+
45
+
46
+ if __name__ == '__main__':
47
+ file_path = 'data0//b/<?>.md'
48
+ print(replace_illegal_chars(file_name=file_path, b_is_path=True))
49
+ print(replace_illegal_chars(file_name=file_path, b_is_path=False))
@@ -7,7 +7,8 @@ import kevin_toolbox.nested_dict_list as ndl
7
7
  def check_consistency(*args, tolerance=1e-7, require_same_shape=True):
8
8
  """
9
9
  检查 args 中多个变量之间是否一致
10
- 变量支持python的所有内置类型,以及复杂的 nested_dict_list 结构, array 等
10
+ 变量支持 python 的所有内置类型,以及复杂的 nested_dict_list 结构, array 等
11
+ 对于 array,不区分 numpy 的 array,torch 的 tensor,还是 tuple of number,只要其中的值相等,即视为相同。
11
12
 
12
13
  参数:
13
14
  tolerance: <float> 判断 <np.number/np.bool_> 之间是否一致时,的容许误差。
@@ -28,13 +29,13 @@ def check_consistency(*args, tolerance=1e-7, require_same_shape=True):
28
29
  try:
29
30
  _check_item(*names_ls, tolerance=tolerance, require_same_shape=True)
30
31
  except AssertionError as e:
31
- assert False, f'inputs <nested_dict_list> has different structure\nthe nodes that differ are:\n{e}'
32
+ raise AssertionError(f'inputs <nested_dict_list> has different structure\nthe nodes that differ are:\n{e}')
32
33
  for its in zip(names_ls[0], *values_ls):
33
34
  try:
34
35
  _check_item(*its[1:], tolerance=tolerance, require_same_shape=require_same_shape)
35
36
  except AssertionError as e:
36
- assert False, \
37
- f'value of nodes {its[0]} in inputs <nested_dict_list> are inconsistent\nthe difference is:\n{e}'
37
+ raise AssertionError(
38
+ f'value of nodes {its[0]} in inputs <nested_dict_list> are inconsistent\nthe difference is:\n{e}')
38
39
  # 简单结构
39
40
  else:
40
41
  _check_item(*args, tolerance=tolerance, require_same_shape=require_same_shape)
@@ -45,44 +46,111 @@ def _check_item(*args, tolerance, require_same_shape):
45
46
  检查 args 中多个 array 之间是否一致
46
47
 
47
48
  工作流程:
48
- 本函数会首先将输入的 args 中的所有变量转换为 np.array;
49
- 然后使用 issubclass() 判断转换后得到的变量属于以下哪几种基本类型:
50
- - 当所有变量都属于 np.number 数值(包含int、float等)或者 np.bool_ 布尔值时,
51
- 将对变量两两求差,当差值小于给定的容许误差 tolerance 时,视为一致。
52
- - 当所有变量都属于 np.flexible 可变长度类型(包含string等)或者 np.object 时,
53
- 将使用==进行比较,当返回值都为 True 时,视为一致。
54
- - 当变量的基本类型不一致(比如同时有np.number和np.flexible)时,
55
- 直接判断为不一致。
56
- numpy 中基本类型之间的继承关系参见: https://numpy.org.cn/reference/arrays/scalars.html
49
+ 1. 对于 args 都是 tuple/list 且每个 tuple/list 的长度都一致的情况,将会拆分为对应各个元素递归进行比较。
50
+ 2. args 中的 tuple 和 tensor 分别转换为 list 和 numpy。
51
+ 3. 检查 args 中是否有 np.array 或者 tensor,若有则根据 require_same_shape 判断其形状是否一致。
52
+ 4. 先将输入的 args 中的所有变量转换为 np.array;
53
+ 然后使用 issubclass() 判断转换后得到的变量属于以下哪几种基本类型:
54
+ - 当所有变量都属于 np.number 数值(包含int、float等)或者 np.bool_ 布尔值时,
55
+ 将对变量两两求差,当差值小于给定的容许误差 tolerance 时,视为一致。
56
+ 注意:在比较过程中,若变量中存在 np.nan 值,将会首先比较有 np.nan 值的位置是否相等,然后再比较非 np.nan 值部分。
57
+ 亦即在相同位置上都具有 np.nan 视为相同。
58
+ 比如 [np.nan, np.nan, 3] 和 [np.nan, np.nan, 3] 会视为相等,
59
+ [np.nan, np.nan] 和 np.nan 在 require_same_shape=False 时会被视为相等。
60
+ - 当所有变量都属于 np.flexible 可变长度类型(包含string等)或者 np.object 时,
61
+ 将使用==进行比较,当返回值都为 True 时,视为一致。
62
+ - 当变量的基本类型不一致(比如同时有np.number和np.flexible)时,
63
+ 直接判断为不一致。
64
+ numpy 中基本类型之间的继承关系参见: https://numpy.org.cn/reference/arrays/scalars.html
57
65
 
58
66
  参数:
59
67
  tolerance: <float> 判断 <np.number/np.bool_> 之间是否一致时,的容许误差。
60
68
  require_same_shape: <boolean> 是否强制要求变量的形状一致。
61
- 当设置为 False 时,不同形状的变量可能因为 numpy broadcast 机制而在比较前自动 reshape 为相同维度,进而可能通过比较。
69
+ 注意:仅在原始 args 中含有 np.array 或者 tensor 的情况会采取 broadcast,亦即此时该参数才会起效。
70
+ 当设置为 False 时,不同形状的变量可能因为 broadcast 机制而在比较前自动 reshape 为相同维度,进而通过比较。
62
71
  """
63
- np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
72
+ warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
64
73
  warnings.filterwarnings("ignore", category=DeprecationWarning)
65
74
 
66
75
  assert len(args) >= 2
67
- assert isinstance(tolerance, (int, float,))
76
+ assert isinstance(tolerance, (int, float,)) and tolerance >= 0
77
+ raw_args = args
78
+
79
+ # 当 args 都是 tuple/list 且每个 tuple 的长度都一致时,将会拆分为对应各个元素进行比较
80
+ if all([isinstance(i, (tuple, list)) for i in args]) and all([len(i) == len(args[0]) for i in args[1:]]):
81
+ for i, it_ls in enumerate(zip(*args)):
82
+ try:
83
+ _check_item(*it_ls, tolerance=tolerance, require_same_shape=require_same_shape)
84
+ except AssertionError as e:
85
+ raise AssertionError(f'elements {i} are inconsistent\nthe difference is:\n{e}')
86
+ return
87
+
88
+ #
89
+ args = ndl.traverse(var=list(args), match_cond=lambda _, __, v: isinstance(v, (tuple,)), action_mode="replace",
90
+ converter=lambda _, v: list(v), b_traverse_matched_element=True)
91
+ args = ndl.traverse(var=list(args), match_cond=lambda _, __, v: torch.is_tensor(v), action_mode="replace",
92
+ converter=lambda _, v: v.detach().cpu().numpy(), b_traverse_matched_element=False)
93
+ b_has_raw_array = any([isinstance(i, np.ndarray) for i in args])
68
94
 
69
- args = [v.detach().cpu() if torch.is_tensor(v) else v for v in args]
70
- args = [np.asarray(v) for v in args]
95
+ try:
96
+ args = [np.asarray(v) for v in args] # if b_has_raw_array else [np.array(v, dtype=object) for v in args]
97
+ except Exception as e:
98
+ raise RuntimeError(f'{raw_args} cannot be converted to np.array, \n'
99
+ f'because {e}')
71
100
 
72
- for v in args[1:]:
101
+ # 比较形状
102
+ if b_has_raw_array:
73
103
  if require_same_shape:
74
- assert args[0].shape == v.shape, \
75
- f"{args[0]}, {v}, different shape: {args[0].shape}, {v.shape}"
76
- if issubclass(args[0].dtype.type, (np.number, np.bool_,)):
77
- # 数字类型
104
+ # 要求形状一致
105
+ for v in args[1:]:
106
+ assert args[0].shape == v.shape, \
107
+ f"{args[0]}, {v}, different shape: {args[0].shape}, {v.shape}"
108
+ else:
109
+ # 否则要求至少能够进行 broadcast
110
+ for v in args[1:]:
111
+ try:
112
+ np.broadcast_arrays(args[0], v)
113
+ except:
114
+ raise AssertionError(f'{args[0]}, {v}, failed to broadcast')
115
+ # 如果都是空的 array,直接视为相等
116
+ if all([i.size == 0 for i in args]):
117
+ return
118
+ b_allow_broadcast = b_has_raw_array and not require_same_shape
119
+
120
+ # 比较值
121
+ if issubclass(args[0].dtype.type, (np.number, np.bool_,)):
122
+ # 数字类型
123
+ for v in args[1:]:
78
124
  assert issubclass(v.dtype.type, (np.number, np.bool_,))
79
- if args[0].size > 0:
80
- assert np.max(np.abs(args[0] - v.astype(dtype=float))) < tolerance, \
81
- f"{args[0]}, {v}, deviation: {np.max(np.abs(args[0] - v))}"
82
- elif issubclass(args[0].dtype.type, (np.flexible, object,)):
83
- # 可变长度类型
125
+ v_0, v_1 = args[0].astype(dtype=float), v.astype(dtype=float)
126
+ v_0, v_1 = np.broadcast_arrays(v_0, v_1) if b_allow_broadcast else (v_0, v_1)
127
+ assert v_0.shape == v_1.shape, \
128
+ f'{v_0}, {v_1}, different shape: {v_0.shape}, {v_1.shape}'
129
+ #
130
+ if v_0.size > 0:
131
+ try:
132
+ if np.any(np.isnan(v)):
133
+ assert np.all(np.isnan(v_0) == np.isnan(v_1))
134
+ v_0 = np.nan_to_num(v_0, nan=1e10)
135
+ v_1 = np.nan_to_num(v_1, nan=1e10)
136
+ assert np.max(np.abs(v_0 - v_1)) < tolerance
137
+ except AssertionError:
138
+ raise AssertionError(f"{args[0]}, {v}, deviation: {np.max(np.abs(args[0] - v))}")
139
+ elif issubclass(args[0].dtype.type, (np.flexible, object,)):
140
+ # 可变长度类型
141
+ for v in args[1:]:
84
142
  assert issubclass(v.dtype.type, (np.flexible, object,))
85
- for i, j in zip(args[0].reshape(-1), v.reshape(-1)):
143
+ v_0, v_1 = np.broadcast_arrays(args[0], v) if b_allow_broadcast else (args[0], v)
144
+ assert v_0.shape == v_1.shape, \
145
+ f'{v_0}, {v_1}, different shape: {v_0.shape}, {v_1.shape}\n' + (
146
+ '' if require_same_shape else
147
+ f'\tMore details: \n'
148
+ f'\t\tAlthough require_same_shape=False has been setted, broadcast failed because the variable at \n'
149
+ f'\t\tthis position does not contain elements of type np.array and tensor.')
150
+ #
151
+ for i, j in zip(v_0.reshape(-1), v_1.reshape(-1)):
152
+ if i is j:
153
+ continue
86
154
  temp = i == j
87
155
  if isinstance(temp, (bool,)):
88
156
  assert temp, \
@@ -90,14 +158,17 @@ def _check_item(*args, tolerance, require_same_shape):
90
158
  else:
91
159
  assert temp.all(), \
92
160
  f"{args[0]}, {v}, diff: {temp}"
93
- else:
94
- raise ValueError
161
+ else:
162
+ raise ValueError
95
163
 
96
164
 
97
165
  if __name__ == '__main__':
98
166
  a = np.array([[1, 2, 3]])
99
167
  b = np.array([[1, 2, 3]])
100
168
  c = {'d': 3, 'c': 4}
101
- check_consistency([c, a], [c, b])
102
169
 
103
- check_consistency(True, True)
170
+ # var = ((1, 2), (4))
171
+ #
172
+ # var = ndl.traverse(var=[var], match_cond=lambda _, __, v: isinstance(v, (tuple,)), action_mode="replace",
173
+ # converter=lambda _, v: list(v), b_traverse_matched_element=True)[0]
174
+ # print(var)
@@ -0,0 +1,95 @@
1
+ Metadata-Version: 2.1
2
+ Name: kevin-toolbox-dev
3
+ Version: 1.3.6
4
+ Summary: 一个常用的工具代码包集合
5
+ Home-page: https://github.com/cantbeblank96/kevin_toolbox
6
+ Download-URL: https://github.com/username/your-package/archive/refs/tags/v1.0.0.tar.gz
7
+ Author: kevin hsu
8
+ Author-email: xukaiming1996@163.com
9
+ License: MIT
10
+ Keywords: mathematics,pytorch,numpy,machine-learning,algorithm
11
+ Platform: UNKNOWN
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.6
16
+ Description-Content-Type: text/markdown
17
+ Requires-Dist: torch (>=1.2.0)
18
+ Requires-Dist: numpy (>=1.19.0)
19
+ Provides-Extra: plot
20
+ Requires-Dist: matplotlib (>=3.0) ; extra == 'plot'
21
+ Provides-Extra: rest
22
+ Requires-Dist: pytest (>=6.2.5) ; extra == 'rest'
23
+ Requires-Dist: line-profiler (>=3.5) ; extra == 'rest'
24
+
25
+ # kevin_toolbox
26
+
27
+ 一个通用的工具代码包集合
28
+
29
+
30
+
31
+ 环境要求
32
+
33
+ ```shell
34
+ numpy>=1.19
35
+ pytorch>=1.2
36
+ ```
37
+
38
+ 安装方法:
39
+
40
+ ```shell
41
+ pip install kevin-toolbox --no-dependencies
42
+ ```
43
+
44
+
45
+
46
+ [项目地址 Repo](https://github.com/cantbeblank96/kevin_toolbox)
47
+
48
+ [使用指南 User_Guide](./notes/User_Guide.md)
49
+
50
+ [免责声明 Disclaimer](./notes/Disclaimer.md)
51
+
52
+ [版本更新记录](./notes/Release_Record.md):
53
+
54
+ - v 1.3.6 (2024-07-03)【new feature】
55
+ - patches
56
+ - for_os
57
+ - 【new feature】add find_files_in_dir(),找出目录下带有给定后缀的所有文件的生成器。
58
+ - for_os.path
59
+ - 【new feature】add find_illegal_chars(),找出给定的文件名/路径中出现了哪些非法符号。
60
+ - 【new feature】add replace_illegal_chars(),将给定的文件名/路径中的非法符号替换为合法形式。
61
+ - for_matplotlib
62
+ - 【new feature】add module common_charts,新增模块——常用图表,该模块下包含以下函数:
63
+ - plot_bars(),绘制柱状图
64
+ - plot_scatters(),绘制散点图
65
+ - plot_lines(),绘制折线图
66
+ - plot_distribution(),绘制分布图
67
+ - plot_scatters_matrix(),绘制散点图矩阵(常用于多变量关系分析)
68
+ - plot_confusion_matrix(),绘制混淆矩阵(常用于混淆矩阵、相关性矩阵、特征图可视化)
69
+ - 添加了测试用例。
70
+ - data_flow.file
71
+ - markdown
72
+ - 【new feature】add save_images_in_ndl(),将ndl结构叶节点下的图片对象保存到 plot_dir 中,并替换为该图片的markdown链接。
73
+ - 便于对表格中的图片或者列表中的图片进行保存和替换。
74
+ - 【new feature】add find_tables(),用于从文本中找出markdown格式的表格,并以二维数组的列表形式返回。
75
+ - 【new feature】add parse_table(),将二维数组形式的表格(比如find_tables()的返回列表的元素),解析成指定的格式。
76
+ - kevin_notation
77
+ - 【bug fix】fix bug in Kevin_Notation_Writer,增加检验写入的列的元素数量是否一致,不一致时进行报错。
78
+ - 【bug fix】fix bug in write(),避免对输入参数 metadata 中的内容进行意料之外的改动。
79
+ - nested_dict_list
80
+ - add para b_allow_override to serializer.write to allow overwriting,增加参数用于允许强制覆盖已有文件。
81
+ - computer_science.algorithm
82
+ - pareto_front
83
+ - 【new feature】add Optimum_Picker,帕累托最优值选取器。
84
+ - 记录并更新帕累托最优值
85
+ - 同时支持监控以下行为,并触发设定的执行器,详见参数 trigger_for_new 和 trigger_for_out。
86
+ - 新加值是一个新的帕累托最优值
87
+ - 抛弃一个不再是最优的旧的最优值
88
+ - statistician
89
+ - 【new feature】add Accumulator_for_Ndl,适用于 ndl 结构的统计器。
90
+ - 【bug fix】fix bug in Accumulator_Base._init_var()
91
+ - 【new feature】modify Average_Accumulator,在 add() 中新增了 weight 参数用于计算带权重的平均值
92
+ - modify Exponential_Moving_Average,add_sequence() 不再支持 weight_ls 参数,让该接口与其他类更加一致。
93
+ - 添加了测试用例。
94
+
95
+