kevin-toolbox-dev 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/__init__.py +1 -0
  3. kevin_toolbox/computer_science/algorithm/pareto_front/optimum_picker.py +218 -0
  4. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +2 -2
  6. kevin_toolbox/computer_science/algorithm/statistician/accumulator_for_ndl.py +69 -0
  7. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +16 -4
  8. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +5 -12
  9. kevin_toolbox/data_flow/file/json_/read_json.py +11 -5
  10. kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_writer.py +4 -1
  11. kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation_debug.py +27 -0
  12. kevin_toolbox/data_flow/file/kevin_notation/write.py +3 -0
  13. kevin_toolbox/data_flow/file/markdown/__init__.py +3 -0
  14. kevin_toolbox/data_flow/file/markdown/find_tables.py +65 -0
  15. kevin_toolbox/data_flow/file/markdown/generate_table.py +19 -5
  16. kevin_toolbox/data_flow/file/markdown/parse_table.py +135 -0
  17. kevin_toolbox/data_flow/file/markdown/save_images_in_ndl.py +81 -0
  18. kevin_toolbox/data_flow/file/markdown/variable.py +17 -0
  19. kevin_toolbox/nested_dict_list/serializer/backends/_ndl.py +4 -1
  20. kevin_toolbox/nested_dict_list/serializer/read.py +18 -14
  21. kevin_toolbox/nested_dict_list/serializer/write.py +23 -7
  22. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +6 -0
  23. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +54 -0
  24. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +60 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +65 -0
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +61 -0
  27. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +53 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +54 -0
  29. kevin_toolbox/patches/for_os/__init__.py +3 -0
  30. kevin_toolbox/patches/for_os/copy.py +33 -0
  31. kevin_toolbox/patches/for_os/find_files_in_dir.py +30 -0
  32. kevin_toolbox/patches/for_os/path/__init__.py +2 -0
  33. kevin_toolbox/patches/for_os/path/find_illegal_chars.py +47 -0
  34. kevin_toolbox/patches/for_os/path/replace_illegal_chars.py +49 -0
  35. kevin_toolbox/patches/for_test/check_consistency.py +104 -33
  36. kevin_toolbox_dev-1.3.6.dist-info/METADATA +95 -0
  37. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/RECORD +39 -20
  38. kevin_toolbox_dev-1.3.4.dist-info/METADATA +0 -67
  39. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/WHEEL +0 -0
  40. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ import re
2
+ from typing import Union
3
+ from kevin_toolbox.data_flow.file.markdown.variable import Table_Format
4
+
5
+
6
+ def parse_table(raw_table, output_format: Union[Table_Format, str] = Table_Format.COMPLETE_DICT, orientation="vertical",
7
+ chunk_size=None, chunk_nums=None, b_remove_empty_lines=False, f_gen_order_of_values=None):
8
+ """
9
+ 将二维数组形式的表格(比如find_tables()的返回列表的元素),解析成指定的格式
10
+
11
+ 参数:
12
+ raw_table: <list of list> 二维数组形式的表格
13
+ output_format: <Table_Format or str> 目标格式
14
+ 具体可以参考 Table_Format 的介绍
15
+ orientation: <str> 解释表格时取哪个方向
16
+ 支持以下值:
17
+ "vertical" / "v": 将第一行作为标题
18
+ "horizontal" / "h": 将第一列作为标题
19
+ chunk_nums: <int> 表格被平均分割为多少份进行并列显示。
20
+ chunk_size: <int> 表格被按照最大长度进行分割,然后并列显示。
21
+ 以上两个参数是用于解释 generate_table() 中使用对应参数生成的表格,其中 chunk_size 仅作检验行数是否符合要求,
22
+ 对解释表格无作用。但是当指定该参数时,将视为表格有可能是多个表格并列的情况,因此将尝试根据标题的重复规律,
23
+ 推断出对应的 chunk_nums,并最终将其拆分成多个表格。
24
+ b_remove_empty_lines: <boolean> 移除空的行、列
25
+ f_gen_order_of_values: <callable> 生成values排序顺序的函数
26
+ 具体参考 generate_table() 中的对应参数
27
+ """
28
+ assert isinstance(raw_table, (list, tuple,))
29
+
30
+ # 转换为字典形式
31
+ if orientation not in ["vertical", "v"]:
32
+ # 需要转为垂直方向
33
+ raw_table = list(zip(*raw_table))
34
+ r_nums, c_nums = len(raw_table), len(raw_table[0])
35
+ if chunk_size is not None:
36
+ assert chunk_size == r_nums - 1, \
37
+ (f'The number of values {r_nums - 1} actually contained in the table '
38
+ f'does not match the specified chunk_size {chunk_size}')
39
+ chunk_nums = c_nums // _find_shortest_repeating_pattern_size(arr=raw_table[0])
40
+ chunk_nums = 1 if chunk_nums is None else chunk_nums
41
+ assert c_nums % chunk_nums == 0, \
42
+ f'The number of headers actually contained in the table does not match the specified chunk_nums, ' \
43
+ f'Expected n*{chunk_nums}, but got {c_nums}'
44
+ # 解释出标题
45
+ keys = raw_table[0][0:c_nums // chunk_nums]
46
+ # 解释出值
47
+ if chunk_nums == 1:
48
+ values = raw_table[1:]
49
+ else:
50
+ values = []
51
+ for i in range(chunk_nums):
52
+ for j in range(1, r_nums):
53
+ values.append(raw_table[j][i * len(keys):(i + 1) * len(keys)])
54
+ # 去除空行
55
+ if b_remove_empty_lines:
56
+ values = [line for line in values if any(i != '' for i in line)]
57
+ table_s = {i: {"title": k, "values": list(v)} for i, (k, v) in enumerate(zip(keys, list(zip(*values))))}
58
+ # 去除空列
59
+ if b_remove_empty_lines:
60
+ table_s = {k: v_s for k, v_s in table_s.items() if v_s["title"] != '' and any(i != '' for i in v_s["values"])}
61
+ # 对值进行排序
62
+ if callable(f_gen_order_of_values):
63
+ breakpoint()
64
+ # 检查是否有重复的 title
65
+ temp = [v["title"] for v in table_s.values()]
66
+ assert len(set(temp)) == len(temp), \
67
+ f'table has duplicate titles, thus cannot be sorted using f_gen_order_of_values'
68
+ idx_ls = list(range(len(values)))
69
+ idx_ls.sort(key=lambda x: f_gen_order_of_values({v["title"]: v["values"][x] for v in table_s.values()}))
70
+ for v in table_s.values():
71
+ v["values"] = [v["values"][i] for i in idx_ls]
72
+
73
+ #
74
+ if output_format is Table_Format.SIMPLE_DICT:
75
+ temp = {v_s["title"] for v_s in table_s.values()}
76
+ if len(temp) != len(set(temp)):
77
+ raise AssertionError(
78
+ f'There are columns with the same title in the table, '
79
+ f'please check the orientation of the table or use output_format="complete_dict"')
80
+ table_s = {v_s["title"]: v_s["values"] for v_s in table_s.values()}
81
+
82
+ return table_s
83
+
84
+
85
+ def _find_shortest_repeating_pattern_size(arr):
86
+ n = len(arr)
87
+
88
+ # 部分匹配表
89
+ pi = [0] * n
90
+ k = 0
91
+ for i in range(1, n):
92
+ if k > 0 and arr[k] != arr[i]:
93
+ k = 0
94
+ if arr[k] == arr[i]:
95
+ k += 1
96
+ pi[i] = k
97
+
98
+ # 最短重复模式的长度
99
+ pattern_length = n - pi[n - 1]
100
+ # 是否是完整的重复模式
101
+ if n % pattern_length != 0:
102
+ pattern_length = n
103
+ return pattern_length
104
+
105
+
106
+ if __name__ == '__main__':
107
+ from kevin_toolbox.data_flow.file.markdown import find_tables
108
+ # # 示例Markdown表格文本
109
+ # file_path = ""
110
+ # with open(file_path, 'r') as f:
111
+ # markdown_text = f.read()
112
+
113
+ # markdown_text = """
114
+ # | Name | Age | Occupation |
115
+ # |------|-----|------------|
116
+ # | Alice | 28 | Engineer |
117
+ # | Bob | 23 | Teacher |
118
+ # | Name | Age | Occupation |
119
+ # | Carol | 32 | Hacker |
120
+ # | David | 18 | Student |
121
+ # """
122
+
123
+ markdown_text = """
124
+ | | a | b | | a | b | | a | b |
125
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- |
126
+ | | 0 | 2 | | 4 | 6 | | 7 | 9 |
127
+ | | 1 | 3 | | 5 | 7 | | 8 | : |
128
+ | | 2 | 4 | | 6 | 8 | | 9 | ; |
129
+ | | 3 | 5 | | | | | | |
130
+ """
131
+ table_ls = find_tables(text=markdown_text)
132
+
133
+ # 调用函数并打印结果
134
+ tables = parse_table(raw_table=table_ls[0], output_format="complete_dict", chunk_nums=3, b_remove_empty_lines=True)
135
+ print(tables)
@@ -0,0 +1,81 @@
1
+ import os
2
+ import warnings
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from collections import defaultdict
8
+ from kevin_toolbox.data_flow.file import markdown
9
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars, find_illegal_chars
10
+ import kevin_toolbox.nested_dict_list as ndl
11
+
12
+
13
+ def save_images_in_ndl(var, plot_dir, doc_dir=None, setting_s=None):
14
+ """
15
+ 将ndl结构叶节点下的图片对象保存到 plot_dir 中,并替换为该图片的markdown链接
16
+
17
+ 参数:
18
+ var: <dict> 待处理的 ndl 结构
19
+ plot_dir: <path> 图片保存的目录
20
+ doc_dir: <path> 输出的markdown文档保存的目录
21
+ 当有指定时,图片链接将以相对于 doc_dir 的相对路径的形式保存
22
+ 默认为 None,此时保存的markdown图片链接使用的是绝对路径
23
+ setting_s: <dict> 配置
24
+ 指定要在哪些节点下去寻找图片对象,以及转换图片对象时使用的参数
25
+ 形式为 {<node name>: {"b_is_rgb":<boolean>, ...}, ...}
26
+ 其中配置项支持:
27
+ - b_is_rgb: 待保存的图片是RGB顺序还是BGR顺序
28
+ - saved_image_format: 保存图片时使用的格式
29
+ 默认为 None,此时等效于 {"": {"b_is_rgb": False, "saved_image_format": ".jpg"}}
30
+ """
31
+ if len(find_illegal_chars(file_name=plot_dir, b_is_path=True)) > 0:
32
+ warnings.warn(f'plot_dir {plot_dir} contains illegal symbols, '
33
+ f'which may cause compatibility issues on certain systems.', UserWarning)
34
+ setting_s = setting_s or {"": {"b_is_rgb": False, "saved_image_format": ".jpg"}}
35
+
36
+ # 将配置解释到各个叶节点
37
+ # 从最浅的路径开始,若更深的路径有另外的设置,则以更新的为准
38
+ root_ls = list(setting_s.keys())
39
+ root_ls.sort(key=lambda x: len(ndl.name_handler.parse_name(name=x)[-1]))
40
+ root_to_leaf_s = defaultdict(set)
41
+ leaf_to_root_s = dict()
42
+ leaf_to_value_s = dict()
43
+ for root in root_ls:
44
+ for leaf, v in ndl.get_nodes(var=ndl.get_value(var=var, name=root, b_pop=False), level=-1, b_strict=True):
45
+ leaf = root + leaf
46
+ if leaf in leaf_to_root_s:
47
+ root_to_leaf_s[leaf_to_root_s[leaf]].remove(leaf)
48
+ root_to_leaf_s[root].add(leaf)
49
+ leaf_to_root_s[leaf] = root
50
+ leaf_to_value_s[leaf] = v
51
+
52
+ for root, leaf_ls in root_to_leaf_s.items():
53
+ setting_ = setting_s[root]
54
+ for leaf in leaf_ls:
55
+ v = leaf_to_value_s[leaf]
56
+ if isinstance(v, Image.Image):
57
+ v = np.asarray(v)
58
+ elif torch.is_tensor(v):
59
+ v = v.detach().cpu().numpy()
60
+ #
61
+ if isinstance(v, np.ndarray):
62
+ image_path = os.path.join(
63
+ plot_dir, replace_illegal_chars(
64
+ file_name=f'{leaf}_{setting_["saved_image_format"]}', b_is_path=False)
65
+ )
66
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
67
+ if setting_["b_is_rgb"]:
68
+ v = cv2.cvtColor(v, cv2.COLOR_RGB2BGR)
69
+ cv2.imwrite(image_path, v)
70
+ v_new = markdown.generate_link(
71
+ name=os.path.basename(image_path),
72
+ target=os.path.relpath(image_path, doc_dir) if doc_dir is not None else image_path, type_="image")
73
+ elif v is None:
74
+ v_new = "/"
75
+ else:
76
+ v_new = v
77
+ ndl.set_value(var=var, name=leaf, b_force=False, value=v_new)
78
+ return var
79
+
80
+
81
+
@@ -0,0 +1,17 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Table_Format(Enum):
5
+ """
6
+ 表格的几种模式
7
+ 1.simple_dict 简易字典模式:
8
+ content_s = {<title>: <list of value>, ...}
9
+ 此时键作为标题,值作为标题下的一系列值。
10
+ 由于字典的无序性,此时标题的顺序是不能保证的,若要额外指定顺序,请使用下面的 完整模式。
11
+ 2. complete_dict 完整字典模式:
12
+ content_s = {<index>: {"title": <title>,"values":<list of value>}, ...}
13
+ 此时将取第 <index> 个 "title" 的值来作为第 <index> 个标题的值。values 同理。
14
+ 该模式允许缺省某些 <index>,此时这些 <index> 对应的行/列将全部置空。
15
+ """
16
+ SIMPLE_DICT = "simple_dict"
17
+ COMPLETE_DICT = "complete_dict"
@@ -27,7 +27,10 @@ class NDL(Backend_Base):
27
27
  """
28
28
  是否可以写
29
29
  """
30
- return True
30
+ if "name" in kwargs:
31
+ return not os.path.exists(os.path.join(self.paras["folder"], f'{kwargs["name"]}'))
32
+ else:
33
+ return True
31
34
 
32
35
  def readable(self, name, **kwargs):
33
36
  """
@@ -3,6 +3,7 @@ import time
3
3
  from kevin_toolbox.patches import for_os
4
4
  from kevin_toolbox.data_flow.file import json_
5
5
  import kevin_toolbox.nested_dict_list as ndl
6
+ import tempfile
6
7
 
7
8
 
8
9
  def read(input_path, **kwargs):
@@ -16,19 +17,26 @@ def read(input_path, **kwargs):
16
17
 
17
18
  assert os.path.exists(input_path)
18
19
 
19
- # 解压
20
- temp_dir = None
21
- if os.path.isfile(input_path) and input_path.endswith(".tar"):
22
- while True:
23
- temp_dir = os.path.join(os.path.dirname(input_path), f'temp{time.time()}')
24
- if not os.path.isdir(temp_dir):
25
- os.makedirs(temp_dir)
26
- break
27
- for_os.unpack(source=input_path, target=temp_dir)
28
- input_path = os.path.join(temp_dir, os.listdir(temp_dir)[0])
20
+ with tempfile.TemporaryDirectory(dir=os.path.dirname(input_path)) as temp_dir:
21
+ if os.path.isfile(input_path) and input_path.endswith(".tar"): # 解压
22
+ for_os.unpack(source=input_path, target=temp_dir)
23
+ input_path = os.path.join(temp_dir, os.listdir(temp_dir)[0])
24
+ var = _read_unpacked_ndl(input_path, **kwargs)
25
+
26
+ return var
27
+
28
+
29
+ def _read_unpacked_ndl(input_path, **kwargs):
30
+ """
31
+ 读取 input_path 中保存的嵌套字典列表
32
+ """
33
+ from kevin_toolbox.nested_dict_list.serializer.variable import SERIALIZER_BACKEND
34
+
35
+ assert os.path.exists(input_path)
29
36
 
30
37
  # 读取 var
31
38
  var = json_.read(file_path=os.path.join(input_path, "var.json"), b_use_suggested_converter=True)
39
+
32
40
  # 读取 record
33
41
  record_s = dict()
34
42
  if os.path.isfile(os.path.join(input_path, "record.json")):
@@ -63,10 +71,6 @@ def read(input_path, **kwargs):
63
71
  from kevin_toolbox.nested_dict_list import value_parser
64
72
  var = value_parser.replace_identical_with_reference(var=var, flag="same", b_reverse=True)
65
73
 
66
- #
67
- if temp_dir is not None:
68
- for_os.remove(path=temp_dir, ignore_errors=True)
69
-
70
74
  return var
71
75
 
72
76
 
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import time
3
3
  import warnings
4
+ import tempfile
4
5
  import kevin_toolbox
5
6
  from kevin_toolbox.data_flow.file import json_
6
7
  from kevin_toolbox.patches import for_os
@@ -12,7 +13,7 @@ from .saved_node_name_builder import Saved_Node_Name_Builder
12
13
 
13
14
  def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_pack_into_tar=True,
14
15
  strictness_level=Strictness_Level.COMPATIBLE, saved_node_name_format='{count}_{hash_name}',
15
- b_keep_identical_relations=False, **kwargs):
16
+ b_keep_identical_relations=False, b_allow_overwrite=False, **kwargs):
16
17
  """
17
18
  将输入的嵌套字典列表 var 的结构和节点值保存到文件中
18
19
  遍历 var,匹配并使用 settings 中设置的保存方式来对各部分结构/节点进行序列化
@@ -105,13 +106,26 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
105
106
  替换为单个节点和其多个引用的形式。
106
107
  对于 ndl 中存在大量具有相同 id 的重复节点的情况,使用该操作可以额外达到压缩的效果。
107
108
  默认为 False
109
+ b_allow_overwrite: <boolean> 是否允许强制覆盖已有文件
110
+ 默认为 False,此时若目标文件已存在则报错
108
111
  """
109
112
  from kevin_toolbox.nested_dict_list.serializer.variable import SERIALIZER_BACKEND
110
113
 
111
114
  # 检查参数
112
115
  traversal_mode = Traversal_Mode(traversal_mode)
113
116
  strictness_level = Strictness_Level(strictness_level)
114
- os.makedirs(output_dir, exist_ok=True)
117
+ #
118
+ tgt_path = output_dir + ".tar" if b_pack_into_tar else output_dir
119
+ if os.path.exists(tgt_path):
120
+ if b_allow_overwrite:
121
+ for_os.remove(path=tgt_path, ignore_errors=True)
122
+ else:
123
+ raise FileExistsError(f"target {tgt_path} already exists")
124
+ os.makedirs(os.path.dirname(output_dir), exist_ok=True)
125
+ temp_dir = tempfile.TemporaryDirectory(dir=os.path.dirname(output_dir))
126
+ temp_output_dir = os.path.join(temp_dir.name, os.path.basename(output_dir))
127
+ os.makedirs(temp_output_dir, exist_ok=True)
128
+ #
115
129
  var = ndl.copy_(var=var, b_deepcopy=False)
116
130
  if b_keep_identical_relations:
117
131
  from kevin_toolbox.nested_dict_list import value_parser
@@ -144,7 +158,7 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
144
158
  backend_name_ls = setting["backend"] if isinstance(setting["backend"], (list, tuple)) else [setting["backend"]]
145
159
  for i in backend_name_ls:
146
160
  if i not in backend_s:
147
- backend_s[i] = SERIALIZER_BACKEND.get(name=i)(folder=os.path.join(output_dir, "nodes"))
161
+ backend_s[i] = SERIALIZER_BACKEND.get(name=i)(folder=os.path.join(temp_output_dir, "nodes"))
148
162
  #
149
163
  t_mode = Traversal_Mode(setting.get("traversal_mode", traversal_mode))
150
164
  # _process and paras
@@ -189,17 +203,19 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
189
203
  f'please check settings to make sure all nodes have been covered and can be deal with backend'
190
204
 
191
205
  # 保存 var 的结构
192
- json_.write(content=var, file_path=os.path.join(output_dir, "var.json"), b_use_suggested_converter=True)
206
+ json_.write(content=var, file_path=os.path.join(temp_output_dir, "var.json"), b_use_suggested_converter=True)
193
207
  # 保存处理结果(非必要)
194
208
  json_.write(content=dict(processed=processed_s, raw_structure=processed_s_bak, timestamp=time.time(),
195
209
  kt_version=kevin_toolbox.__version__,
196
210
  b_keep_identical_relations=b_keep_identical_relations),
197
- file_path=os.path.join(output_dir, "record.json"), b_use_suggested_converter=True)
211
+ file_path=os.path.join(temp_output_dir, "record.json"), b_use_suggested_converter=True)
198
212
 
199
213
  # 打包成 .tar 文件
200
214
  if b_pack_into_tar:
201
- for_os.pack(source=output_dir)
202
- for_os.remove(path=output_dir, ignore_errors=True)
215
+ for_os.pack(source=temp_output_dir)
216
+ os.rename(temp_output_dir + ".tar", output_dir + ".tar")
217
+ else:
218
+ os.rename(temp_output_dir, output_dir)
203
219
 
204
220
 
205
221
  def _judge_processed_or_not(processed_s, name):
@@ -0,0 +1,6 @@
1
+ from .plot_lines import plot_lines
2
+ from .plot_scatters import plot_scatters
3
+ from .plot_distribution import plot_distribution
4
+ from .plot_bars import plot_bars
5
+ from .plot_scatters_matrix import plot_scatters_matrix
6
+ from .plot_confusion_matrix import plot_confusion_matrix
@@ -0,0 +1,54 @@
1
+ import os
2
+ import copy
3
+ from kevin_toolbox.computer_science.algorithm import for_seq
4
+ import matplotlib.pyplot as plt
5
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
6
+
7
+ # TODO 在 linux 系统下遇到中文时,尝试自动下载中文字体,并尝试自动设置字体
8
+ # font_path = os.path.join(root_dir, "utils/SimHei.ttf")
9
+ # font_name = FontProperties(fname=font_path)
10
+
11
+
12
+ def plot_bars(data_s, title, x_name, y_label=None, output_dir=None, **kwargs):
13
+ data_s = copy.deepcopy(data_s)
14
+ paras = {
15
+ "dpi": 200
16
+ }
17
+ paras.update(kwargs)
18
+
19
+ plt.clf()
20
+ #
21
+ x_all_ls = data_s.pop(x_name)
22
+ #
23
+ for i, (k, y_ls) in enumerate(data_s.items()):
24
+ if i == 0:
25
+ plt.bar([j - 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
26
+ else:
27
+ plt.bar([j + 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
28
+
29
+ plt.xlabel(f'{x_name}')
30
+ plt.ylabel(f'{y_label if y_label else "value"}')
31
+ temp = for_seq.flatten_list([list(i) for i in data_s.values()])
32
+ y_min, y_max = min(temp), max(temp)
33
+ plt.ylim(max(min(y_min, 0), y_min - (y_max - y_min) * 0.2), y_max + (y_max - y_min) * 0.1)
34
+ plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls) # , fontproperties=font_name
35
+ plt.title(f'{title}')
36
+ # 显示图例
37
+ plt.legend()
38
+
39
+ if output_dir is None:
40
+ plt.show()
41
+ return None
42
+ else:
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
45
+ plt.savefig(output_path, dpi=paras["dpi"])
46
+ return output_path
47
+
48
+
49
+ if __name__ == '__main__':
50
+ plot_bars(data_s={
51
+ 'a': [1.5, 2, 3, 4, 5],
52
+ 'b': [5, 4, 3, 2, 1],
53
+ 'c': [1, 2, 3, 4, 5]},
54
+ title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
@@ -0,0 +1,60 @@
1
+ import os
2
+ from sklearn.metrics import confusion_matrix
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
6
+
7
+
8
+ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None, **kwargs):
9
+ paras = {
10
+ "dpi": 200,
11
+ "normalize": None, # "true", "pred", "all",
12
+ "b_return_cfm": False, # 是否输出混淆矩阵
13
+ }
14
+ paras.update(kwargs)
15
+
16
+ value_set = set(data_s[gt_name]).union(set(data_s[pd_name]))
17
+ if label_to_value_s is None:
18
+ label_to_value_s = {f'{i}': i for i in value_set}
19
+ else:
20
+ assert all(i in value_set for i in label_to_value_s.values())
21
+ # 计算混淆矩阵
22
+ cfm = confusion_matrix(y_true=data_s[gt_name], y_pred=data_s[pd_name], labels=list(label_to_value_s.values()),
23
+ normalize=paras["normalize"])
24
+ # 绘制混淆矩阵热力图
25
+ plt.clf()
26
+ plt.figure(figsize=(8, 6))
27
+ sns.heatmap(cfm, annot=True, fmt='.2%' if paras["normalize"] is not None else 'd',
28
+ xticklabels=list(label_to_value_s.keys()), yticklabels=list(label_to_value_s.keys()),
29
+ cmap='viridis')
30
+
31
+ plt.xlabel(f'{pd_name}')
32
+ plt.ylabel(f'{gt_name}')
33
+ plt.title(f'{title}')
34
+
35
+ if output_dir is None:
36
+ plt.show()
37
+ output_path = None
38
+ else:
39
+ os.makedirs(output_dir, exist_ok=True)
40
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
41
+ plt.savefig(output_path, dpi=paras["dpi"])
42
+
43
+ if paras["b_return_cfm"]:
44
+ return output_path, cfm
45
+ else:
46
+ return output_path
47
+
48
+
49
+ if __name__ == '__main__':
50
+ import numpy as np
51
+
52
+ # 示例真实标签和预测标签
53
+ y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 5])
54
+ y_pred = np.array([0, 2, 1, 0, 2, 1, 0, 1, 1, 5])
55
+
56
+ plot_confusion_matrix(data_s={'a': y_true, 'b': y_pred},
57
+ title='test', gt_name='a', pd_name='b',
58
+ label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2},
59
+ # output_dir=os.path.join(os.path.dirname(__file__), "temp"),
60
+ normalize="true")
@@ -0,0 +1,65 @@
1
+ import os
2
+ import math
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
6
+
7
+
8
+ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, **kwargs):
9
+ paras = {
10
+ "dpi": 200
11
+ }
12
+ paras.update(kwargs)
13
+ if x_name is not None:
14
+ x_name_ls = [x_name, ]
15
+ assert isinstance(x_name_ls, (list, tuple)) and len(x_name_ls) > 0
16
+
17
+ plt.clf()
18
+
19
+ alpha = max(1 / len(x_name_ls), 0.3)
20
+ # 检查数据类型
21
+ if type_ in ["histogram", "hist"]:
22
+ # 数字数据,绘制概率分布图
23
+ for x_name in x_name_ls:
24
+ data = data_s[x_name]
25
+ assert all(isinstance(x, (int, float)) for x in data), \
26
+ f'输入数组中的元素类型不一致'
27
+ if "steps" in paras:
28
+ min_ = math.floor(min(data) / paras["steps"]) * paras["steps"]
29
+ max_ = math.ceil(max(data) / paras["steps"]) * paras["steps"]
30
+ bins = np.arange(min_, max_ + paras["steps"], paras["steps"])
31
+ else:
32
+ bins = np.linspace(paras.get("min", min(data)), paras.get("max", max(data)), paras["bin_nums"] + 1)
33
+ plt.hist(data, density=True, bins=bins, alpha=alpha, label=x_name)
34
+ elif type_ in ["category", "cate"]:
35
+ # 字符串数据,绘制概率直方图
36
+ for x_name in x_name_ls:
37
+ data = data_s[x_name]
38
+ unique_values, counts = np.unique(data, return_counts=True)
39
+ probabilities = counts / len(data)
40
+ plt.bar([f'{i}' for i in unique_values], probabilities, label=x_name, alpha=alpha)
41
+ else:
42
+ raise ValueError(f'unsupported plot type {type_}')
43
+
44
+ plt.xlabel(f'value')
45
+ plt.ylabel('prob')
46
+ plt.title(f'{title}')
47
+ # 显示图例
48
+ plt.legend()
49
+
50
+ if output_dir is None:
51
+ plt.show()
52
+ return None
53
+ else:
54
+ os.makedirs(output_dir, exist_ok=True)
55
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
56
+ plt.savefig(output_path, dpi=paras["dpi"])
57
+ return output_path
58
+
59
+
60
+ if __name__ == '__main__':
61
+ plot_distribution(data_s={
62
+ 'a': [1, 2, 3, 4, 5, 3, 2, 1],
63
+ 'c': [1, 2, 3, 4, 5, 0, 0, 0]},
64
+ title='test', x_name_ls=['a', 'c'], type_="category",
65
+ output_dir=os.path.join(os.path.dirname(__file__), "temp"))
@@ -0,0 +1,61 @@
1
+ import os
2
+ import matplotlib.pyplot as plt
3
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
4
+ from kevin_toolbox.patches.for_matplotlib import generate_color_list
5
+
6
+
7
+ def plot_lines(data_s, title, x_name, output_dir=None, **kwargs):
8
+ line_nums = len(data_s) - 1
9
+ paras = {
10
+ "dpi": 200,
11
+ "color_ls": generate_color_list(nums=line_nums),
12
+ "marker_ls": None,
13
+ "linestyle_ls": '-',
14
+ }
15
+ paras.update(kwargs)
16
+ for k, v in paras.items():
17
+ if k.endswith("_ls") and not isinstance(v, (list, tuple)):
18
+ paras[k] = [v] * line_nums
19
+ assert line_nums == len(paras["color_ls"]) == len(paras["marker_ls"]) == len(paras["linestyle_ls"])
20
+
21
+ plt.clf()
22
+ #
23
+ x_all_ls = data_s.pop(x_name)
24
+ data_s, temp = dict(), data_s
25
+ for k, v_ls in temp.items():
26
+ y_ls, x_ls = [], []
27
+ for x, v in zip(x_all_ls, v_ls):
28
+ if x is None or v is None:
29
+ continue
30
+ x_ls.append(x)
31
+ y_ls.append(v)
32
+ if len(x_ls) == 0:
33
+ continue
34
+ data_s[k] = (x_ls, y_ls)
35
+ #
36
+ for i, (k, (x_ls, y_ls)) in enumerate(data_s.items()):
37
+ plt.plot(x_ls, y_ls, label=f'{k}', color=paras["color_ls"][i], marker=paras["marker_ls"][i],
38
+ linestyle=paras["linestyle_ls"][i])
39
+ plt.xlabel(f'{x_name}')
40
+ plt.ylabel('value')
41
+ plt.title(f'{title}')
42
+ # 显示图例
43
+ plt.legend()
44
+
45
+ if output_dir is None:
46
+ plt.show()
47
+ return None
48
+ else:
49
+ # 对非法字符进行替换
50
+ os.makedirs(output_dir, exist_ok=True)
51
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
52
+ plt.savefig(output_path, dpi=paras["dpi"])
53
+ return output_path
54
+
55
+
56
+ if __name__ == '__main__':
57
+ plot_lines(data_s={
58
+ 'a': [1, 2, 3, 4, 5],
59
+ 'b': [5, 4, 3, 2, 1],
60
+ 'c': [1, 2, 3, 4, 5]},
61
+ title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))