kevin-toolbox-dev 1.3.5__py3-none-any.whl → 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/computer_science/algorithm/pareto_front/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/pareto_front/optimum_picker.py +218 -0
- kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +2 -2
- kevin_toolbox/computer_science/algorithm/statistician/accumulator_for_ndl.py +69 -0
- kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +16 -4
- kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +5 -12
- kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_writer.py +3 -0
- kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation_debug.py +27 -0
- kevin_toolbox/data_flow/file/kevin_notation/write.py +3 -0
- kevin_toolbox/data_flow/file/markdown/__init__.py +3 -0
- kevin_toolbox/data_flow/file/markdown/find_tables.py +65 -0
- kevin_toolbox/data_flow/file/markdown/generate_table.py +19 -5
- kevin_toolbox/data_flow/file/markdown/parse_table.py +135 -0
- kevin_toolbox/data_flow/file/markdown/save_images_in_ndl.py +81 -0
- kevin_toolbox/data_flow/file/markdown/variable.py +17 -0
- kevin_toolbox/nested_dict_list/serializer/read.py +1 -1
- kevin_toolbox/nested_dict_list/serializer/write.py +9 -2
- kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +6 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +54 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +60 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +65 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +61 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +53 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +54 -0
- kevin_toolbox/patches/for_os/__init__.py +2 -0
- kevin_toolbox/patches/for_os/find_files_in_dir.py +30 -0
- kevin_toolbox/patches/for_os/path/__init__.py +2 -0
- kevin_toolbox/patches/for_os/path/find_illegal_chars.py +47 -0
- kevin_toolbox/patches/for_os/path/replace_illegal_chars.py +49 -0
- kevin_toolbox_dev-1.3.6.dist-info/METADATA +95 -0
- {kevin_toolbox_dev-1.3.5.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/RECORD +35 -17
- kevin_toolbox_dev-1.3.5.dist-info/METADATA +0 -74
- {kevin_toolbox_dev-1.3.5.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.3.5.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Union
|
3
|
+
from kevin_toolbox.data_flow.file.markdown.variable import Table_Format
|
4
|
+
|
5
|
+
|
6
|
+
def parse_table(raw_table, output_format: Union[Table_Format, str] = Table_Format.COMPLETE_DICT, orientation="vertical",
|
7
|
+
chunk_size=None, chunk_nums=None, b_remove_empty_lines=False, f_gen_order_of_values=None):
|
8
|
+
"""
|
9
|
+
将二维数组形式的表格(比如find_tables()的返回列表的元素),解析成指定的格式
|
10
|
+
|
11
|
+
参数:
|
12
|
+
raw_table: <list of list> 二维数组形式的表格
|
13
|
+
output_format: <Table_Format or str> 目标格式
|
14
|
+
具体可以参考 Table_Format 的介绍
|
15
|
+
orientation: <str> 解释表格时取哪个方向
|
16
|
+
支持以下值:
|
17
|
+
"vertical" / "v": 将第一行作为标题
|
18
|
+
"horizontal" / "h": 将第一列作为标题
|
19
|
+
chunk_nums: <int> 表格被平均分割为多少份进行并列显示。
|
20
|
+
chunk_size: <int> 表格被按照最大长度进行分割,然后并列显示。
|
21
|
+
以上两个参数是用于解释 generate_table() 中使用对应参数生成的表格,其中 chunk_size 仅作检验行数是否符合要求,
|
22
|
+
对解释表格无作用。但是当指定该参数时,将视为表格有可能是多个表格并列的情况,因此将尝试根据标题的重复规律,
|
23
|
+
推断出对应的 chunk_nums,并最终将其拆分成多个表格。
|
24
|
+
b_remove_empty_lines: <boolean> 移除空的行、列
|
25
|
+
f_gen_order_of_values: <callable> 生成values排序顺序的函数
|
26
|
+
具体参考 generate_table() 中的对应参数
|
27
|
+
"""
|
28
|
+
assert isinstance(raw_table, (list, tuple,))
|
29
|
+
|
30
|
+
# 转换为字典形式
|
31
|
+
if orientation not in ["vertical", "v"]:
|
32
|
+
# 需要转为垂直方向
|
33
|
+
raw_table = list(zip(*raw_table))
|
34
|
+
r_nums, c_nums = len(raw_table), len(raw_table[0])
|
35
|
+
if chunk_size is not None:
|
36
|
+
assert chunk_size == r_nums - 1, \
|
37
|
+
(f'The number of values {r_nums - 1} actually contained in the table '
|
38
|
+
f'does not match the specified chunk_size {chunk_size}')
|
39
|
+
chunk_nums = c_nums // _find_shortest_repeating_pattern_size(arr=raw_table[0])
|
40
|
+
chunk_nums = 1 if chunk_nums is None else chunk_nums
|
41
|
+
assert c_nums % chunk_nums == 0, \
|
42
|
+
f'The number of headers actually contained in the table does not match the specified chunk_nums, ' \
|
43
|
+
f'Expected n*{chunk_nums}, but got {c_nums}'
|
44
|
+
# 解释出标题
|
45
|
+
keys = raw_table[0][0:c_nums // chunk_nums]
|
46
|
+
# 解释出值
|
47
|
+
if chunk_nums == 1:
|
48
|
+
values = raw_table[1:]
|
49
|
+
else:
|
50
|
+
values = []
|
51
|
+
for i in range(chunk_nums):
|
52
|
+
for j in range(1, r_nums):
|
53
|
+
values.append(raw_table[j][i * len(keys):(i + 1) * len(keys)])
|
54
|
+
# 去除空行
|
55
|
+
if b_remove_empty_lines:
|
56
|
+
values = [line for line in values if any(i != '' for i in line)]
|
57
|
+
table_s = {i: {"title": k, "values": list(v)} for i, (k, v) in enumerate(zip(keys, list(zip(*values))))}
|
58
|
+
# 去除空列
|
59
|
+
if b_remove_empty_lines:
|
60
|
+
table_s = {k: v_s for k, v_s in table_s.items() if v_s["title"] != '' and any(i != '' for i in v_s["values"])}
|
61
|
+
# 对值进行排序
|
62
|
+
if callable(f_gen_order_of_values):
|
63
|
+
breakpoint()
|
64
|
+
# 检查是否有重复的 title
|
65
|
+
temp = [v["title"] for v in table_s.values()]
|
66
|
+
assert len(set(temp)) == len(temp), \
|
67
|
+
f'table has duplicate titles, thus cannot be sorted using f_gen_order_of_values'
|
68
|
+
idx_ls = list(range(len(values)))
|
69
|
+
idx_ls.sort(key=lambda x: f_gen_order_of_values({v["title"]: v["values"][x] for v in table_s.values()}))
|
70
|
+
for v in table_s.values():
|
71
|
+
v["values"] = [v["values"][i] for i in idx_ls]
|
72
|
+
|
73
|
+
#
|
74
|
+
if output_format is Table_Format.SIMPLE_DICT:
|
75
|
+
temp = {v_s["title"] for v_s in table_s.values()}
|
76
|
+
if len(temp) != len(set(temp)):
|
77
|
+
raise AssertionError(
|
78
|
+
f'There are columns with the same title in the table, '
|
79
|
+
f'please check the orientation of the table or use output_format="complete_dict"')
|
80
|
+
table_s = {v_s["title"]: v_s["values"] for v_s in table_s.values()}
|
81
|
+
|
82
|
+
return table_s
|
83
|
+
|
84
|
+
|
85
|
+
def _find_shortest_repeating_pattern_size(arr):
|
86
|
+
n = len(arr)
|
87
|
+
|
88
|
+
# 部分匹配表
|
89
|
+
pi = [0] * n
|
90
|
+
k = 0
|
91
|
+
for i in range(1, n):
|
92
|
+
if k > 0 and arr[k] != arr[i]:
|
93
|
+
k = 0
|
94
|
+
if arr[k] == arr[i]:
|
95
|
+
k += 1
|
96
|
+
pi[i] = k
|
97
|
+
|
98
|
+
# 最短重复模式的长度
|
99
|
+
pattern_length = n - pi[n - 1]
|
100
|
+
# 是否是完整的重复模式
|
101
|
+
if n % pattern_length != 0:
|
102
|
+
pattern_length = n
|
103
|
+
return pattern_length
|
104
|
+
|
105
|
+
|
106
|
+
if __name__ == '__main__':
|
107
|
+
from kevin_toolbox.data_flow.file.markdown import find_tables
|
108
|
+
# # 示例Markdown表格文本
|
109
|
+
# file_path = ""
|
110
|
+
# with open(file_path, 'r') as f:
|
111
|
+
# markdown_text = f.read()
|
112
|
+
|
113
|
+
# markdown_text = """
|
114
|
+
# | Name | Age | Occupation |
|
115
|
+
# |------|-----|------------|
|
116
|
+
# | Alice | 28 | Engineer |
|
117
|
+
# | Bob | 23 | Teacher |
|
118
|
+
# | Name | Age | Occupation |
|
119
|
+
# | Carol | 32 | Hacker |
|
120
|
+
# | David | 18 | Student |
|
121
|
+
# """
|
122
|
+
|
123
|
+
markdown_text = """
|
124
|
+
| | a | b | | a | b | | a | b |
|
125
|
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
126
|
+
| | 0 | 2 | | 4 | 6 | | 7 | 9 |
|
127
|
+
| | 1 | 3 | | 5 | 7 | | 8 | : |
|
128
|
+
| | 2 | 4 | | 6 | 8 | | 9 | ; |
|
129
|
+
| | 3 | 5 | | | | | | |
|
130
|
+
"""
|
131
|
+
table_ls = find_tables(text=markdown_text)
|
132
|
+
|
133
|
+
# 调用函数并打印结果
|
134
|
+
tables = parse_table(raw_table=table_ls[0], output_format="complete_dict", chunk_nums=3, b_remove_empty_lines=True)
|
135
|
+
print(tables)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import os
|
2
|
+
import warnings
|
3
|
+
import torch
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from PIL import Image
|
7
|
+
from collections import defaultdict
|
8
|
+
from kevin_toolbox.data_flow.file import markdown
|
9
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars, find_illegal_chars
|
10
|
+
import kevin_toolbox.nested_dict_list as ndl
|
11
|
+
|
12
|
+
|
13
|
+
def save_images_in_ndl(var, plot_dir, doc_dir=None, setting_s=None):
|
14
|
+
"""
|
15
|
+
将ndl结构叶节点下的图片对象保存到 plot_dir 中,并替换为该图片的markdown链接
|
16
|
+
|
17
|
+
参数:
|
18
|
+
var: <dict> 待处理的 ndl 结构
|
19
|
+
plot_dir: <path> 图片保存的目录
|
20
|
+
doc_dir: <path> 输出的markdown文档保存的目录
|
21
|
+
当有指定时,图片链接将以相对于 doc_dir 的相对路径的形式保存
|
22
|
+
默认为 None,此时保存的markdown图片链接使用的是绝对路径
|
23
|
+
setting_s: <dict> 配置
|
24
|
+
指定要在哪些节点下去寻找图片对象,以及转换图片对象时使用的参数
|
25
|
+
形式为 {<node name>: {"b_is_rgb":<boolean>, ...}, ...}
|
26
|
+
其中配置项支持:
|
27
|
+
- b_is_rgb: 待保存的图片是RGB顺序还是BGR顺序
|
28
|
+
- saved_image_format: 保存图片时使用的格式
|
29
|
+
默认为 None,此时等效于 {"": {"b_is_rgb": False, "saved_image_format": ".jpg"}}
|
30
|
+
"""
|
31
|
+
if len(find_illegal_chars(file_name=plot_dir, b_is_path=True)) > 0:
|
32
|
+
warnings.warn(f'plot_dir {plot_dir} contains illegal symbols, '
|
33
|
+
f'which may cause compatibility issues on certain systems.', UserWarning)
|
34
|
+
setting_s = setting_s or {"": {"b_is_rgb": False, "saved_image_format": ".jpg"}}
|
35
|
+
|
36
|
+
# 将配置解释到各个叶节点
|
37
|
+
# 从最浅的路径开始,若更深的路径有另外的设置,则以更新的为准
|
38
|
+
root_ls = list(setting_s.keys())
|
39
|
+
root_ls.sort(key=lambda x: len(ndl.name_handler.parse_name(name=x)[-1]))
|
40
|
+
root_to_leaf_s = defaultdict(set)
|
41
|
+
leaf_to_root_s = dict()
|
42
|
+
leaf_to_value_s = dict()
|
43
|
+
for root in root_ls:
|
44
|
+
for leaf, v in ndl.get_nodes(var=ndl.get_value(var=var, name=root, b_pop=False), level=-1, b_strict=True):
|
45
|
+
leaf = root + leaf
|
46
|
+
if leaf in leaf_to_root_s:
|
47
|
+
root_to_leaf_s[leaf_to_root_s[leaf]].remove(leaf)
|
48
|
+
root_to_leaf_s[root].add(leaf)
|
49
|
+
leaf_to_root_s[leaf] = root
|
50
|
+
leaf_to_value_s[leaf] = v
|
51
|
+
|
52
|
+
for root, leaf_ls in root_to_leaf_s.items():
|
53
|
+
setting_ = setting_s[root]
|
54
|
+
for leaf in leaf_ls:
|
55
|
+
v = leaf_to_value_s[leaf]
|
56
|
+
if isinstance(v, Image.Image):
|
57
|
+
v = np.asarray(v)
|
58
|
+
elif torch.is_tensor(v):
|
59
|
+
v = v.detach().cpu().numpy()
|
60
|
+
#
|
61
|
+
if isinstance(v, np.ndarray):
|
62
|
+
image_path = os.path.join(
|
63
|
+
plot_dir, replace_illegal_chars(
|
64
|
+
file_name=f'{leaf}_{setting_["saved_image_format"]}', b_is_path=False)
|
65
|
+
)
|
66
|
+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
67
|
+
if setting_["b_is_rgb"]:
|
68
|
+
v = cv2.cvtColor(v, cv2.COLOR_RGB2BGR)
|
69
|
+
cv2.imwrite(image_path, v)
|
70
|
+
v_new = markdown.generate_link(
|
71
|
+
name=os.path.basename(image_path),
|
72
|
+
target=os.path.relpath(image_path, doc_dir) if doc_dir is not None else image_path, type_="image")
|
73
|
+
elif v is None:
|
74
|
+
v_new = "/"
|
75
|
+
else:
|
76
|
+
v_new = v
|
77
|
+
ndl.set_value(var=var, name=leaf, b_force=False, value=v_new)
|
78
|
+
return var
|
79
|
+
|
80
|
+
|
81
|
+
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class Table_Format(Enum):
|
5
|
+
"""
|
6
|
+
表格的几种模式
|
7
|
+
1.simple_dict 简易字典模式:
|
8
|
+
content_s = {<title>: <list of value>, ...}
|
9
|
+
此时键作为标题,值作为标题下的一系列值。
|
10
|
+
由于字典的无序性,此时标题的顺序是不能保证的,若要额外指定顺序,请使用下面的 完整模式。
|
11
|
+
2. complete_dict 完整字典模式:
|
12
|
+
content_s = {<index>: {"title": <title>,"values":<list of value>}, ...}
|
13
|
+
此时将取第 <index> 个 "title" 的值来作为第 <index> 个标题的值。values 同理。
|
14
|
+
该模式允许缺省某些 <index>,此时这些 <index> 对应的行/列将全部置空。
|
15
|
+
"""
|
16
|
+
SIMPLE_DICT = "simple_dict"
|
17
|
+
COMPLETE_DICT = "complete_dict"
|
@@ -18,7 +18,7 @@ def read(input_path, **kwargs):
|
|
18
18
|
assert os.path.exists(input_path)
|
19
19
|
|
20
20
|
with tempfile.TemporaryDirectory(dir=os.path.dirname(input_path)) as temp_dir:
|
21
|
-
if os.path.isfile(input_path) and input_path.endswith(".tar"):
|
21
|
+
if os.path.isfile(input_path) and input_path.endswith(".tar"): # 解压
|
22
22
|
for_os.unpack(source=input_path, target=temp_dir)
|
23
23
|
input_path = os.path.join(temp_dir, os.listdir(temp_dir)[0])
|
24
24
|
var = _read_unpacked_ndl(input_path, **kwargs)
|
@@ -13,7 +13,7 @@ from .saved_node_name_builder import Saved_Node_Name_Builder
|
|
13
13
|
|
14
14
|
def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_pack_into_tar=True,
|
15
15
|
strictness_level=Strictness_Level.COMPATIBLE, saved_node_name_format='{count}_{hash_name}',
|
16
|
-
b_keep_identical_relations=False, **kwargs):
|
16
|
+
b_keep_identical_relations=False, b_allow_overwrite=False, **kwargs):
|
17
17
|
"""
|
18
18
|
将输入的嵌套字典列表 var 的结构和节点值保存到文件中
|
19
19
|
遍历 var,匹配并使用 settings 中设置的保存方式来对各部分结构/节点进行序列化
|
@@ -106,6 +106,8 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
|
|
106
106
|
替换为单个节点和其多个引用的形式。
|
107
107
|
对于 ndl 中存在大量具有相同 id 的重复节点的情况,使用该操作可以额外达到压缩的效果。
|
108
108
|
默认为 False
|
109
|
+
b_allow_overwrite: <boolean> 是否允许强制覆盖已有文件
|
110
|
+
默认为 False,此时若目标文件已存在则报错
|
109
111
|
"""
|
110
112
|
from kevin_toolbox.nested_dict_list.serializer.variable import SERIALIZER_BACKEND
|
111
113
|
|
@@ -113,7 +115,12 @@ def write(var, output_dir, settings=None, traversal_mode=Traversal_Mode.BFS, b_p
|
|
113
115
|
traversal_mode = Traversal_Mode(traversal_mode)
|
114
116
|
strictness_level = Strictness_Level(strictness_level)
|
115
117
|
#
|
116
|
-
|
118
|
+
tgt_path = output_dir + ".tar" if b_pack_into_tar else output_dir
|
119
|
+
if os.path.exists(tgt_path):
|
120
|
+
if b_allow_overwrite:
|
121
|
+
for_os.remove(path=tgt_path, ignore_errors=True)
|
122
|
+
else:
|
123
|
+
raise FileExistsError(f"target {tgt_path} already exists")
|
117
124
|
os.makedirs(os.path.dirname(output_dir), exist_ok=True)
|
118
125
|
temp_dir = tempfile.TemporaryDirectory(dir=os.path.dirname(output_dir))
|
119
126
|
temp_output_dir = os.path.join(temp_dir.name, os.path.basename(output_dir))
|
@@ -0,0 +1,6 @@
|
|
1
|
+
from .plot_lines import plot_lines
|
2
|
+
from .plot_scatters import plot_scatters
|
3
|
+
from .plot_distribution import plot_distribution
|
4
|
+
from .plot_bars import plot_bars
|
5
|
+
from .plot_scatters_matrix import plot_scatters_matrix
|
6
|
+
from .plot_confusion_matrix import plot_confusion_matrix
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import os
|
2
|
+
import copy
|
3
|
+
from kevin_toolbox.computer_science.algorithm import for_seq
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
6
|
+
|
7
|
+
# TODO 在 linux 系统下遇到中文时,尝试自动下载中文字体,并尝试自动设置字体
|
8
|
+
# font_path = os.path.join(root_dir, "utils/SimHei.ttf")
|
9
|
+
# font_name = FontProperties(fname=font_path)
|
10
|
+
|
11
|
+
|
12
|
+
def plot_bars(data_s, title, x_name, y_label=None, output_dir=None, **kwargs):
|
13
|
+
data_s = copy.deepcopy(data_s)
|
14
|
+
paras = {
|
15
|
+
"dpi": 200
|
16
|
+
}
|
17
|
+
paras.update(kwargs)
|
18
|
+
|
19
|
+
plt.clf()
|
20
|
+
#
|
21
|
+
x_all_ls = data_s.pop(x_name)
|
22
|
+
#
|
23
|
+
for i, (k, y_ls) in enumerate(data_s.items()):
|
24
|
+
if i == 0:
|
25
|
+
plt.bar([j - 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
|
26
|
+
else:
|
27
|
+
plt.bar([j + 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
|
28
|
+
|
29
|
+
plt.xlabel(f'{x_name}')
|
30
|
+
plt.ylabel(f'{y_label if y_label else "value"}')
|
31
|
+
temp = for_seq.flatten_list([list(i) for i in data_s.values()])
|
32
|
+
y_min, y_max = min(temp), max(temp)
|
33
|
+
plt.ylim(max(min(y_min, 0), y_min - (y_max - y_min) * 0.2), y_max + (y_max - y_min) * 0.1)
|
34
|
+
plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls) # , fontproperties=font_name
|
35
|
+
plt.title(f'{title}')
|
36
|
+
# 显示图例
|
37
|
+
plt.legend()
|
38
|
+
|
39
|
+
if output_dir is None:
|
40
|
+
plt.show()
|
41
|
+
return None
|
42
|
+
else:
|
43
|
+
os.makedirs(output_dir, exist_ok=True)
|
44
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
45
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
46
|
+
return output_path
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == '__main__':
|
50
|
+
plot_bars(data_s={
|
51
|
+
'a': [1.5, 2, 3, 4, 5],
|
52
|
+
'b': [5, 4, 3, 2, 1],
|
53
|
+
'c': [1, 2, 3, 4, 5]},
|
54
|
+
title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import os
|
2
|
+
from sklearn.metrics import confusion_matrix
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
import seaborn as sns
|
5
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
6
|
+
|
7
|
+
|
8
|
+
def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None, **kwargs):
|
9
|
+
paras = {
|
10
|
+
"dpi": 200,
|
11
|
+
"normalize": None, # "true", "pred", "all",
|
12
|
+
"b_return_cfm": False, # 是否输出混淆矩阵
|
13
|
+
}
|
14
|
+
paras.update(kwargs)
|
15
|
+
|
16
|
+
value_set = set(data_s[gt_name]).union(set(data_s[pd_name]))
|
17
|
+
if label_to_value_s is None:
|
18
|
+
label_to_value_s = {f'{i}': i for i in value_set}
|
19
|
+
else:
|
20
|
+
assert all(i in value_set for i in label_to_value_s.values())
|
21
|
+
# 计算混淆矩阵
|
22
|
+
cfm = confusion_matrix(y_true=data_s[gt_name], y_pred=data_s[pd_name], labels=list(label_to_value_s.values()),
|
23
|
+
normalize=paras["normalize"])
|
24
|
+
# 绘制混淆矩阵热力图
|
25
|
+
plt.clf()
|
26
|
+
plt.figure(figsize=(8, 6))
|
27
|
+
sns.heatmap(cfm, annot=True, fmt='.2%' if paras["normalize"] is not None else 'd',
|
28
|
+
xticklabels=list(label_to_value_s.keys()), yticklabels=list(label_to_value_s.keys()),
|
29
|
+
cmap='viridis')
|
30
|
+
|
31
|
+
plt.xlabel(f'{pd_name}')
|
32
|
+
plt.ylabel(f'{gt_name}')
|
33
|
+
plt.title(f'{title}')
|
34
|
+
|
35
|
+
if output_dir is None:
|
36
|
+
plt.show()
|
37
|
+
output_path = None
|
38
|
+
else:
|
39
|
+
os.makedirs(output_dir, exist_ok=True)
|
40
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
41
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
42
|
+
|
43
|
+
if paras["b_return_cfm"]:
|
44
|
+
return output_path, cfm
|
45
|
+
else:
|
46
|
+
return output_path
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == '__main__':
|
50
|
+
import numpy as np
|
51
|
+
|
52
|
+
# 示例真实标签和预测标签
|
53
|
+
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 5])
|
54
|
+
y_pred = np.array([0, 2, 1, 0, 2, 1, 0, 1, 1, 5])
|
55
|
+
|
56
|
+
plot_confusion_matrix(data_s={'a': y_true, 'b': y_pred},
|
57
|
+
title='test', gt_name='a', pd_name='b',
|
58
|
+
label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2},
|
59
|
+
# output_dir=os.path.join(os.path.dirname(__file__), "temp"),
|
60
|
+
normalize="true")
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import os
|
2
|
+
import math
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
import numpy as np
|
5
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
6
|
+
|
7
|
+
|
8
|
+
def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, **kwargs):
|
9
|
+
paras = {
|
10
|
+
"dpi": 200
|
11
|
+
}
|
12
|
+
paras.update(kwargs)
|
13
|
+
if x_name is not None:
|
14
|
+
x_name_ls = [x_name, ]
|
15
|
+
assert isinstance(x_name_ls, (list, tuple)) and len(x_name_ls) > 0
|
16
|
+
|
17
|
+
plt.clf()
|
18
|
+
|
19
|
+
alpha = max(1 / len(x_name_ls), 0.3)
|
20
|
+
# 检查数据类型
|
21
|
+
if type_ in ["histogram", "hist"]:
|
22
|
+
# 数字数据,绘制概率分布图
|
23
|
+
for x_name in x_name_ls:
|
24
|
+
data = data_s[x_name]
|
25
|
+
assert all(isinstance(x, (int, float)) for x in data), \
|
26
|
+
f'输入数组中的元素类型不一致'
|
27
|
+
if "steps" in paras:
|
28
|
+
min_ = math.floor(min(data) / paras["steps"]) * paras["steps"]
|
29
|
+
max_ = math.ceil(max(data) / paras["steps"]) * paras["steps"]
|
30
|
+
bins = np.arange(min_, max_ + paras["steps"], paras["steps"])
|
31
|
+
else:
|
32
|
+
bins = np.linspace(paras.get("min", min(data)), paras.get("max", max(data)), paras["bin_nums"] + 1)
|
33
|
+
plt.hist(data, density=True, bins=bins, alpha=alpha, label=x_name)
|
34
|
+
elif type_ in ["category", "cate"]:
|
35
|
+
# 字符串数据,绘制概率直方图
|
36
|
+
for x_name in x_name_ls:
|
37
|
+
data = data_s[x_name]
|
38
|
+
unique_values, counts = np.unique(data, return_counts=True)
|
39
|
+
probabilities = counts / len(data)
|
40
|
+
plt.bar([f'{i}' for i in unique_values], probabilities, label=x_name, alpha=alpha)
|
41
|
+
else:
|
42
|
+
raise ValueError(f'unsupported plot type {type_}')
|
43
|
+
|
44
|
+
plt.xlabel(f'value')
|
45
|
+
plt.ylabel('prob')
|
46
|
+
plt.title(f'{title}')
|
47
|
+
# 显示图例
|
48
|
+
plt.legend()
|
49
|
+
|
50
|
+
if output_dir is None:
|
51
|
+
plt.show()
|
52
|
+
return None
|
53
|
+
else:
|
54
|
+
os.makedirs(output_dir, exist_ok=True)
|
55
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
56
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
57
|
+
return output_path
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == '__main__':
|
61
|
+
plot_distribution(data_s={
|
62
|
+
'a': [1, 2, 3, 4, 5, 3, 2, 1],
|
63
|
+
'c': [1, 2, 3, 4, 5, 0, 0, 0]},
|
64
|
+
title='test', x_name_ls=['a', 'c'], type_="category",
|
65
|
+
output_dir=os.path.join(os.path.dirname(__file__), "temp"))
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import os
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
4
|
+
from kevin_toolbox.patches.for_matplotlib import generate_color_list
|
5
|
+
|
6
|
+
|
7
|
+
def plot_lines(data_s, title, x_name, output_dir=None, **kwargs):
|
8
|
+
line_nums = len(data_s) - 1
|
9
|
+
paras = {
|
10
|
+
"dpi": 200,
|
11
|
+
"color_ls": generate_color_list(nums=line_nums),
|
12
|
+
"marker_ls": None,
|
13
|
+
"linestyle_ls": '-',
|
14
|
+
}
|
15
|
+
paras.update(kwargs)
|
16
|
+
for k, v in paras.items():
|
17
|
+
if k.endswith("_ls") and not isinstance(v, (list, tuple)):
|
18
|
+
paras[k] = [v] * line_nums
|
19
|
+
assert line_nums == len(paras["color_ls"]) == len(paras["marker_ls"]) == len(paras["linestyle_ls"])
|
20
|
+
|
21
|
+
plt.clf()
|
22
|
+
#
|
23
|
+
x_all_ls = data_s.pop(x_name)
|
24
|
+
data_s, temp = dict(), data_s
|
25
|
+
for k, v_ls in temp.items():
|
26
|
+
y_ls, x_ls = [], []
|
27
|
+
for x, v in zip(x_all_ls, v_ls):
|
28
|
+
if x is None or v is None:
|
29
|
+
continue
|
30
|
+
x_ls.append(x)
|
31
|
+
y_ls.append(v)
|
32
|
+
if len(x_ls) == 0:
|
33
|
+
continue
|
34
|
+
data_s[k] = (x_ls, y_ls)
|
35
|
+
#
|
36
|
+
for i, (k, (x_ls, y_ls)) in enumerate(data_s.items()):
|
37
|
+
plt.plot(x_ls, y_ls, label=f'{k}', color=paras["color_ls"][i], marker=paras["marker_ls"][i],
|
38
|
+
linestyle=paras["linestyle_ls"][i])
|
39
|
+
plt.xlabel(f'{x_name}')
|
40
|
+
plt.ylabel('value')
|
41
|
+
plt.title(f'{title}')
|
42
|
+
# 显示图例
|
43
|
+
plt.legend()
|
44
|
+
|
45
|
+
if output_dir is None:
|
46
|
+
plt.show()
|
47
|
+
return None
|
48
|
+
else:
|
49
|
+
# 对非法字符进行替换
|
50
|
+
os.makedirs(output_dir, exist_ok=True)
|
51
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
52
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
53
|
+
return output_path
|
54
|
+
|
55
|
+
|
56
|
+
if __name__ == '__main__':
|
57
|
+
plot_lines(data_s={
|
58
|
+
'a': [1, 2, 3, 4, 5],
|
59
|
+
'b': [5, 4, 3, 2, 1],
|
60
|
+
'c': [1, 2, 3, 4, 5]},
|
61
|
+
title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import os
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
from kevin_toolbox.patches.for_matplotlib import generate_color_list
|
4
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
5
|
+
|
6
|
+
|
7
|
+
def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, **kwargs):
|
8
|
+
paras = {
|
9
|
+
"dpi": 200,
|
10
|
+
"scatter_size": 5
|
11
|
+
}
|
12
|
+
paras.update(kwargs)
|
13
|
+
|
14
|
+
plt.clf()
|
15
|
+
#
|
16
|
+
color_s = None
|
17
|
+
if cate_name is not None:
|
18
|
+
cates = list(set(data_s[cate_name]))
|
19
|
+
color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
|
20
|
+
c = [color_s[i] for i in data_s[cate_name]]
|
21
|
+
else:
|
22
|
+
c = "blue"
|
23
|
+
# 创建散点图
|
24
|
+
plt.scatter(data_s[x_name], data_s[y_name], s=paras["scatter_size"], c=c, alpha=0.8)
|
25
|
+
#
|
26
|
+
plt.xlabel(f'{x_name}')
|
27
|
+
plt.ylabel(f'{y_name}')
|
28
|
+
plt.title(f'{title}')
|
29
|
+
# 添加图例
|
30
|
+
if cate_name is not None:
|
31
|
+
plt.legend(handles=[
|
32
|
+
plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
|
33
|
+
markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
|
34
|
+
])
|
35
|
+
|
36
|
+
if output_dir is None:
|
37
|
+
plt.show()
|
38
|
+
return None
|
39
|
+
else:
|
40
|
+
os.makedirs(output_dir, exist_ok=True)
|
41
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
42
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
43
|
+
return output_path
|
44
|
+
|
45
|
+
|
46
|
+
if __name__ == '__main__':
|
47
|
+
data_s_ = dict(
|
48
|
+
x=[1, 2, 3, 4, 5],
|
49
|
+
y=[2, 4, 6, 8, 10],
|
50
|
+
categories=['A', 'B', 'A', 'B', 'A']
|
51
|
+
)
|
52
|
+
|
53
|
+
plot_scatters(data_s=data_s_, title='test', x_name='x', y_name='y', cate_name='categories')
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import os
|
2
|
+
import pandas as pd
|
3
|
+
import seaborn as sns
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
from kevin_toolbox.patches.for_matplotlib import generate_color_list
|
6
|
+
from kevin_toolbox.patches.for_os.path import replace_illegal_chars
|
7
|
+
|
8
|
+
|
9
|
+
def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, cate_color_s=None, **kwargs):
|
10
|
+
paras = {
|
11
|
+
"dpi": 200,
|
12
|
+
"diag_kind": "kde" # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
|
13
|
+
}
|
14
|
+
assert cate_name in data_s and len(set(x_name_ls).difference(set(data_s.keys()))) == 0
|
15
|
+
if cate_color_s is None:
|
16
|
+
temp = set(data_s[cate_name])
|
17
|
+
cate_color_s = {k: v for k, v in zip(temp, generate_color_list(len(temp)))}
|
18
|
+
assert set(cate_color_s.keys()) == set(data_s[cate_name])
|
19
|
+
paras.update(kwargs)
|
20
|
+
|
21
|
+
plt.clf()
|
22
|
+
# 使用seaborn绘制散点图矩阵
|
23
|
+
sns.pairplot(
|
24
|
+
pd.DataFrame(data_s),
|
25
|
+
diag_kind=paras["diag_kind"], # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
|
26
|
+
hue=cate_name, # hue 表示根据该列的值进行分类
|
27
|
+
palette=cate_color_s, x_vars=x_name_ls, y_vars=x_name_ls, # x_vars,y_vars 指定子图的排列顺序
|
28
|
+
)
|
29
|
+
#
|
30
|
+
plt.subplots_adjust(top=0.95)
|
31
|
+
plt.suptitle(f'{title}', y=0.98, x=0.47)
|
32
|
+
# g.fig.suptitle(f'{title}', y=1.05)
|
33
|
+
|
34
|
+
if output_dir is None:
|
35
|
+
plt.show()
|
36
|
+
return None
|
37
|
+
else:
|
38
|
+
os.makedirs(output_dir, exist_ok=True)
|
39
|
+
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
40
|
+
plt.savefig(output_path, dpi=paras["dpi"])
|
41
|
+
return output_path
|
42
|
+
|
43
|
+
|
44
|
+
if __name__ == '__main__':
|
45
|
+
data_s_ = dict(
|
46
|
+
x=[1, 2, 3, 4, 5],
|
47
|
+
y=[2, 4, 6, 8, 10],
|
48
|
+
z=[2, 4, 6, 8, 10],
|
49
|
+
categories=['A', 'B', 'A', 'B', 'A'],
|
50
|
+
title='test',
|
51
|
+
)
|
52
|
+
|
53
|
+
plot_scatters_matrix(data_s=data_s_, title='test', x_name_ls=['y', 'x', 'z'], cate_name='categories',
|
54
|
+
cate_color_s={'A': 'red', 'B': 'blue'})
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import os
|
2
|
+
from kevin_toolbox.patches.for_os import walk
|
3
|
+
|
4
|
+
|
5
|
+
def find_files_in_dir(input_dir, suffix_ls, b_relative_path=True, b_ignore_case=True):
|
6
|
+
"""
|
7
|
+
找出目录下带有给定后缀的所有文件的生成器
|
8
|
+
主要利用了 for_os.walk 中的过滤规则进行实现
|
9
|
+
|
10
|
+
参数:
|
11
|
+
suffix_ls: <list/tuple of str> 可选的后缀
|
12
|
+
b_relative_path: <bool> 是否返回相对路径
|
13
|
+
b_ignore_case: <bool> 是否忽略大小写
|
14
|
+
"""
|
15
|
+
suffix_ls = tuple(set(suffix_ls))
|
16
|
+
suffix_ls = tuple(map(lambda x: x.lower(), suffix_ls)) if b_ignore_case else suffix_ls
|
17
|
+
for root, dirs, files in walk(top=input_dir, topdown=True,
|
18
|
+
ignore_s=[{
|
19
|
+
"func": lambda _, b_is_symlink, path: b_is_symlink or not (
|
20
|
+
path.lower() if b_ignore_case else path).endswith(suffix_ls),
|
21
|
+
"scope": ["files", ]
|
22
|
+
}]):
|
23
|
+
for file in files:
|
24
|
+
file_path = os.path.join(root, file)
|
25
|
+
if b_relative_path:
|
26
|
+
file_path = os.path.relpath(file_path, start=input_dir)
|
27
|
+
yield file_path
|
28
|
+
|
29
|
+
|
30
|
+
|