kevin-toolbox-dev 1.4.11__py3-none-any.whl → 1.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/computer_science/algorithm/pareto_front/get_pareto_points_idx.py +2 -0
- kevin_toolbox/computer_science/algorithm/redirector/redirectable_sequence_fetcher.py +3 -3
- kevin_toolbox/computer_science/algorithm/sampler/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/sampler/recent_sampler.py +128 -0
- kevin_toolbox/computer_science/algorithm/sampler/reservoir_sampler.py +2 -2
- kevin_toolbox/computer_science/algorithm/statistician/__init__.py +2 -0
- kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +1 -1
- kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +1 -1
- kevin_toolbox/computer_science/algorithm/statistician/maximum_accumulator.py +80 -0
- kevin_toolbox/computer_science/algorithm/statistician/minimum_accumulator.py +34 -0
- kevin_toolbox/data_flow/file/markdown/table/find_tables.py +38 -12
- kevin_toolbox/developing/file_management/__init__.py +1 -0
- kevin_toolbox/developing/file_management/file_feature_extractor.py +263 -0
- kevin_toolbox/nested_dict_list/serializer/read.py +4 -1
- kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +5 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_2d_matrix.py +134 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_3d.py +198 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +7 -4
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +11 -4
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_contour.py +157 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +19 -8
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +72 -21
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_mean_std_lines.py +135 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +9 -3
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +9 -3
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +1 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/log_scaling.py +69 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +19 -3
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +1 -1
- kevin_toolbox/patches/for_numpy/__init__.py +1 -0
- kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
- kevin_toolbox_dev-1.4.13.dist-info/METADATA +77 -0
- {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/RECORD +36 -26
- kevin_toolbox_dev-1.4.11.dist-info/METADATA +0 -67
- {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ def read(input_path, **kwargs):
|
|
12
12
|
|
13
13
|
参数:
|
14
14
|
input_path: <path> 文件夹或者 .tar 文件,具体结构参考 write()
|
15
|
+
b_keep_identical_relations: <boolean> 覆盖 record.json 中记录的同名参数,该参数的作用详见 write() 中的介绍。
|
15
16
|
"""
|
16
17
|
assert os.path.exists(input_path)
|
17
18
|
|
@@ -42,7 +43,7 @@ def _read_unpacked_ndl(input_path, **kwargs):
|
|
42
43
|
|
43
44
|
# 读取被处理的节点
|
44
45
|
processed_nodes = []
|
45
|
-
if record_s:
|
46
|
+
if "processed" in record_s:
|
46
47
|
for name, value in ndl.get_nodes(var=record_s["processed"], level=-1, b_strict=True):
|
47
48
|
if value:
|
48
49
|
processed_nodes.append(name)
|
@@ -68,6 +69,8 @@ def _read_unpacked_ndl(input_path, **kwargs):
|
|
68
69
|
ndl.set_value(var=var, name=name, value=bk.read(**value))
|
69
70
|
|
70
71
|
#
|
72
|
+
if "b_keep_identical_relations" in kwargs:
|
73
|
+
record_s["b_keep_identical_relations"] = kwargs["b_keep_identical_relations"]
|
71
74
|
if record_s.get("b_keep_identical_relations", False):
|
72
75
|
from kevin_toolbox.nested_dict_list import value_parser
|
73
76
|
var = value_parser.replace_identical_with_reference(var=var, flag="same", b_reverse=True)
|
@@ -48,4 +48,9 @@ from .plot_distribution import plot_distribution
|
|
48
48
|
from .plot_bars import plot_bars
|
49
49
|
from .plot_scatters_matrix import plot_scatters_matrix
|
50
50
|
from .plot_confusion_matrix import plot_confusion_matrix
|
51
|
+
from .plot_2d_matrix import plot_2d_matrix
|
52
|
+
from .plot_contour import plot_contour
|
53
|
+
from .plot_3d import plot_3d
|
51
54
|
from .plot_from_record import plot_from_record
|
55
|
+
# from .plot_raincloud import plot_raincloud
|
56
|
+
from .plot_mean_std_lines import plot_mean_std_lines
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import copy
|
2
|
+
import warnings
|
3
|
+
import numpy as np
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
import seaborn as sns
|
6
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
|
7
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
8
|
+
from kevin_toolbox.env_info.version import compare
|
9
|
+
|
10
|
+
__name = ":common_charts:plot_matrix"
|
11
|
+
|
12
|
+
if compare(v_0=sns.__version__, operator="<", v_1='0.13.0'):
|
13
|
+
warnings.warn("seaborn version is too low, it may cause the heat map to not be drawn properly,"
|
14
|
+
" please upgrade to 0.13.0 or higher")
|
15
|
+
|
16
|
+
|
17
|
+
@COMMON_CHARTS.register(name=__name)
|
18
|
+
def plot_2d_matrix(matrix, title, row_label="row", column_label="column", x_tick_labels=None, y_tick_labels=None,
|
19
|
+
output_dir=None, output_path=None, replace_zero_division_with=0, **kwargs):
|
20
|
+
"""
|
21
|
+
计算并绘制混淆矩阵
|
22
|
+
|
23
|
+
参数:
|
24
|
+
matrix: <np.ndarray> 矩阵
|
25
|
+
row_label: <str> 行标签。
|
26
|
+
column_label: <str> 列标签。
|
27
|
+
title: <str> 绘图标题,同时用于保存图片的文件名。
|
28
|
+
output_dir: <str or None>
|
29
|
+
图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
|
30
|
+
若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
|
31
|
+
|
32
|
+
output_dir: <str> 图片输出目录。
|
33
|
+
output_path: <str> 图片输出路径。
|
34
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
35
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
36
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
37
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
38
|
+
replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
|
39
|
+
建议使用 np.nan 或 0,默认值为 0。
|
40
|
+
|
41
|
+
其他可选参数:
|
42
|
+
dpi: <int> 图像保存的分辨率。
|
43
|
+
suffix: <str> 图片保存后缀。
|
44
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
45
|
+
normalize: <str or None> 指定归一化方式。
|
46
|
+
可选值包括:
|
47
|
+
"row"(按行归一化)
|
48
|
+
"column"(按列归一化)
|
49
|
+
"all"(整体归一化)
|
50
|
+
默认为 None 表示不归一化。
|
51
|
+
value_fmt: <str> 矩阵元素数值的显示方式。
|
52
|
+
b_return_matrix: <bool> 是否在返回值中包含(当使用 normalize 操作时)修改后的矩阵。
|
53
|
+
默认为 False。
|
54
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
55
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
56
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
57
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
58
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
59
|
+
默认为 False
|
60
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
61
|
+
默认为 True
|
62
|
+
"""
|
63
|
+
paras = {
|
64
|
+
"dpi": 200,
|
65
|
+
"suffix": ".png",
|
66
|
+
"b_generate_record": False,
|
67
|
+
"b_show_plot": False,
|
68
|
+
"b_bgr_image": True,
|
69
|
+
"normalize": None, # "true", "pred", "all",
|
70
|
+
"b_return_matrix": False, # 是否输出混淆矩阵
|
71
|
+
|
72
|
+
}
|
73
|
+
paras.update(kwargs)
|
74
|
+
matrix = np.asarray(matrix)
|
75
|
+
paras.setdefault("value_fmt",
|
76
|
+
'.2%' if paras["normalize"] is not None or np.issubdtype(matrix.dtype, np.floating) else 'd')
|
77
|
+
#
|
78
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
79
|
+
save_record(_func=plot_2d_matrix, _name=__name,
|
80
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
81
|
+
**paras)
|
82
|
+
matrix = copy.deepcopy(matrix)
|
83
|
+
|
84
|
+
# replace with nan
|
85
|
+
if paras["normalize"] is not None:
|
86
|
+
if paras["normalize"] == "all":
|
87
|
+
if matrix.sum() == 0:
|
88
|
+
matrix[matrix == 0] = replace_zero_division_with
|
89
|
+
matrix = matrix / matrix.sum()
|
90
|
+
else:
|
91
|
+
check_axis = 1 if paras["normalize"] == "row" else 0
|
92
|
+
temp = np.sum(matrix, axis=check_axis, keepdims=False)
|
93
|
+
for i in range(len(temp)):
|
94
|
+
if temp[i] == 0:
|
95
|
+
if check_axis == 0:
|
96
|
+
matrix[:, i] = replace_zero_division_with
|
97
|
+
else:
|
98
|
+
matrix[i, :] = replace_zero_division_with
|
99
|
+
matrix = matrix / np.sum(matrix, axis=check_axis, keepdims=True)
|
100
|
+
|
101
|
+
# 绘制混淆矩阵热力图
|
102
|
+
plt.clf()
|
103
|
+
plt.figure(figsize=(8, 6))
|
104
|
+
sns.heatmap(matrix, annot=True, fmt=paras["value_fmt"],
|
105
|
+
xticklabels=x_tick_labels if x_tick_labels is not None else "auto",
|
106
|
+
yticklabels=y_tick_labels if y_tick_labels is not None else "auto",
|
107
|
+
cmap='viridis')
|
108
|
+
|
109
|
+
plt.xlabel(f'{column_label}')
|
110
|
+
plt.ylabel(f'{row_label}')
|
111
|
+
plt.title(f'{title}')
|
112
|
+
|
113
|
+
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
|
114
|
+
|
115
|
+
if paras["b_return_matrix"]:
|
116
|
+
return _output_path, matrix
|
117
|
+
else:
|
118
|
+
return _output_path
|
119
|
+
|
120
|
+
|
121
|
+
if __name__ == '__main__':
|
122
|
+
import os
|
123
|
+
|
124
|
+
# 示例真实标签和预测标签
|
125
|
+
A = np.random.randint(0, 5, (5, 5))
|
126
|
+
print(A)
|
127
|
+
|
128
|
+
plot_2d_matrix(
|
129
|
+
matrix=np.random.randint(0, 5, (5, 5)),
|
130
|
+
title="2D Matrix",
|
131
|
+
output_dir=os.path.join(os.path.dirname(__file__), "temp"),
|
132
|
+
replace_zero_division_with=-1,
|
133
|
+
# normalize="row"
|
134
|
+
)
|
@@ -0,0 +1,198 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
from mpl_toolkits.mplot3d import Axes3D # 兼容部分旧版 matplotlib
|
4
|
+
from scipy.interpolate import griddata
|
5
|
+
from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
|
6
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
|
7
|
+
log_scaling
|
8
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
9
|
+
|
10
|
+
__name = ":common_charts:plot_3d"
|
11
|
+
|
12
|
+
|
13
|
+
@COMMON_CHARTS.register(name=__name)
|
14
|
+
def plot_3d(data_s, title, x_name, y_name, z_name, cate_name=None, type_=("scatter", "smooth_surf"), output_dir=None,
|
15
|
+
output_path=None, **kwargs):
|
16
|
+
"""
|
17
|
+
绘制3D图
|
18
|
+
支持:散点图、三角剖分曲面及其平滑版本
|
19
|
+
|
20
|
+
|
21
|
+
参数:
|
22
|
+
data_s: <dict> 数据。
|
23
|
+
形如 {<data_name>: <data list>, ...} 的字典
|
24
|
+
需要包含 x、y、z 三个键值对,分别对应 x、y、z 轴的数据。
|
25
|
+
title: <str> 绘图标题。
|
26
|
+
x_name: <str> x 轴的数据键名。
|
27
|
+
y_name: <str> y 轴的数据键名。
|
28
|
+
z_name: <str> z 轴的数据键名。
|
29
|
+
cate_name: <str> 以哪个 data_name 作为数据点的类别。
|
30
|
+
type_: <str/list of str> 图表类型。
|
31
|
+
目前支持以下取值,或者以下取值的列表:
|
32
|
+
- "scatter" 散点图
|
33
|
+
- "tri_surf" 三角曲面
|
34
|
+
- "smooth_surf" 平滑曲面
|
35
|
+
当指定列表时,将会绘制多个图表的混合。
|
36
|
+
output_dir: <str> 图片输出目录。
|
37
|
+
output_path: <str> 图片输出路径。
|
38
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
39
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
40
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
41
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
42
|
+
|
43
|
+
其他可选参数:
|
44
|
+
dpi: <int> 保存图像的分辨率。
|
45
|
+
默认为 200。
|
46
|
+
suffix: <str> 图片保存后缀。
|
47
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
48
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
49
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
50
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
51
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
52
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
53
|
+
默认为 False
|
54
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
55
|
+
默认为 True
|
56
|
+
scatter_size: <int> 散点大小,默认 30。
|
57
|
+
cate_of_surf: <str or list of str> 使用哪些类别的数据来绘制曲面。
|
58
|
+
默认为 None,表示使用所有类别的数据来绘制曲面。
|
59
|
+
仅当 cate_name 非 None 时该参数起效。
|
60
|
+
tri_surf_cmap: <str> 三角剖分曲面的颜色映射,默认 "viridis"。
|
61
|
+
tri_surf_alpha: <float> 三角剖分曲面的透明度,默认 0.6。
|
62
|
+
smooth_surf_cmap: <str> 平滑曲面的颜色映射,默认 "coolwarm"。
|
63
|
+
smooth_surf_alpha: <float> 平滑曲面的透明度,默认 0.6。
|
64
|
+
smooth_surf_method: <str> 平滑的方法。
|
65
|
+
支持以下取值:
|
66
|
+
- "linear"
|
67
|
+
- "cubic"
|
68
|
+
view_elev: <float> 视角中的仰角,默认 30。
|
69
|
+
view_azim: <float> 视角中的方位角,默认 45。
|
70
|
+
x_log_scale,y_log_scale,z_log_scale: <int/float> 对 x,y,z 轴数据使用哪个底数进行对数显示。
|
71
|
+
默认为 None,此时表示不使用对数显示。
|
72
|
+
x_ticks,...: <int/list of float or int> 在哪个数字下添加坐标记号。
|
73
|
+
默认为 None,表示不添加记号。
|
74
|
+
当设置为 int 时,表示自动根据 x,y,z 数据的范围,选取等间隔选取多少个坐标作为记号。
|
75
|
+
x_tick_labels,...: <int/list> 坐标记号的label。
|
76
|
+
"""
|
77
|
+
# 默认参数设置
|
78
|
+
paras = {
|
79
|
+
"dpi": 200,
|
80
|
+
"suffix": ".png",
|
81
|
+
"b_generate_record": False,
|
82
|
+
"b_show_plot": False,
|
83
|
+
"b_bgr_image": True,
|
84
|
+
"scatter_size": 30,
|
85
|
+
"cate_of_surf": None,
|
86
|
+
"tri_surf_cmap": "viridis",
|
87
|
+
"tri_surf_alpha": 0.6,
|
88
|
+
"smooth_surf_cmap": "coolwarm",
|
89
|
+
"smooth_surf_alpha": 0.6,
|
90
|
+
"smooth_surf_method": "linear",
|
91
|
+
"view_elev": 30,
|
92
|
+
"view_azim": 45,
|
93
|
+
"x_log_scale": None,
|
94
|
+
"x_ticks": None,
|
95
|
+
"x_tick_labels": None,
|
96
|
+
"y_log_scale": None,
|
97
|
+
"y_ticks": None,
|
98
|
+
"y_tick_labels": None,
|
99
|
+
"z_log_scale": None,
|
100
|
+
"z_ticks": None,
|
101
|
+
"z_tick_labels": None,
|
102
|
+
}
|
103
|
+
paras.update(kwargs)
|
104
|
+
#
|
105
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
106
|
+
save_record(_func=plot_3d, _name=__name,
|
107
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
108
|
+
**paras)
|
109
|
+
data_s = data_s.copy()
|
110
|
+
if isinstance(type_, str):
|
111
|
+
type_ = [type_]
|
112
|
+
#
|
113
|
+
d_s = dict()
|
114
|
+
ticks_s = dict()
|
115
|
+
tick_labels_s = dict()
|
116
|
+
for k in ("x", "y", "z"):
|
117
|
+
d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
|
118
|
+
x_ls=data_s[eval(f'{k}_name')], log_scale=paras[f"{k}_log_scale"],
|
119
|
+
ticks=paras[f"{k}_ticks"], tick_labels=paras[f"{k}_tick_labels"]
|
120
|
+
)
|
121
|
+
|
122
|
+
x, y, z = [d_s[i].reshape(-1) for i in ("x", "y", "z")]
|
123
|
+
color_s = None
|
124
|
+
cate_of_surf = None
|
125
|
+
if cate_name is not None:
|
126
|
+
cates = list(set(data_s[cate_name]))
|
127
|
+
color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
|
128
|
+
c = [color_s[i] for i in data_s[cate_name]]
|
129
|
+
if paras["cate_of_surf"] is not None:
|
130
|
+
temp = [paras["cate_of_surf"], ] if isinstance(paras["cate_of_surf"], str) else paras[
|
131
|
+
"cate_of_surf"]
|
132
|
+
cate_of_surf = [i in temp for i in data_s[cate_name]]
|
133
|
+
else:
|
134
|
+
c = "red"
|
135
|
+
|
136
|
+
plt.clf()
|
137
|
+
fig = plt.figure(figsize=(10, 8))
|
138
|
+
ax = fig.add_subplot(111, projection='3d')
|
139
|
+
|
140
|
+
# 绘制数据点
|
141
|
+
if "scatter" in type_:
|
142
|
+
ax.scatter(x, y, z, s=paras["scatter_size"], c=c, depthshade=True)
|
143
|
+
|
144
|
+
if cate_of_surf is not None:
|
145
|
+
x, y, z = x[cate_of_surf], y[cate_of_surf], z[cate_of_surf]
|
146
|
+
|
147
|
+
# 绘制基于三角剖分的曲面(不平滑)
|
148
|
+
if "tri_surf" in type_:
|
149
|
+
tri_surf = ax.plot_trisurf(x, y, z, cmap=paras["tri_surf_cmap"], alpha=paras["tri_surf_alpha"])
|
150
|
+
|
151
|
+
# 构造规则网格,用于平滑曲面插值
|
152
|
+
if "smooth_surf" in type_:
|
153
|
+
grid_x, grid_y = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
|
154
|
+
grid_z = griddata((x, y), z, (grid_x, grid_y), method=paras["smooth_surf_method"])
|
155
|
+
# 绘制平滑曲面
|
156
|
+
smooth_surf = ax.plot_surface(grid_x, grid_y, grid_z, cmap=paras["smooth_surf_cmap"],
|
157
|
+
edgecolor='none', alpha=paras["smooth_surf_alpha"])
|
158
|
+
# 添加颜色条以展示平滑曲面颜色与 z 值的对应关系
|
159
|
+
cbar = fig.colorbar(smooth_surf, ax=ax, shrink=0.5, aspect=10)
|
160
|
+
cbar.set_label(z_name, fontsize=12)
|
161
|
+
|
162
|
+
# 设置坐标轴标签和图形标题
|
163
|
+
ax.set_xlabel(x_name, fontsize=12)
|
164
|
+
ax.set_ylabel(y_name, fontsize=12)
|
165
|
+
ax.set_zlabel(z_name, fontsize=12)
|
166
|
+
ax.set_title(title, fontsize=14)
|
167
|
+
for i in ("x", "y", "z"):
|
168
|
+
if ticks_s[i] is not None:
|
169
|
+
getattr(ax, f'set_{i}ticks')(ticks_s[i])
|
170
|
+
getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
|
171
|
+
|
172
|
+
# 调整视角
|
173
|
+
ax.view_init(elev=paras["view_elev"], azim=paras["view_azim"])
|
174
|
+
|
175
|
+
# 创建图例
|
176
|
+
if "scatter" in type_ and cate_name is not None:
|
177
|
+
plt.legend(handles=[
|
178
|
+
plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
|
179
|
+
markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
|
180
|
+
])
|
181
|
+
|
182
|
+
return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
183
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
184
|
+
|
185
|
+
|
186
|
+
if __name__ == '__main__':
|
187
|
+
# 示例用法:生成示例数据并绘制3D图像
|
188
|
+
np.random.seed(42)
|
189
|
+
num_points = 200
|
190
|
+
data = {
|
191
|
+
'x': np.random.uniform(-5, 5, num_points),
|
192
|
+
'y': np.random.uniform(-5, 5, num_points),
|
193
|
+
"c": np.random.uniform(-5, 5, num_points) > 0.3,
|
194
|
+
}
|
195
|
+
# 示例 z 值:例如 z = sin(sqrt(x^2+y^2))
|
196
|
+
data['z'] = np.sin(np.sqrt(data['x'] ** 2 + data['y'] ** 2)) + 1.1
|
197
|
+
plot_3d(data, x_name='x', y_name='y', z_name='z', cate_name="c", title="3D Surface Plot", z_log_scale=10, z_ticks=5,
|
198
|
+
type_=("scatter"), output_dir="./temp")
|
@@ -27,7 +27,7 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
|
|
27
27
|
output_path: <str or None> 图片输出路径。
|
28
28
|
以上两个只需指定一个即可,同时指定时以后者为准。
|
29
29
|
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
30
|
-
|
30
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
31
31
|
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
32
32
|
|
33
33
|
其他可选参数:
|
@@ -39,6 +39,10 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
|
|
39
39
|
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
40
40
|
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
41
41
|
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
42
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
43
|
+
默认为 False
|
44
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
45
|
+
默认为 True
|
42
46
|
|
43
47
|
返回值:
|
44
48
|
若 output_dir 非 None,则返回图像保存的文件路径。
|
@@ -78,9 +82,8 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
|
|
78
82
|
# 显示图例
|
79
83
|
plt.legend()
|
80
84
|
|
81
|
-
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"]
|
82
|
-
|
83
|
-
return _output_path
|
85
|
+
return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
86
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
84
87
|
|
85
88
|
|
86
89
|
if __name__ == '__main__':
|
@@ -30,7 +30,7 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
30
30
|
output_path: <str> 图片输出路径。
|
31
31
|
以上两个只需指定一个即可,同时指定时以后者为准。
|
32
32
|
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
33
|
-
|
33
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
34
34
|
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
35
35
|
replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
|
36
36
|
建议使用 np.nan 或 0,默认值为 0。
|
@@ -51,6 +51,10 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
51
51
|
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
52
52
|
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
53
53
|
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
54
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
55
|
+
默认为 False
|
56
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
57
|
+
默认为 True
|
54
58
|
|
55
59
|
返回值:
|
56
60
|
当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
|
@@ -59,6 +63,8 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
59
63
|
"dpi": 200,
|
60
64
|
"suffix": ".png",
|
61
65
|
"b_generate_record": False,
|
66
|
+
"b_show_plot": False,
|
67
|
+
"b_bgr_image": True,
|
62
68
|
"normalize": None, # "true", "pred", "all",
|
63
69
|
"b_return_cfm": False, # 是否输出混淆矩阵
|
64
70
|
}
|
@@ -105,12 +111,13 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
105
111
|
plt.ylabel(f'{gt_name}')
|
106
112
|
plt.title(f'{title}')
|
107
113
|
|
108
|
-
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"]
|
114
|
+
res = save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
115
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
109
116
|
|
110
117
|
if paras["b_return_cfm"]:
|
111
|
-
return
|
118
|
+
return res, cfm
|
112
119
|
else:
|
113
|
-
return
|
120
|
+
return res
|
114
121
|
|
115
122
|
|
116
123
|
if __name__ == '__main__':
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
|
4
|
+
log_scaling
|
5
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
6
|
+
|
7
|
+
__name = ":common_charts:plot_contour"
|
8
|
+
|
9
|
+
|
10
|
+
@COMMON_CHARTS.register(name=__name)
|
11
|
+
def plot_contour(data_s, title, x_name, y_name, z_name, type_=("contour", "contourf"),
|
12
|
+
output_dir=None, output_path=None, **kwargs):
|
13
|
+
"""
|
14
|
+
绘制等高线图
|
15
|
+
|
16
|
+
|
17
|
+
参数:
|
18
|
+
data_s: <dict> 数据。
|
19
|
+
形如 {<data_name>: <data list>, ...} 的字典
|
20
|
+
需要包含 x、y、z 三个键值对,分别对应 x、y、z 轴的数据,值可以是 2D 的矩阵或者"" 1D 数据。 数组。
|
21
|
+
title: <str> 绘图标题。
|
22
|
+
x_name: <str> x 轴的数据键名。
|
23
|
+
y_name: <str> y 轴的数据键名。
|
24
|
+
z_name: <str> z 轴的数据键名。
|
25
|
+
type_: <str/list of str> 图表类型。
|
26
|
+
目前支持以下取值,或者以下取值的列表:
|
27
|
+
- "contour" 等高线
|
28
|
+
- "contourf" 带颜色填充的等高线
|
29
|
+
当指定列表时,将会绘制多个图表的混合。
|
30
|
+
output_dir: <str> 图片输出目录。
|
31
|
+
output_path: <str> 图片输出路径。
|
32
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
33
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
34
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
35
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
36
|
+
|
37
|
+
其他可选参数:
|
38
|
+
dpi: <int> 保存图像的分辨率。
|
39
|
+
默认为 200。
|
40
|
+
suffix: <str> 图片保存后缀。
|
41
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
42
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
43
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
44
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
45
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
46
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
47
|
+
默认为 False
|
48
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
49
|
+
默认为 True
|
50
|
+
contourf_cmap: <str> 带颜色填充的等高线的颜色映射,默认 "viridis"。
|
51
|
+
contourf_alpha: <float> 颜色填充的透明度。
|
52
|
+
linestyles: <str> 等高线线形。
|
53
|
+
可选值:
|
54
|
+
- 'solid' 实线
|
55
|
+
- 'dashed' 虚线(默认)
|
56
|
+
- 'dashdot' 点划线
|
57
|
+
- 'dotted' 点线
|
58
|
+
b_clabel: <boolean> 是否在等高线上显示数值。
|
59
|
+
x_log_scale,y_log_scale,z_log_scale: <int/float> 对 x,y,z 轴数据使用哪个底数进行对数显示。
|
60
|
+
默认为 None,此时表示不使用对数显示。
|
61
|
+
x_ticks,...: <int/list of float or int> 在哪个数字下添加坐标记号。
|
62
|
+
默认为 None,表示不添加记号。
|
63
|
+
当设置为 int 时,表示自动根据 x,y,z 数据的范围,选取等间隔选取多少个坐标作为记号。
|
64
|
+
特别地,可以通过 z_ticks 来指定等高线的数量和划分位置。
|
65
|
+
x_tick_labels,...: <int/list> 坐标记号的label。
|
66
|
+
"""
|
67
|
+
# 默认参数设置
|
68
|
+
paras = {
|
69
|
+
"dpi": 200,
|
70
|
+
"suffix": ".png",
|
71
|
+
"b_generate_record": False,
|
72
|
+
"b_show_plot": False,
|
73
|
+
"b_bgr_image": True,
|
74
|
+
"contourf_cmap": "viridis",
|
75
|
+
"contourf_alpha": 0.5,
|
76
|
+
"linestyles": "dashed",
|
77
|
+
"b_clabel": True,
|
78
|
+
"x_log_scale": None, "x_ticks": None, "x_tick_labels": None,
|
79
|
+
"y_log_scale": None, "y_ticks": None, "y_tick_labels": None,
|
80
|
+
"z_log_scale": None, "z_ticks": None, "z_tick_labels": None,
|
81
|
+
}
|
82
|
+
paras.update(kwargs)
|
83
|
+
#
|
84
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
85
|
+
save_record(_func=plot_contour, _name=__name,
|
86
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
87
|
+
**paras)
|
88
|
+
data_s = data_s.copy()
|
89
|
+
if isinstance(type_, str):
|
90
|
+
type_ = [type_]
|
91
|
+
|
92
|
+
d_s = dict()
|
93
|
+
ticks_s = dict()
|
94
|
+
tick_labels_s = dict()
|
95
|
+
for k in ("x", "y", "z"):
|
96
|
+
d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
|
97
|
+
x_ls=data_s[eval(f'{k}_name')], log_scale=paras[f"{k}_log_scale"],
|
98
|
+
ticks=paras[f"{k}_ticks"], tick_labels=paras[f"{k}_tick_labels"]
|
99
|
+
)
|
100
|
+
X, Y, Z = [d_s[i] for i in ("x", "y", "z")]
|
101
|
+
|
102
|
+
plt.clf()
|
103
|
+
fig = plt.figure(figsize=(10, 8))
|
104
|
+
ax = fig.add_subplot(111)
|
105
|
+
|
106
|
+
# 等高线
|
107
|
+
if "contour" in type_:
|
108
|
+
if X.ndim == 1:
|
109
|
+
contour = ax.tricontour(X, Y, Z, colors='black', linestyles=paras["linestyles"], levels=ticks_s["z"])
|
110
|
+
elif X.ndim == 2:
|
111
|
+
contour = ax.contour(X, Y, Z, colors='black', linestyles=paras["linestyles"], levels=ticks_s["z"])
|
112
|
+
else:
|
113
|
+
raise ValueError("The dimension of X, Y, Z must be 1 or 2.")
|
114
|
+
if paras["b_clabel"]:
|
115
|
+
ax.clabel(contour, inline=True, fontsize=10, fmt={k: v for k, v in zip(ticks_s["z"], tick_labels_s["z"])})
|
116
|
+
|
117
|
+
# 等高线颜色填充
|
118
|
+
if "contourf" in type_:
|
119
|
+
if X.ndim == 1:
|
120
|
+
contourf = ax.tricontourf(X, Y, Z, cmap=paras["contourf_cmap"], alpha=paras["contourf_alpha"],
|
121
|
+
levels=ticks_s["z"])
|
122
|
+
elif X.ndim == 2:
|
123
|
+
contourf = ax.contourf(X, Y, Z, cmap=paras["contourf_cmap"], alpha=paras["contourf_alpha"],
|
124
|
+
levels=ticks_s["z"])
|
125
|
+
else:
|
126
|
+
raise ValueError("The dimension of X, Y, Z must be 1 or 2.")
|
127
|
+
# 添加颜色条以展示平滑曲面颜色与 z 值的对应关系
|
128
|
+
cbar = fig.colorbar(contourf, ax=ax, shrink=0.5, aspect=10)
|
129
|
+
cbar.set_label(z_name, fontsize=12)
|
130
|
+
|
131
|
+
# 设置坐标轴标签和图形标题
|
132
|
+
ax.set_xlabel(x_name, fontsize=12)
|
133
|
+
ax.set_ylabel(y_name, fontsize=12)
|
134
|
+
ax.set_title(title, fontsize=14)
|
135
|
+
for i in ("x", "y",):
|
136
|
+
if ticks_s[i] is not None:
|
137
|
+
getattr(ax, f'set_{i}ticks')(ticks_s[i])
|
138
|
+
getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
|
139
|
+
|
140
|
+
return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
141
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
142
|
+
|
143
|
+
|
144
|
+
if __name__ == '__main__':
|
145
|
+
# 生成示例数据
|
146
|
+
x = np.linspace(1, 7, 100)
|
147
|
+
y = np.linspace(-3, 3, 100)
|
148
|
+
X, Y = np.meshgrid(x, y)
|
149
|
+
# 这里定义一个函数,使得 Z 值落在 0 到 1 之间
|
150
|
+
Z = np.exp(-(X ** 2 + Y ** 2))
|
151
|
+
data = {
|
152
|
+
'x': X,
|
153
|
+
'y': Y,
|
154
|
+
"z": Z,
|
155
|
+
}
|
156
|
+
plot_contour(data, x_name='x', y_name='y', z_name='z', title="Contour Plot", x_log_scale=None, z_ticks=10,
|
157
|
+
type_=("contour", "contourf"), output_dir="./temp")
|