kevin-toolbox-dev 1.3.5__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/__init__.py +1 -0
  3. kevin_toolbox/computer_science/algorithm/pareto_front/optimum_picker.py +218 -0
  4. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +2 -2
  6. kevin_toolbox/computer_science/algorithm/statistician/accumulator_for_ndl.py +69 -0
  7. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +16 -4
  8. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +5 -12
  9. kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_writer.py +3 -0
  10. kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation_debug.py +27 -0
  11. kevin_toolbox/data_flow/file/kevin_notation/write.py +3 -0
  12. kevin_toolbox/data_flow/file/markdown/__init__.py +3 -0
  13. kevin_toolbox/data_flow/file/markdown/find_tables.py +65 -0
  14. kevin_toolbox/data_flow/file/markdown/generate_table.py +19 -5
  15. kevin_toolbox/data_flow/file/markdown/parse_table.py +135 -0
  16. kevin_toolbox/data_flow/file/markdown/save_images_in_ndl.py +81 -0
  17. kevin_toolbox/data_flow/file/markdown/variable.py +17 -0
  18. kevin_toolbox/nested_dict_list/serializer/read.py +1 -1
  19. kevin_toolbox/nested_dict_list/serializer/write.py +10 -2
  20. kevin_toolbox/patches/for_matplotlib/__init__.py +1 -1
  21. kevin_toolbox/patches/for_matplotlib/clear_border_of_axes.py +34 -0
  22. kevin_toolbox/patches/for_matplotlib/color/__init__.py +4 -0
  23. kevin_toolbox/patches/for_matplotlib/color/color_format.py +7 -0
  24. kevin_toolbox/patches/for_matplotlib/color/convert_format.py +108 -0
  25. kevin_toolbox/patches/for_matplotlib/color/generate_color_list.py +50 -0
  26. kevin_toolbox/patches/for_matplotlib/color/get_format.py +12 -0
  27. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +6 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +54 -0
  29. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +60 -0
  30. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +65 -0
  31. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +61 -0
  32. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +53 -0
  33. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +54 -0
  34. kevin_toolbox/patches/for_optuna/build_study.py +1 -1
  35. kevin_toolbox/patches/for_os/__init__.py +2 -0
  36. kevin_toolbox/patches/for_os/find_files_in_dir.py +30 -0
  37. kevin_toolbox/patches/for_os/path/__init__.py +2 -0
  38. kevin_toolbox/patches/for_os/path/find_illegal_chars.py +47 -0
  39. kevin_toolbox/patches/for_os/path/replace_illegal_chars.py +49 -0
  40. kevin_toolbox_dev-1.3.7.dist-info/METADATA +103 -0
  41. {kevin_toolbox_dev-1.3.5.dist-info → kevin_toolbox_dev-1.3.7.dist-info}/RECORD +43 -20
  42. kevin_toolbox/patches/for_matplotlib/generate_color_list.py +0 -33
  43. kevin_toolbox_dev-1.3.5.dist-info/METADATA +0 -74
  44. {kevin_toolbox_dev-1.3.5.dist-info → kevin_toolbox_dev-1.3.7.dist-info}/WHEEL +0 -0
  45. {kevin_toolbox_dev-1.3.5.dist-info → kevin_toolbox_dev-1.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ import re
2
+ from typing import Union
3
+ from kevin_toolbox.data_flow.file.markdown.variable import Table_Format
4
+
5
+
6
+ def parse_table(raw_table, output_format: Union[Table_Format, str] = Table_Format.COMPLETE_DICT, orientation="vertical",
7
+ chunk_size=None, chunk_nums=None, b_remove_empty_lines=False, f_gen_order_of_values=None):
8
+ """
9
+ 将二维数组形式的表格(比如find_tables()的返回列表的元素),解析成指定的格式
10
+
11
+ 参数:
12
+ raw_table: <list of list> 二维数组形式的表格
13
+ output_format: <Table_Format or str> 目标格式
14
+ 具体可以参考 Table_Format 的介绍
15
+ orientation: <str> 解释表格时取哪个方向
16
+ 支持以下值:
17
+ "vertical" / "v": 将第一行作为标题
18
+ "horizontal" / "h": 将第一列作为标题
19
+ chunk_nums: <int> 表格被平均分割为多少份进行并列显示。
20
+ chunk_size: <int> 表格被按照最大长度进行分割,然后并列显示。
21
+ 以上两个参数是用于解释 generate_table() 中使用对应参数生成的表格,其中 chunk_size 仅作检验行数是否符合要求,
22
+ 对解释表格无作用。但是当指定该参数时,将视为表格有可能是多个表格并列的情况,因此将尝试根据标题的重复规律,
23
+ 推断出对应的 chunk_nums,并最终将其拆分成多个表格。
24
+ b_remove_empty_lines: <boolean> 移除空的行、列
25
+ f_gen_order_of_values: <callable> 生成values排序顺序的函数
26
+ 具体参考 generate_table() 中的对应参数
27
+ """
28
+ assert isinstance(raw_table, (list, tuple,))
29
+
30
+ # 转换为字典形式
31
+ if orientation not in ["vertical", "v"]:
32
+ # 需要转为垂直方向
33
+ raw_table = list(zip(*raw_table))
34
+ r_nums, c_nums = len(raw_table), len(raw_table[0])
35
+ if chunk_size is not None:
36
+ assert chunk_size == r_nums - 1, \
37
+ (f'The number of values {r_nums - 1} actually contained in the table '
38
+ f'does not match the specified chunk_size {chunk_size}')
39
+ chunk_nums = c_nums // _find_shortest_repeating_pattern_size(arr=raw_table[0])
40
+ chunk_nums = 1 if chunk_nums is None else chunk_nums
41
+ assert c_nums % chunk_nums == 0, \
42
+ f'The number of headers actually contained in the table does not match the specified chunk_nums, ' \
43
+ f'Expected n*{chunk_nums}, but got {c_nums}'
44
+ # 解释出标题
45
+ keys = raw_table[0][0:c_nums // chunk_nums]
46
+ # 解释出值
47
+ if chunk_nums == 1:
48
+ values = raw_table[1:]
49
+ else:
50
+ values = []
51
+ for i in range(chunk_nums):
52
+ for j in range(1, r_nums):
53
+ values.append(raw_table[j][i * len(keys):(i + 1) * len(keys)])
54
+ # 去除空行
55
+ if b_remove_empty_lines:
56
+ values = [line for line in values if any(i != '' for i in line)]
57
+ table_s = {i: {"title": k, "values": list(v)} for i, (k, v) in enumerate(zip(keys, list(zip(*values))))}
58
+ # 去除空列
59
+ if b_remove_empty_lines:
60
+ table_s = {k: v_s for k, v_s in table_s.items() if v_s["title"] != '' and any(i != '' for i in v_s["values"])}
61
+ # 对值进行排序
62
+ if callable(f_gen_order_of_values):
63
+ breakpoint()
64
+ # 检查是否有重复的 title
65
+ temp = [v["title"] for v in table_s.values()]
66
+ assert len(set(temp)) == len(temp), \
67
+ f'table has duplicate titles, thus cannot be sorted using f_gen_order_of_values'
68
+ idx_ls = list(range(len(values)))
69
+ idx_ls.sort(key=lambda x: f_gen_order_of_values({v["title"]: v["values"][x] for v in table_s.values()}))
70
+ for v in table_s.values():
71
+ v["values"] = [v["values"][i] for i in idx_ls]
72
+
73
+ #
74
+ if output_format is Table_Format.SIMPLE_DICT:
75
+ temp = {v_s["title"] for v_s in table_s.values()}
76
+ if len(temp) != len(set(temp)):
77
+ raise AssertionError(
78
+ f'There are columns with the same title in the table, '
79
+ f'please check the orientation of the table or use output_format="complete_dict"')
80
+ table_s = {v_s["title"]: v_s["values"] for v_s in table_s.values()}
81
+
82
+ return table_s
83
+
84
+
85
+ def _find_shortest_repeating_pattern_size(arr):
86
+ n = len(arr)
87
+
88
+ # 部分匹配表
89
+ pi = [0] * n
90
+ k = 0
91
+ for i in range(1, n):
92
+ if k > 0 and arr[k] != arr[i]:
93
+ k = 0
94
+ if arr[k] == arr[i]:
95
+ k += 1
96
+ pi[i] = k
97
+
98
+ # 最短重复模式的长度
99
+ pattern_length = n - pi[n - 1]
100
+ # 是否是完整的重复模式
101
+ if n % pattern_length != 0:
102
+ pattern_length = n
103
+ return pattern_length
104
+
105
+
106
+ if __name__ == '__main__':
107
+ from kevin_toolbox.data_flow.file.markdown import find_tables
108
+ # # 示例Markdown表格文本
109
+ # file_path = ""
110
+ # with open(file_path, 'r') as f:
111
+ # markdown_text = f.read()
112
+
113
+ # markdown_text = """
114
+ # | Name | Age | Occupation |
115
+ # |------|-----|------------|
116
+ # | Alice | 28 | Engineer |
117
+ # | Bob | 23 | Teacher |
118
+ # | Name | Age | Occupation |
119
+ # | Carol | 32 | Hacker |
120
+ # | David | 18 | Student |
121
+ # """
122
+
123
+ markdown_text = """
124
+ | | a | b | | a | b | | a | b |
125
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- |
126
+ | | 0 | 2 | | 4 | 6 | | 7 | 9 |
127
+ | | 1 | 3 | | 5 | 7 | | 8 | : |
128
+ | | 2 | 4 | | 6 | 8 | | 9 | ; |
129
+ | | 3 | 5 | | | | | | |
130
+ """
131
+ table_ls = find_tables(text=markdown_text)
132
+
133
+ # 调用函数并打印结果
134
+ tables = parse_table(raw_table=table_ls[0], output_format="complete_dict", chunk_nums=3, b_remove_empty_lines=True)
135
+ print(tables)
@@ -0,0 +1,81 @@
1
+ import os
2
+ import warnings
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from collections import defaultdict
8
+ from kevin_toolbox.data_flow.file import markdown
9
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars, find_illegal_chars
10
+ import kevin_toolbox.nested_dict_list as ndl
11
+
12
+
13
+ def save_images_in_ndl(var, plot_dir, doc_dir=None, setting_s=None):
14
+ """
15
+ 将ndl结构叶节点下的图片对象保存到 plot_dir 中,并替换为该图片的markdown链接
16
+
17
+ 参数:
18
+ var: <dict> 待处理的 ndl 结构
19
+ plot_dir: <path> 图片保存的目录
20
+ doc_dir: <path> 输出的markdown文档保存的目录
21
+ 当有指定时,图片链接将以相对于 doc_dir 的相对路径的形式保存
22
+ 默认为 None,此时保存的markdown图片链接使用的是绝对路径
23
+ setting_s: <dict> 配置
24
+ 指定要在哪些节点下去寻找图片对象,以及转换图片对象时使用的参数
25
+ 形式为 {<node name>: {"b_is_rgb":<boolean>, ...}, ...}
26
+ 其中配置项支持:
27
+ - b_is_rgb: 待保存的图片是RGB顺序还是BGR顺序
28
+ - saved_image_format: 保存图片时使用的格式
29
+ 默认为 None,此时等效于 {"": {"b_is_rgb": False, "saved_image_format": ".jpg"}}
30
+ """
31
+ if len(find_illegal_chars(file_name=plot_dir, b_is_path=True)) > 0:
32
+ warnings.warn(f'plot_dir {plot_dir} contains illegal symbols, '
33
+ f'which may cause compatibility issues on certain systems.', UserWarning)
34
+ setting_s = setting_s or {"": {"b_is_rgb": False, "saved_image_format": ".jpg"}}
35
+
36
+ # 将配置解释到各个叶节点
37
+ # 从最浅的路径开始,若更深的路径有另外的设置,则以更新的为准
38
+ root_ls = list(setting_s.keys())
39
+ root_ls.sort(key=lambda x: len(ndl.name_handler.parse_name(name=x)[-1]))
40
+ root_to_leaf_s = defaultdict(set)
41
+ leaf_to_root_s = dict()
42
+ leaf_to_value_s = dict()
43
+ for root in root_ls:
44
+ for leaf, v in ndl.get_nodes(var=ndl.get_value(var=var, name=root, b_pop=False), level=-1, b_strict=True):
45
+ leaf = root + leaf
46
+ if leaf in leaf_to_root_s:
47
+ root_to_leaf_s[leaf_to_root_s[leaf]].remove(leaf)
48
+ root_to_leaf_s[root].add(leaf)
49
+ leaf_to_root_s[leaf] = root
50
+ leaf_to_value_s[leaf] = v
51
+
52
+ for root, leaf_ls in root_to_leaf_s.items():
53
+ setting_ = setting_s[root]
54
+ for leaf in leaf_ls:
55
+ v = leaf_to_value_s[leaf]
56
+ if isinstance(v, Image.Image):
57
+ v = np.asarray(v)
58
+ elif torch.is_tensor(v):
59
+ v = v.detach().cpu().numpy()
60
+ #
61
+ if isinstance(v, np.ndarray):
62
+ image_path = os.path.join(
63
+ plot_dir, replace_illegal_chars(
64
+ file_name=f'{leaf}_{setting_["saved_image_format"]}', b_is_path=False)
65
+ )
66
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
67
+ if setting_["b_is_rgb"]:
68
+ v = cv2.cvtColor(v, cv2.COLOR_RGB2BGR)
69
+ cv2.imwrite(image_path, v)
70
+ v_new = markdown.generate_link(
71
+ name=os.path.basename(image_path),
72
+ target=os.path.relpath(image_path, doc_dir) if doc_dir is not None else image_path, type_="image")
73
+ elif v is None:
74
+ v_new = "/"
75
+ else:
76
+ v_new = v
77
+ ndl.set_value(var=var, name=leaf, b_force=False, value=v_new)
78
+ return var
79
+
80
+
81
+
@@ -0,0 +1,17 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Table_Format(Enum):
5
+ """
6
+ 表格的几种模式
7
+ 1.simple_dict 简易字典模式:
8
+ content_s = {<title>: <list of value>, ...}
9
+ 此时键作为标题,值作为标题下的一系列值。
10
+ 由于字典的无序性,此时标题的顺序是不能保证的,若要额外指定顺序,请使用下面的 完整模式。
11
+ 2. complete_dict 完整字典模式:
12
+ content_s = {<index>: {"title": <title>,"values":<list of value>}, ...}
13
+ 此时将取第 <index> 个 "title" 的值来作为第 <index> 个标题的值。values 同理。
14
+ 该模式允许缺省某些 <index>,此时这些 <index> 对应的行/列将全部置空。
15
+ """
16
+ SIMPLE_DICT = "simple_dict"
17
+ COMPLETE_DICT = "complete_dict"
@@ -18,7 +18,7 @@ def read(input_path, **kwargs):
18
18
  assert os.path.exists(input_path)
19
19
 
20
20
  with tempfile.TemporaryDirectory(dir=os.path.dirname(input_path)) as temp_dir:
21
- if os.path.isfile(input_path) and input_path.endswith(".tar"): # 解压
21
+ if os.path.isfile(input_path) and input_path.endswith(".tar"): # 解压
22
22
  for_os.unpack(source=input_path, target=temp_dir)
23
23
  input_path = os.path.join(temp_dir, os.listdir(temp_dir)[0])
24
24
  var = _read_unpacked_ndl(input_path, **kwargs)
@@ -13,7 +13,7 @@ from .saved_node_name_builder import Saved_Node_Name_Builder
13
13
 
14
14
  def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_pack_into_tar=True,
15
15
  strictness_level=Strictness_Level.COMPATIBLE, saved_node_name_format='{count}_{hash_name}',
16
- b_keep_identical_relations=False, **kwargs):
16
+ b_keep_identical_relations=False, b_allow_overwrite=False, **kwargs):
17
17
  """
18
18
  将输入的嵌套字典列表 var 的结构和节点值保存到文件中
19
19
  遍历 var,匹配并使用 settings 中设置的保存方式来对各部分结构/节点进行序列化
@@ -106,6 +106,8 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
106
106
  替换为单个节点和其多个引用的形式。
107
107
  对于 ndl 中存在大量具有相同 id 的重复节点的情况,使用该操作可以额外达到压缩的效果。
108
108
  默认为 False
109
+ b_allow_overwrite: <boolean> 是否允许强制覆盖已有文件
110
+ 默认为 False,此时若目标文件已存在则报错
109
111
  """
110
112
  from kevin_toolbox.nested_dict_list.serializer.variable import SERIALIZER_BACKEND
111
113
 
@@ -113,7 +115,12 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
113
115
  traversal_mode = Traversal_Mode(traversal_mode)
114
116
  strictness_level = Strictness_Level(strictness_level)
115
117
  #
116
- assert not os.path.exists(output_dir + ".tar" if b_pack_into_tar else output_dir), f'target already exists'
118
+ tgt_path = output_dir + ".tar" if b_pack_into_tar else output_dir
119
+ if os.path.exists(tgt_path):
120
+ if b_allow_overwrite:
121
+ for_os.remove(path=tgt_path, ignore_errors=True)
122
+ else:
123
+ raise FileExistsError(f"target {tgt_path} already exists")
117
124
  os.makedirs(os.path.dirname(output_dir), exist_ok=True)
118
125
  temp_dir = tempfile.TemporaryDirectory(dir=os.path.dirname(output_dir))
119
126
  temp_output_dir = os.path.join(temp_dir.name, os.path.basename(output_dir))
@@ -204,6 +211,7 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
204
211
  file_path=os.path.join(temp_output_dir, "record.json"), b_use_suggested_converter=True)
205
212
 
206
213
  # 打包成 .tar 文件
214
+ for_os.remove(path=tgt_path, ignore_errors=True)
207
215
  if b_pack_into_tar:
208
216
  for_os.pack(source=temp_output_dir)
209
217
  os.rename(temp_output_dir + ".tar", output_dir + ".tar")
@@ -1,4 +1,4 @@
1
1
  from .arrow3d import Arrow3D
2
2
  from .add_trajectory_2d import add_trajectory_2d
3
3
  from .add_trajectory_3d import add_trajectory_3d
4
- from .generate_color_list import generate_color_list
4
+ from .clear_border_of_axes import clear_border_of_axes
@@ -0,0 +1,34 @@
1
+ def clear_border_of_axes(ax):
2
+ """
3
+ 用于清除 ax 中的坐标轴和 ticks
4
+ """
5
+ ax.set_xticks([])
6
+ ax.set_yticks([])
7
+ ax.spines['left'].set_color('none')
8
+ ax.spines['right'].set_color('none')
9
+ ax.spines['bottom'].set_color('none')
10
+ ax.spines['top'].set_color('none')
11
+ return ax
12
+
13
+
14
+ if __name__ == '__main__':
15
+ import matplotlib.pyplot as plt
16
+
17
+ #
18
+ fig, ax = plt.subplots()
19
+
20
+ x = [1, 4]
21
+ y = [1, 10]
22
+ ax.plot(x, y)
23
+
24
+ # 设置坐标轴的范围,以便更好地展示直线
25
+ ax.set_xlim([0, 5])
26
+ ax.set_ylim([0, 15])
27
+
28
+ # 添加标题和坐标轴标签
29
+ ax.set_xlabel('X')
30
+ ax.set_ylabel('Y')
31
+ clear_border_of_axes(ax)
32
+
33
+ # 显示图形
34
+ plt.show()
@@ -0,0 +1,4 @@
1
+ from .color_format import Color_Format
2
+ from .get_format import get_format
3
+ from .convert_format import convert_format
4
+ from .generate_color_list import generate_color_list
@@ -0,0 +1,7 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Color_Format(Enum):
5
+ HEX_STR = "hex_str" # 例如 '#FF573380'
6
+ RGBA_ARRAY = "rgba_array" # 例如 (255, 87, 51, 0.5)
7
+ NATURAL_NAME = "natural_name" # 例如 'red'
@@ -0,0 +1,108 @@
1
+ from kevin_toolbox.patches.for_matplotlib.color import Color_Format, get_format
2
+
3
+
4
+ def hex_to_rgba(hex_color):
5
+ hex_color = hex_color.lstrip('#')
6
+ assert len(hex_color) in (6, 8), \
7
+ f'hex_color should be 6 or 8 characters long (not including #). but got {len(hex_color)}'
8
+ res = list(int(hex_color[i * 2:i * 2 + 2], 16) for i in range(len(hex_color) // 2))
9
+ if len(res) not in (3, 4):
10
+ breakpoint()
11
+ if len(res) == 4:
12
+ res[3] /= 255
13
+ return tuple(res)
14
+
15
+
16
+ def rgba_to_hex(rgba):
17
+ assert len(rgba) in (3, 4), \
18
+ f'rgba should be 3 or 4 elements long. but got {len(rgba)}'
19
+ if len(rgba) == 4:
20
+ rgba = list(rgba)
21
+ rgba[3] = max(0, min(255, int(255 * rgba[3])))
22
+ res = "#"
23
+ for i in rgba:
24
+ res += f'{i:02X}'
25
+ return res
26
+
27
+
28
+ NAME_TO_HEX = {
29
+ 'blue': '#0000FF',
30
+ 'red': '#FF0000',
31
+ 'green': '#008000',
32
+ 'orange': '#FFA500',
33
+ 'purple': '#800080',
34
+ 'yellow': '#FFFF00',
35
+ 'brown': '#A52A2A',
36
+ 'pink': '#FFC0CB',
37
+ 'gray': '#808080',
38
+ 'olive': '#808000',
39
+ 'cyan': '#00FFFF'
40
+ }
41
+ HEX_TO_NAME = {v: k for k, v in NAME_TO_HEX.items()}
42
+
43
+
44
+ def natural_name_to_hex(name):
45
+ global NAME_TO_HEX
46
+ name = name.lower()
47
+ assert name in NAME_TO_HEX, \
48
+ f'{name} is not a valid color name.'
49
+ return NAME_TO_HEX[name]
50
+
51
+
52
+ def hex_to_natural_name(hex_color):
53
+ global HEX_TO_NAME
54
+ hex_color = hex_color.upper()[:7]
55
+ assert hex_color in HEX_TO_NAME, \
56
+ f'{hex_color} does not has corresponding color name.'
57
+ return HEX_TO_NAME[hex_color]
58
+
59
+
60
+ CONVERT_PROCESS_S = {
61
+ (Color_Format.HEX_STR, Color_Format.NATURAL_NAME): hex_to_natural_name, # (from, to): process
62
+ (Color_Format.HEX_STR, Color_Format.RGBA_ARRAY): hex_to_rgba,
63
+ (Color_Format.NATURAL_NAME, Color_Format.HEX_STR): natural_name_to_hex,
64
+ (Color_Format.NATURAL_NAME, Color_Format.RGBA_ARRAY): lambda x: hex_to_rgba(natural_name_to_hex(x)),
65
+ (Color_Format.RGBA_ARRAY, Color_Format.HEX_STR): rgba_to_hex,
66
+ (Color_Format.RGBA_ARRAY, Color_Format.NATURAL_NAME): lambda x: hex_to_natural_name(rgba_to_hex(x))
67
+ }
68
+
69
+
70
+ def convert_format(var, output_format, input_format=None):
71
+ """
72
+ 在各种颜色格式之间进行转换
73
+
74
+ 参数:
75
+ var:
76
+ input_format: <str> 描述输入的格式。
77
+ 支持 HEX_STR、NATURAL_NAME、RGBA_ARRAY 等格式,
78
+ 默认为 None,此时将根据输入推断格式
79
+ output_format: <str/list of str> 输出的目标格式。
80
+ 当输入是一个 tuple/list 时,将输出其中任一格式,具体规则为:
81
+ - 当 input_format 不在可选的输出格式中时,优先按照第一个输出格式进行转换。
82
+ 若转换失败,则按照第二个输出格式进行转换。依次类推。
83
+ - 当 input_format 在可选的输出格式中时,不进行转换。
84
+ """
85
+ global CONVERT_PROCESS_S
86
+ if input_format is None:
87
+ input_format = get_format(var=var)
88
+ input_format = Color_Format(input_format)
89
+ if not isinstance(output_format, (list, tuple,)):
90
+ output_format = [output_format]
91
+ output_format = [Color_Format(i) for i in output_format]
92
+
93
+ if input_format in output_format:
94
+ return var
95
+ else:
96
+ for output_format_i in output_format:
97
+ try:
98
+ return CONVERT_PROCESS_S[(input_format, output_format_i)](var)
99
+ except Exception as e:
100
+ raise Exception(f'fail to convert {var} from {input_format} to {output_format}, beacause: {e}')
101
+
102
+
103
+ if __name__ == '__main__':
104
+ print(hex_to_rgba('#FF57337F'))
105
+ print(rgba_to_hex((255, 87, 51, 0.5)))
106
+ print(natural_name_to_hex('pink'))
107
+ print(convert_format(var='#FF57337F', input_format='hex_str', output_format='rgba_array'))
108
+ print(convert_format(var="#0000FF", output_format="rgba_array"))
@@ -0,0 +1,50 @@
1
+ from kevin_toolbox.patches.for_matplotlib.color import Color_Format, convert_format
2
+ from kevin_toolbox.patches.for_numpy import random
3
+
4
+ PREDEFINED = ['blue', 'red', 'green', 'orange', 'purple', 'yellow', "brown", "pink", "gray", "olive", "cyan"]
5
+ PREDEFINED = [convert_format(var=i, output_format=Color_Format.HEX_STR) for i in PREDEFINED]
6
+
7
+ population = tuple('0123456789ABCDEF')
8
+
9
+
10
+ def generate_color_list(nums, seed=None, rng=None, exclude_ls=None, output_format=Color_Format.HEX_STR):
11
+ """
12
+ 生成颜色列表
13
+
14
+ 参数:
15
+ nums: <int> 生成颜色的数量
16
+ seed,rng: 随机种子或随机生成器,二选一
17
+ exclude: <list of str> 需要排除的颜色
18
+ output_format: <Color_Format/str> 输出格式
19
+ 支持 HEX_STR、RGBA_ARRAY 两种格式
20
+ 返回:
21
+ 不包含 alpha 透明度值的颜色列表
22
+ """
23
+ global PREDEFINED, population
24
+ assert output_format in [Color_Format.HEX_STR, Color_Format.RGBA_ARRAY]
25
+ output_format = Color_Format(output_format)
26
+ if exclude_ls is None:
27
+ exclude_ls = []
28
+ assert isinstance(exclude_ls, (list, tuple))
29
+ exclude_ls = set(convert_format(var=i, output_format=Color_Format.HEX_STR) for i in exclude_ls)
30
+ rng = random.get_rng(seed=seed, rng=rng)
31
+
32
+ colors = [i for i in PREDEFINED if i not in exclude_ls][:nums] # 优先输出预定义的颜色
33
+
34
+ # 随机生成剩余数量的颜色
35
+ while len(colors) < nums:
36
+ c = "#" + ''.join(
37
+ rng.choice(population, size=6, replace=True))
38
+ if c not in colors and c not in exclude_ls:
39
+ colors.append(c)
40
+ colors = [convert_format(c, output_format=output_format) for c in colors]
41
+
42
+ return colors
43
+
44
+
45
+ if __name__ == '__main__':
46
+ color_list = generate_color_list(1, exclude_ls=['blue'])
47
+ print(color_list)
48
+
49
+ color_list = generate_color_list(nums=1, seed=114, exclude_ls=['#0000FF'])
50
+ print(color_list)
@@ -0,0 +1,12 @@
1
+ from kevin_toolbox.patches.for_matplotlib.color import Color_Format
2
+
3
+
4
+ def get_format(var):
5
+ if isinstance(var, str):
6
+ if var.startswith("#"):
7
+ res = Color_Format.HEX_STR
8
+ else:
9
+ res = Color_Format.NATURAL_NAME
10
+ else:
11
+ res = Color_Format.RGBA_ARRAY
12
+ return res
@@ -0,0 +1,6 @@
1
+ from .plot_lines import plot_lines
2
+ from .plot_scatters import plot_scatters
3
+ from .plot_distribution import plot_distribution
4
+ from .plot_bars import plot_bars
5
+ from .plot_scatters_matrix import plot_scatters_matrix
6
+ from .plot_confusion_matrix import plot_confusion_matrix
@@ -0,0 +1,54 @@
1
+ import os
2
+ import copy
3
+ from kevin_toolbox.computer_science.algorithm import for_seq
4
+ import matplotlib.pyplot as plt
5
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
6
+
7
+ # TODO 在 linux 系统下遇到中文时,尝试自动下载中文字体,并尝试自动设置字体
8
+ # font_path = os.path.join(root_dir, "utils/SimHei.ttf")
9
+ # font_name = FontProperties(fname=font_path)
10
+
11
+
12
+ def plot_bars(data_s, title, x_name, y_label=None, output_dir=None, **kwargs):
13
+ data_s = copy.deepcopy(data_s)
14
+ paras = {
15
+ "dpi": 200
16
+ }
17
+ paras.update(kwargs)
18
+
19
+ plt.clf()
20
+ #
21
+ x_all_ls = data_s.pop(x_name)
22
+ #
23
+ for i, (k, y_ls) in enumerate(data_s.items()):
24
+ if i == 0:
25
+ plt.bar([j - 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
26
+ else:
27
+ plt.bar([j + 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
28
+
29
+ plt.xlabel(f'{x_name}')
30
+ plt.ylabel(f'{y_label if y_label else "value"}')
31
+ temp = for_seq.flatten_list([list(i) for i in data_s.values()])
32
+ y_min, y_max = min(temp), max(temp)
33
+ plt.ylim(max(min(y_min, 0), y_min - (y_max - y_min) * 0.2), y_max + (y_max - y_min) * 0.1)
34
+ plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls) # , fontproperties=font_name
35
+ plt.title(f'{title}')
36
+ # 显示图例
37
+ plt.legend()
38
+
39
+ if output_dir is None:
40
+ plt.show()
41
+ return None
42
+ else:
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
45
+ plt.savefig(output_path, dpi=paras["dpi"])
46
+ return output_path
47
+
48
+
49
+ if __name__ == '__main__':
50
+ plot_bars(data_s={
51
+ 'a': [1.5, 2, 3, 4, 5],
52
+ 'b': [5, 4, 3, 2, 1],
53
+ 'c': [1, 2, 3, 4, 5]},
54
+ title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
@@ -0,0 +1,60 @@
1
+ import os
2
+ from sklearn.metrics import confusion_matrix
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
6
+
7
+
8
+ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None, **kwargs):
9
+ paras = {
10
+ "dpi": 200,
11
+ "normalize": None, # "true", "pred", "all",
12
+ "b_return_cfm": False, # 是否输出混淆矩阵
13
+ }
14
+ paras.update(kwargs)
15
+
16
+ value_set = set(data_s[gt_name]).union(set(data_s[pd_name]))
17
+ if label_to_value_s is None:
18
+ label_to_value_s = {f'{i}': i for i in value_set}
19
+ else:
20
+ assert all(i in value_set for i in label_to_value_s.values())
21
+ # 计算混淆矩阵
22
+ cfm = confusion_matrix(y_true=data_s[gt_name], y_pred=data_s[pd_name], labels=list(label_to_value_s.values()),
23
+ normalize=paras["normalize"])
24
+ # 绘制混淆矩阵热力图
25
+ plt.clf()
26
+ plt.figure(figsize=(8, 6))
27
+ sns.heatmap(cfm, annot=True, fmt='.2%' if paras["normalize"] is not None else 'd',
28
+ xticklabels=list(label_to_value_s.keys()), yticklabels=list(label_to_value_s.keys()),
29
+ cmap='viridis')
30
+
31
+ plt.xlabel(f'{pd_name}')
32
+ plt.ylabel(f'{gt_name}')
33
+ plt.title(f'{title}')
34
+
35
+ if output_dir is None:
36
+ plt.show()
37
+ output_path = None
38
+ else:
39
+ os.makedirs(output_dir, exist_ok=True)
40
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
41
+ plt.savefig(output_path, dpi=paras["dpi"])
42
+
43
+ if paras["b_return_cfm"]:
44
+ return output_path, cfm
45
+ else:
46
+ return output_path
47
+
48
+
49
+ if __name__ == '__main__':
50
+ import numpy as np
51
+
52
+ # 示例真实标签和预测标签
53
+ y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 5])
54
+ y_pred = np.array([0, 2, 1, 0, 2, 1, 0, 1, 1, 5])
55
+
56
+ plot_confusion_matrix(data_s={'a': y_true, 'b': y_pred},
57
+ title='test', gt_name='a', pd_name='b',
58
+ label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2},
59
+ # output_dir=os.path.join(os.path.dirname(__file__), "temp"),
60
+ normalize="true")