kevin-toolbox-dev 1.4.11__py3-none-any.whl → 1.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/get_pareto_points_idx.py +2 -0
  3. kevin_toolbox/computer_science/algorithm/redirector/redirectable_sequence_fetcher.py +3 -3
  4. kevin_toolbox/computer_science/algorithm/sampler/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/sampler/recent_sampler.py +128 -0
  6. kevin_toolbox/computer_science/algorithm/sampler/reservoir_sampler.py +2 -2
  7. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +2 -0
  8. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +1 -1
  9. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +1 -1
  10. kevin_toolbox/computer_science/algorithm/statistician/maximum_accumulator.py +80 -0
  11. kevin_toolbox/computer_science/algorithm/statistician/minimum_accumulator.py +34 -0
  12. kevin_toolbox/data_flow/file/markdown/table/find_tables.py +38 -12
  13. kevin_toolbox/developing/file_management/__init__.py +1 -0
  14. kevin_toolbox/developing/file_management/file_feature_extractor.py +263 -0
  15. kevin_toolbox/nested_dict_list/serializer/read.py +4 -1
  16. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +5 -0
  17. kevin_toolbox/patches/for_matplotlib/common_charts/plot_2d_matrix.py +134 -0
  18. kevin_toolbox/patches/for_matplotlib/common_charts/plot_3d.py +198 -0
  19. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +7 -4
  20. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +11 -4
  21. kevin_toolbox/patches/for_matplotlib/common_charts/plot_contour.py +157 -0
  22. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +19 -8
  23. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +72 -21
  24. kevin_toolbox/patches/for_matplotlib/common_charts/plot_mean_std_lines.py +135 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +9 -3
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +9 -3
  27. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +1 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/utils/log_scaling.py +69 -0
  29. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +19 -3
  30. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +1 -1
  31. kevin_toolbox/patches/for_numpy/__init__.py +1 -0
  32. kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
  33. kevin_toolbox_dev-1.4.13.dist-info/METADATA +77 -0
  34. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/RECORD +36 -26
  35. kevin_toolbox_dev-1.4.11.dist-info/METADATA +0 -67
  36. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/WHEEL +0 -0
  37. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ def read(input_path, **kwargs):
12
12
 
13
13
  参数:
14
14
  input_path: <path> 文件夹或者 .tar 文件,具体结构参考 write()
15
+ b_keep_identical_relations: <boolean> 覆盖 record.json 中记录的同名参数,该参数的作用详见 write() 中的介绍。
15
16
  """
16
17
  assert os.path.exists(input_path)
17
18
 
@@ -42,7 +43,7 @@ def _read_unpacked_ndl(input_path, **kwargs):
42
43
 
43
44
  # 读取被处理的节点
44
45
  processed_nodes = []
45
- if record_s:
46
+ if "processed" in record_s:
46
47
  for name, value in ndl.get_nodes(var=record_s["processed"], level=-1, b_strict=True):
47
48
  if value:
48
49
  processed_nodes.append(name)
@@ -68,6 +69,8 @@ def _read_unpacked_ndl(input_path, **kwargs):
68
69
  ndl.set_value(var=var, name=name, value=bk.read(**value))
69
70
 
70
71
  #
72
+ if "b_keep_identical_relations" in kwargs:
73
+ record_s["b_keep_identical_relations"] = kwargs["b_keep_identical_relations"]
71
74
  if record_s.get("b_keep_identical_relations", False):
72
75
  from kevin_toolbox.nested_dict_list import value_parser
73
76
  var = value_parser.replace_identical_with_reference(var=var, flag="same", b_reverse=True)
@@ -48,4 +48,9 @@ from .plot_distribution import plot_distribution
48
48
  from .plot_bars import plot_bars
49
49
  from .plot_scatters_matrix import plot_scatters_matrix
50
50
  from .plot_confusion_matrix import plot_confusion_matrix
51
+ from .plot_2d_matrix import plot_2d_matrix
52
+ from .plot_contour import plot_contour
53
+ from .plot_3d import plot_3d
51
54
  from .plot_from_record import plot_from_record
55
+ # from .plot_raincloud import plot_raincloud
56
+ from .plot_mean_std_lines import plot_mean_std_lines
@@ -0,0 +1,134 @@
1
+ import copy
2
+ import warnings
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
7
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
8
+ from kevin_toolbox.env_info.version import compare
9
+
10
+ __name = ":common_charts:plot_matrix"
11
+
12
+ if compare(v_0=sns.__version__, operator="<", v_1='0.13.0'):
13
+ warnings.warn("seaborn version is too low, it may cause the heat map to not be drawn properly,"
14
+ " please upgrade to 0.13.0 or higher")
15
+
16
+
17
+ @COMMON_CHARTS.register(name=__name)
18
+ def plot_2d_matrix(matrix, title, row_label="row", column_label="column", x_tick_labels=None, y_tick_labels=None,
19
+ output_dir=None, output_path=None, replace_zero_division_with=0, **kwargs):
20
+ """
21
+ 计算并绘制混淆矩阵
22
+
23
+ 参数:
24
+ matrix: <np.ndarray> 矩阵
25
+ row_label: <str> 行标签。
26
+ column_label: <str> 列标签。
27
+ title: <str> 绘图标题,同时用于保存图片的文件名。
28
+ output_dir: <str or None>
29
+ 图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
30
+ 若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
31
+
32
+ output_dir: <str> 图片输出目录。
33
+ output_path: <str> 图片输出路径。
34
+ 以上两个只需指定一个即可,同时指定时以后者为准。
35
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
36
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
37
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
38
+ replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
39
+ 建议使用 np.nan 或 0,默认值为 0。
40
+
41
+ 其他可选参数:
42
+ dpi: <int> 图像保存的分辨率。
43
+ suffix: <str> 图片保存后缀。
44
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
45
+ normalize: <str or None> 指定归一化方式。
46
+ 可选值包括:
47
+ "row"(按行归一化)
48
+ "column"(按列归一化)
49
+ "all"(整体归一化)
50
+ 默认为 None 表示不归一化。
51
+ value_fmt: <str> 矩阵元素数值的显示方式。
52
+ b_return_matrix: <bool> 是否在返回值中包含(当使用 normalize 操作时)修改后的矩阵。
53
+ 默认为 False。
54
+ b_generate_record: <boolean> 是否保存函数参数为档案。
55
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
56
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
57
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
58
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
59
+ 默认为 False
60
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
61
+ 默认为 True
62
+ """
63
+ paras = {
64
+ "dpi": 200,
65
+ "suffix": ".png",
66
+ "b_generate_record": False,
67
+ "b_show_plot": False,
68
+ "b_bgr_image": True,
69
+ "normalize": None, # "true", "pred", "all",
70
+ "b_return_matrix": False, # 是否输出混淆矩阵
71
+
72
+ }
73
+ paras.update(kwargs)
74
+ matrix = np.asarray(matrix)
75
+ paras.setdefault("value_fmt",
76
+ '.2%' if paras["normalize"] is not None or np.issubdtype(matrix.dtype, np.floating) else 'd')
77
+ #
78
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
79
+ save_record(_func=plot_2d_matrix, _name=__name,
80
+ _output_path=_output_path if paras["b_generate_record"] else None,
81
+ **paras)
82
+ matrix = copy.deepcopy(matrix)
83
+
84
+ # replace with nan
85
+ if paras["normalize"] is not None:
86
+ if paras["normalize"] == "all":
87
+ if matrix.sum() == 0:
88
+ matrix[matrix == 0] = replace_zero_division_with
89
+ matrix = matrix / matrix.sum()
90
+ else:
91
+ check_axis = 1 if paras["normalize"] == "row" else 0
92
+ temp = np.sum(matrix, axis=check_axis, keepdims=False)
93
+ for i in range(len(temp)):
94
+ if temp[i] == 0:
95
+ if check_axis == 0:
96
+ matrix[:, i] = replace_zero_division_with
97
+ else:
98
+ matrix[i, :] = replace_zero_division_with
99
+ matrix = matrix / np.sum(matrix, axis=check_axis, keepdims=True)
100
+
101
+ # 绘制混淆矩阵热力图
102
+ plt.clf()
103
+ plt.figure(figsize=(8, 6))
104
+ sns.heatmap(matrix, annot=True, fmt=paras["value_fmt"],
105
+ xticklabels=x_tick_labels if x_tick_labels is not None else "auto",
106
+ yticklabels=y_tick_labels if y_tick_labels is not None else "auto",
107
+ cmap='viridis')
108
+
109
+ plt.xlabel(f'{column_label}')
110
+ plt.ylabel(f'{row_label}')
111
+ plt.title(f'{title}')
112
+
113
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
114
+
115
+ if paras["b_return_matrix"]:
116
+ return _output_path, matrix
117
+ else:
118
+ return _output_path
119
+
120
+
121
+ if __name__ == '__main__':
122
+ import os
123
+
124
+ # 示例真实标签和预测标签
125
+ A = np.random.randint(0, 5, (5, 5))
126
+ print(A)
127
+
128
+ plot_2d_matrix(
129
+ matrix=np.random.randint(0, 5, (5, 5)),
130
+ title="2D Matrix",
131
+ output_dir=os.path.join(os.path.dirname(__file__), "temp"),
132
+ replace_zero_division_with=-1,
133
+ # normalize="row"
134
+ )
@@ -0,0 +1,198 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from mpl_toolkits.mplot3d import Axes3D # 兼容部分旧版 matplotlib
4
+ from scipy.interpolate import griddata
5
+ from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
6
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
7
+ log_scaling
8
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
9
+
10
+ __name = ":common_charts:plot_3d"
11
+
12
+
13
+ @COMMON_CHARTS.register(name=__name)
14
+ def plot_3d(data_s, title, x_name, y_name, z_name, cate_name=None, type_=("scatter", "smooth_surf"), output_dir=None,
15
+ output_path=None, **kwargs):
16
+ """
17
+ 绘制3D图
18
+ 支持:散点图、三角剖分曲面及其平滑版本
19
+
20
+
21
+ 参数:
22
+ data_s: <dict> 数据。
23
+ 形如 {<data_name>: <data list>, ...} 的字典
24
+ 需要包含 x、y、z 三个键值对,分别对应 x、y、z 轴的数据。
25
+ title: <str> 绘图标题。
26
+ x_name: <str> x 轴的数据键名。
27
+ y_name: <str> y 轴的数据键名。
28
+ z_name: <str> z 轴的数据键名。
29
+ cate_name: <str> 以哪个 data_name 作为数据点的类别。
30
+ type_: <str/list of str> 图表类型。
31
+ 目前支持以下取值,或者以下取值的列表:
32
+ - "scatter" 散点图
33
+ - "tri_surf" 三角曲面
34
+ - "smooth_surf" 平滑曲面
35
+ 当指定列表时,将会绘制多个图表的混合。
36
+ output_dir: <str> 图片输出目录。
37
+ output_path: <str> 图片输出路径。
38
+ 以上两个只需指定一个即可,同时指定时以后者为准。
39
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
40
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
41
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
42
+
43
+ 其他可选参数:
44
+ dpi: <int> 保存图像的分辨率。
45
+ 默认为 200。
46
+ suffix: <str> 图片保存后缀。
47
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
48
+ b_generate_record: <boolean> 是否保存函数参数为档案。
49
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
50
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
51
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
52
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
53
+ 默认为 False
54
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
55
+ 默认为 True
56
+ scatter_size: <int> 散点大小,默认 30。
57
+ cate_of_surf: <str or list of str> 使用哪些类别的数据来绘制曲面。
58
+ 默认为 None,表示使用所有类别的数据来绘制曲面。
59
+ 仅当 cate_name 非 None 时该参数起效。
60
+ tri_surf_cmap: <str> 三角剖分曲面的颜色映射,默认 "viridis"。
61
+ tri_surf_alpha: <float> 三角剖分曲面的透明度,默认 0.6。
62
+ smooth_surf_cmap: <str> 平滑曲面的颜色映射,默认 "coolwarm"。
63
+ smooth_surf_alpha: <float> 平滑曲面的透明度,默认 0.6。
64
+ smooth_surf_method: <str> 平滑的方法。
65
+ 支持以下取值:
66
+ - "linear"
67
+ - "cubic"
68
+ view_elev: <float> 视角中的仰角,默认 30。
69
+ view_azim: <float> 视角中的方位角,默认 45。
70
+ x_log_scale,y_log_scale,z_log_scale: <int/float> 对 x,y,z 轴数据使用哪个底数进行对数显示。
71
+ 默认为 None,此时表示不使用对数显示。
72
+ x_ticks,...: <int/list of float or int> 在哪个数字下添加坐标记号。
73
+ 默认为 None,表示不添加记号。
74
+ 当设置为 int 时,表示自动根据 x,y,z 数据的范围,选取等间隔选取多少个坐标作为记号。
75
+ x_tick_labels,...: <int/list> 坐标记号的label。
76
+ """
77
+ # 默认参数设置
78
+ paras = {
79
+ "dpi": 200,
80
+ "suffix": ".png",
81
+ "b_generate_record": False,
82
+ "b_show_plot": False,
83
+ "b_bgr_image": True,
84
+ "scatter_size": 30,
85
+ "cate_of_surf": None,
86
+ "tri_surf_cmap": "viridis",
87
+ "tri_surf_alpha": 0.6,
88
+ "smooth_surf_cmap": "coolwarm",
89
+ "smooth_surf_alpha": 0.6,
90
+ "smooth_surf_method": "linear",
91
+ "view_elev": 30,
92
+ "view_azim": 45,
93
+ "x_log_scale": None,
94
+ "x_ticks": None,
95
+ "x_tick_labels": None,
96
+ "y_log_scale": None,
97
+ "y_ticks": None,
98
+ "y_tick_labels": None,
99
+ "z_log_scale": None,
100
+ "z_ticks": None,
101
+ "z_tick_labels": None,
102
+ }
103
+ paras.update(kwargs)
104
+ #
105
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
106
+ save_record(_func=plot_3d, _name=__name,
107
+ _output_path=_output_path if paras["b_generate_record"] else None,
108
+ **paras)
109
+ data_s = data_s.copy()
110
+ if isinstance(type_, str):
111
+ type_ = [type_]
112
+ #
113
+ d_s = dict()
114
+ ticks_s = dict()
115
+ tick_labels_s = dict()
116
+ for k in ("x", "y", "z"):
117
+ d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
118
+ x_ls=data_s[eval(f'{k}_name')], log_scale=paras[f"{k}_log_scale"],
119
+ ticks=paras[f"{k}_ticks"], tick_labels=paras[f"{k}_tick_labels"]
120
+ )
121
+
122
+ x, y, z = [d_s[i].reshape(-1) for i in ("x", "y", "z")]
123
+ color_s = None
124
+ cate_of_surf = None
125
+ if cate_name is not None:
126
+ cates = list(set(data_s[cate_name]))
127
+ color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
128
+ c = [color_s[i] for i in data_s[cate_name]]
129
+ if paras["cate_of_surf"] is not None:
130
+ temp = [paras["cate_of_surf"], ] if isinstance(paras["cate_of_surf"], str) else paras[
131
+ "cate_of_surf"]
132
+ cate_of_surf = [i in temp for i in data_s[cate_name]]
133
+ else:
134
+ c = "red"
135
+
136
+ plt.clf()
137
+ fig = plt.figure(figsize=(10, 8))
138
+ ax = fig.add_subplot(111, projection='3d')
139
+
140
+ # 绘制数据点
141
+ if "scatter" in type_:
142
+ ax.scatter(x, y, z, s=paras["scatter_size"], c=c, depthshade=True)
143
+
144
+ if cate_of_surf is not None:
145
+ x, y, z = x[cate_of_surf], y[cate_of_surf], z[cate_of_surf]
146
+
147
+ # 绘制基于三角剖分的曲面(不平滑)
148
+ if "tri_surf" in type_:
149
+ tri_surf = ax.plot_trisurf(x, y, z, cmap=paras["tri_surf_cmap"], alpha=paras["tri_surf_alpha"])
150
+
151
+ # 构造规则网格,用于平滑曲面插值
152
+ if "smooth_surf" in type_:
153
+ grid_x, grid_y = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
154
+ grid_z = griddata((x, y), z, (grid_x, grid_y), method=paras["smooth_surf_method"])
155
+ # 绘制平滑曲面
156
+ smooth_surf = ax.plot_surface(grid_x, grid_y, grid_z, cmap=paras["smooth_surf_cmap"],
157
+ edgecolor='none', alpha=paras["smooth_surf_alpha"])
158
+ # 添加颜色条以展示平滑曲面颜色与 z 值的对应关系
159
+ cbar = fig.colorbar(smooth_surf, ax=ax, shrink=0.5, aspect=10)
160
+ cbar.set_label(z_name, fontsize=12)
161
+
162
+ # 设置坐标轴标签和图形标题
163
+ ax.set_xlabel(x_name, fontsize=12)
164
+ ax.set_ylabel(y_name, fontsize=12)
165
+ ax.set_zlabel(z_name, fontsize=12)
166
+ ax.set_title(title, fontsize=14)
167
+ for i in ("x", "y", "z"):
168
+ if ticks_s[i] is not None:
169
+ getattr(ax, f'set_{i}ticks')(ticks_s[i])
170
+ getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
171
+
172
+ # 调整视角
173
+ ax.view_init(elev=paras["view_elev"], azim=paras["view_azim"])
174
+
175
+ # 创建图例
176
+ if "scatter" in type_ and cate_name is not None:
177
+ plt.legend(handles=[
178
+ plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
179
+ markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
180
+ ])
181
+
182
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
183
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
184
+
185
+
186
+ if __name__ == '__main__':
187
+ # 示例用法:生成示例数据并绘制3D图像
188
+ np.random.seed(42)
189
+ num_points = 200
190
+ data = {
191
+ 'x': np.random.uniform(-5, 5, num_points),
192
+ 'y': np.random.uniform(-5, 5, num_points),
193
+ "c": np.random.uniform(-5, 5, num_points) > 0.3,
194
+ }
195
+ # 示例 z 值:例如 z = sin(sqrt(x^2+y^2))
196
+ data['z'] = np.sin(np.sqrt(data['x'] ** 2 + data['y'] ** 2)) + 1.1
197
+ plot_3d(data, x_name='x', y_name='y', z_name='z', cate_name="c", title="3D Surface Plot", z_log_scale=10, z_ticks=5,
198
+ type_=("scatter"), output_dir="./temp")
@@ -27,7 +27,7 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
27
27
  output_path: <str or None> 图片输出路径。
28
28
  以上两个只需指定一个即可,同时指定时以后者为准。
29
29
  当只有 output_dir 被指定时,将会以 title 作为图片名。
30
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
30
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
31
31
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
32
32
 
33
33
  其他可选参数:
@@ -39,6 +39,10 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
39
39
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
40
40
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
41
41
  该参数仅在 output_dir 和 output_path 非 None 时起效。
42
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
43
+ 默认为 False
44
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
45
+ 默认为 True
42
46
 
43
47
  返回值:
44
48
  若 output_dir 非 None,则返回图像保存的文件路径。
@@ -78,9 +82,8 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
78
82
  # 显示图例
79
83
  plt.legend()
80
84
 
81
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
82
-
83
- return _output_path
85
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
86
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
84
87
 
85
88
 
86
89
  if __name__ == '__main__':
@@ -30,7 +30,7 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
30
30
  output_path: <str> 图片输出路径。
31
31
  以上两个只需指定一个即可,同时指定时以后者为准。
32
32
  当只有 output_dir 被指定时,将会以 title 作为图片名。
33
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
33
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
34
34
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
35
35
  replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
36
36
  建议使用 np.nan 或 0,默认值为 0。
@@ -51,6 +51,10 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
51
51
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
52
52
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
53
53
  该参数仅在 output_dir 和 output_path 非 None 时起效。
54
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
55
+ 默认为 False
56
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
57
+ 默认为 True
54
58
 
55
59
  返回值:
56
60
  当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
@@ -59,6 +63,8 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
59
63
  "dpi": 200,
60
64
  "suffix": ".png",
61
65
  "b_generate_record": False,
66
+ "b_show_plot": False,
67
+ "b_bgr_image": True,
62
68
  "normalize": None, # "true", "pred", "all",
63
69
  "b_return_cfm": False, # 是否输出混淆矩阵
64
70
  }
@@ -105,12 +111,13 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
105
111
  plt.ylabel(f'{gt_name}')
106
112
  plt.title(f'{title}')
107
113
 
108
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
114
+ res = save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
115
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
109
116
 
110
117
  if paras["b_return_cfm"]:
111
- return _output_path, cfm
118
+ return res, cfm
112
119
  else:
113
- return _output_path
120
+ return res
114
121
 
115
122
 
116
123
  if __name__ == '__main__':
@@ -0,0 +1,157 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
4
+ log_scaling
5
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
6
+
7
+ __name = ":common_charts:plot_contour"
8
+
9
+
10
+ @COMMON_CHARTS.register(name=__name)
11
+ def plot_contour(data_s, title, x_name, y_name, z_name, type_=("contour", "contourf"),
12
+ output_dir=None, output_path=None, **kwargs):
13
+ """
14
+ 绘制等高线图
15
+
16
+
17
+ 参数:
18
+ data_s: <dict> 数据。
19
+ 形如 {<data_name>: <data list>, ...} 的字典
20
+ 需要包含 x、y、z 三个键值对,分别对应 x、y、z 轴的数据,值可以是 2D 的矩阵或者"" 1D 数据。 数组。
21
+ title: <str> 绘图标题。
22
+ x_name: <str> x 轴的数据键名。
23
+ y_name: <str> y 轴的数据键名。
24
+ z_name: <str> z 轴的数据键名。
25
+ type_: <str/list of str> 图表类型。
26
+ 目前支持以下取值,或者以下取值的列表:
27
+ - "contour" 等高线
28
+ - "contourf" 带颜色填充的等高线
29
+ 当指定列表时,将会绘制多个图表的混合。
30
+ output_dir: <str> 图片输出目录。
31
+ output_path: <str> 图片输出路径。
32
+ 以上两个只需指定一个即可,同时指定时以后者为准。
33
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
34
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
35
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
36
+
37
+ 其他可选参数:
38
+ dpi: <int> 保存图像的分辨率。
39
+ 默认为 200。
40
+ suffix: <str> 图片保存后缀。
41
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
42
+ b_generate_record: <boolean> 是否保存函数参数为档案。
43
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
44
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
45
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
46
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
47
+ 默认为 False
48
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
49
+ 默认为 True
50
+ contourf_cmap: <str> 带颜色填充的等高线的颜色映射,默认 "viridis"。
51
+ contourf_alpha: <float> 颜色填充的透明度。
52
+ linestyles: <str> 等高线线形。
53
+ 可选值:
54
+ - 'solid' 实线
55
+ - 'dashed' 虚线(默认)
56
+ - 'dashdot' 点划线
57
+ - 'dotted' 点线
58
+ b_clabel: <boolean> 是否在等高线上显示数值。
59
+ x_log_scale,y_log_scale,z_log_scale: <int/float> 对 x,y,z 轴数据使用哪个底数进行对数显示。
60
+ 默认为 None,此时表示不使用对数显示。
61
+ x_ticks,...: <int/list of float or int> 在哪个数字下添加坐标记号。
62
+ 默认为 None,表示不添加记号。
63
+ 当设置为 int 时,表示自动根据 x,y,z 数据的范围,选取等间隔选取多少个坐标作为记号。
64
+ 特别地,可以通过 z_ticks 来指定等高线的数量和划分位置。
65
+ x_tick_labels,...: <int/list> 坐标记号的label。
66
+ """
67
+ # 默认参数设置
68
+ paras = {
69
+ "dpi": 200,
70
+ "suffix": ".png",
71
+ "b_generate_record": False,
72
+ "b_show_plot": False,
73
+ "b_bgr_image": True,
74
+ "contourf_cmap": "viridis",
75
+ "contourf_alpha": 0.5,
76
+ "linestyles": "dashed",
77
+ "b_clabel": True,
78
+ "x_log_scale": None, "x_ticks": None, "x_tick_labels": None,
79
+ "y_log_scale": None, "y_ticks": None, "y_tick_labels": None,
80
+ "z_log_scale": None, "z_ticks": None, "z_tick_labels": None,
81
+ }
82
+ paras.update(kwargs)
83
+ #
84
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
85
+ save_record(_func=plot_contour, _name=__name,
86
+ _output_path=_output_path if paras["b_generate_record"] else None,
87
+ **paras)
88
+ data_s = data_s.copy()
89
+ if isinstance(type_, str):
90
+ type_ = [type_]
91
+
92
+ d_s = dict()
93
+ ticks_s = dict()
94
+ tick_labels_s = dict()
95
+ for k in ("x", "y", "z"):
96
+ d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
97
+ x_ls=data_s[eval(f'{k}_name')], log_scale=paras[f"{k}_log_scale"],
98
+ ticks=paras[f"{k}_ticks"], tick_labels=paras[f"{k}_tick_labels"]
99
+ )
100
+ X, Y, Z = [d_s[i] for i in ("x", "y", "z")]
101
+
102
+ plt.clf()
103
+ fig = plt.figure(figsize=(10, 8))
104
+ ax = fig.add_subplot(111)
105
+
106
+ # 等高线
107
+ if "contour" in type_:
108
+ if X.ndim == 1:
109
+ contour = ax.tricontour(X, Y, Z, colors='black', linestyles=paras["linestyles"], levels=ticks_s["z"])
110
+ elif X.ndim == 2:
111
+ contour = ax.contour(X, Y, Z, colors='black', linestyles=paras["linestyles"], levels=ticks_s["z"])
112
+ else:
113
+ raise ValueError("The dimension of X, Y, Z must be 1 or 2.")
114
+ if paras["b_clabel"]:
115
+ ax.clabel(contour, inline=True, fontsize=10, fmt={k: v for k, v in zip(ticks_s["z"], tick_labels_s["z"])})
116
+
117
+ # 等高线颜色填充
118
+ if "contourf" in type_:
119
+ if X.ndim == 1:
120
+ contourf = ax.tricontourf(X, Y, Z, cmap=paras["contourf_cmap"], alpha=paras["contourf_alpha"],
121
+ levels=ticks_s["z"])
122
+ elif X.ndim == 2:
123
+ contourf = ax.contourf(X, Y, Z, cmap=paras["contourf_cmap"], alpha=paras["contourf_alpha"],
124
+ levels=ticks_s["z"])
125
+ else:
126
+ raise ValueError("The dimension of X, Y, Z must be 1 or 2.")
127
+ # 添加颜色条以展示平滑曲面颜色与 z 值的对应关系
128
+ cbar = fig.colorbar(contourf, ax=ax, shrink=0.5, aspect=10)
129
+ cbar.set_label(z_name, fontsize=12)
130
+
131
+ # 设置坐标轴标签和图形标题
132
+ ax.set_xlabel(x_name, fontsize=12)
133
+ ax.set_ylabel(y_name, fontsize=12)
134
+ ax.set_title(title, fontsize=14)
135
+ for i in ("x", "y",):
136
+ if ticks_s[i] is not None:
137
+ getattr(ax, f'set_{i}ticks')(ticks_s[i])
138
+ getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
139
+
140
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
141
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
142
+
143
+
144
+ if __name__ == '__main__':
145
+ # 生成示例数据
146
+ x = np.linspace(1, 7, 100)
147
+ y = np.linspace(-3, 3, 100)
148
+ X, Y = np.meshgrid(x, y)
149
+ # 这里定义一个函数,使得 Z 值落在 0 到 1 之间
150
+ Z = np.exp(-(X ** 2 + Y ** 2))
151
+ data = {
152
+ 'x': X,
153
+ 'y': Y,
154
+ "z": Z,
155
+ }
156
+ plot_contour(data, x_name='x', y_name='y', z_name='z', title="Contour Plot", x_log_scale=None, z_ticks=10,
157
+ type_=("contour", "contourf"), output_dir="./temp")