kevin-toolbox-dev 1.4.11__py3-none-any.whl → 1.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +2 -0
  3. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +1 -1
  4. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +1 -1
  5. kevin_toolbox/computer_science/algorithm/statistician/maximum_accumulator.py +80 -0
  6. kevin_toolbox/computer_science/algorithm/statistician/minimum_accumulator.py +34 -0
  7. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +3 -0
  8. kevin_toolbox/patches/for_matplotlib/common_charts/plot_2d_matrix.py +128 -0
  9. kevin_toolbox/patches/for_matplotlib/common_charts/plot_3d.py +198 -0
  10. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +7 -4
  11. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +11 -4
  12. kevin_toolbox/patches/for_matplotlib/common_charts/plot_contour.py +157 -0
  13. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +19 -8
  14. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +65 -21
  15. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +9 -3
  16. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +9 -3
  17. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +1 -0
  18. kevin_toolbox/patches/for_matplotlib/common_charts/utils/log_scaling.py +62 -0
  19. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +19 -3
  20. kevin_toolbox/patches/for_numpy/__init__.py +1 -0
  21. kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
  22. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/METADATA +9 -12
  23. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/RECORD +25 -19
  24. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/WHEEL +0 -0
  25. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/top_level.txt +0 -0
kevin_toolbox/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "1.4.11"
1
+ __version__ = "1.4.12"
2
2
 
3
3
 
4
4
  import os
@@ -12,5 +12,5 @@ os.system(
12
12
  os.system(
13
13
  f'python {os.path.split(__file__)[0]}/env_info/check_validity_and_uninstall.py '
14
14
  f'--package_name kevin-toolbox-dev '
15
- f'--expiration_timestamp 1760340881 --verbose 0'
15
+ f'--expiration_timestamp 1763992975 --verbose 0'
16
16
  )
@@ -2,3 +2,5 @@ from .accumulator_base import Accumulator_Base
2
2
  from .exponential_moving_average import Exponential_Moving_Average
3
3
  from .average_accumulator import Average_Accumulator
4
4
  from .accumulator_for_ndl import Accumulator_for_Ndl
5
+ from .maximum_accumulator import Maximum_Accumulator
6
+ from .minimum_accumulator import Minimum_Accumulator
@@ -27,7 +27,7 @@ class Average_Accumulator(Accumulator_Base):
27
27
  以上三种方式,默认选用最后一种。
28
28
  如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
29
29
  """
30
- super(Average_Accumulator, self).__init__(**kwargs)
30
+ super().__init__(**kwargs)
31
31
 
32
32
  def add_sequence(self, var_ls, **kwargs):
33
33
  for var in var_ls:
@@ -56,7 +56,7 @@ class Exponential_Moving_Average(Accumulator_Base):
56
56
  # 校验参数
57
57
  assert isinstance(paras["keep_ratio"], (int, float,)) and 0 <= paras["keep_ratio"] <= 1
58
58
  #
59
- super(Exponential_Moving_Average, self).__init__(**paras)
59
+ super().__init__(**paras)
60
60
 
61
61
  def add_sequence(self, var_ls, **kwargs):
62
62
  for var in var_ls:
@@ -0,0 +1,80 @@
1
+ import numpy as np
2
+ import torch
3
+ from kevin_toolbox.computer_science.algorithm.statistician import Accumulator_Base
4
+
5
+
6
+ class Maximum_Accumulator(Accumulator_Base):
7
+ """
8
+ 用于计算最大值的累积器
9
+ """
10
+
11
+ def __init__(self, **kwargs):
12
+ """
13
+ 参数:
14
+ data_format: 指定数据格式
15
+ like: 指定数据格式
16
+ 指定输入数据的格式,有三种方式:
17
+ 1. 显式指定数据的形状和所在设备等。
18
+ data_format: <dict of paras>
19
+ 其中需要包含以下参数:
20
+ type_: <str>
21
+ "numpy": np.ndarray
22
+ "torch": torch.tensor
23
+ shape: <list of integers>
24
+ device: <torch.device>
25
+ dtype: <torch.dtype>
26
+ 2. 根据输入的数据,来推断出形状、设备等。
27
+ like: <torch.tensor / np.ndarray / int / float>
28
+ 3. 均不指定 data_format 和 like,此时将等到第一次调用 add()/add_sequence() 时再根据输入来自动推断。
29
+ 以上三种方式,默认选用最后一种。
30
+ 如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
31
+ """
32
+ super().__init__(**kwargs)
33
+
34
+ def add_sequence(self, var_ls, **kwargs):
35
+ for var in var_ls:
36
+ self.add(var, **kwargs)
37
+
38
+ def add(self, var, **kwargs):
39
+ """
40
+ 添加单个数据
41
+
42
+ 参数:
43
+ var: 数据
44
+ """
45
+ if self.var is None:
46
+ self.var = var
47
+ else:
48
+ # 统计
49
+ if torch.is_tensor(var):
50
+ self.var = torch.maximum(self.var, var)
51
+ else:
52
+ self.var = np.maximum(self.var, var)
53
+ self.state["total_nums"] += 1
54
+
55
+ def get(self, **kwargs):
56
+ """
57
+ 获取当前累加的平均值
58
+ 当未有累积时,返回 None
59
+ """
60
+ if len(self) == 0:
61
+ return None
62
+ return self.var
63
+
64
+ @staticmethod
65
+ def _init_state():
66
+ """
67
+ 初始化状态
68
+ """
69
+ return dict(
70
+ total_nums=0
71
+ )
72
+
73
+
74
+ if __name__ == '__main__':
75
+
76
+ seq = list(torch.tensor(range(1, 10))-5)
77
+ avg = Maximum_Accumulator()
78
+ for i, v in enumerate(seq):
79
+ avg.add(var=v)
80
+ print(i, v, avg.get())
@@ -0,0 +1,34 @@
1
+ import numpy as np
2
+ import torch
3
+ from kevin_toolbox.computer_science.algorithm.statistician import Maximum_Accumulator
4
+
5
+
6
+ class Minimum_Accumulator(Maximum_Accumulator):
7
+ """
8
+ 用于计算最小值的累积器
9
+ """
10
+
11
+ def add(self, var, **kwargs):
12
+ """
13
+ 添加单个数据
14
+
15
+ 参数:
16
+ var: 数据
17
+ """
18
+ if self.var is None:
19
+ self.var = var
20
+ else:
21
+ # 统计
22
+ if torch.is_tensor(var):
23
+ self.var = torch.minimum(self.var, var)
24
+ else:
25
+ self.var = np.minimum(self.var, var)
26
+ self.state["total_nums"] += 1
27
+
28
+
29
+ if __name__ == '__main__':
30
+ seq = list(torch.tensor(range(1, 10)) + 5)
31
+ avg = Minimum_Accumulator()
32
+ for i, v in enumerate(seq):
33
+ avg.add(var=v)
34
+ print(i, v, avg.get())
@@ -48,4 +48,7 @@ from .plot_distribution import plot_distribution
48
48
  from .plot_bars import plot_bars
49
49
  from .plot_scatters_matrix import plot_scatters_matrix
50
50
  from .plot_confusion_matrix import plot_confusion_matrix
51
+ from .plot_2d_matrix import plot_2d_matrix
52
+ from .plot_contour import plot_contour
53
+ from .plot_3d import plot_3d
51
54
  from .plot_from_record import plot_from_record
@@ -0,0 +1,128 @@
1
+ import copy
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
6
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
7
+
8
+ __name = ":common_charts:plot_matrix"
9
+
10
+
11
+ @COMMON_CHARTS.register(name=__name)
12
+ def plot_2d_matrix(matrix, title, row_label="row", column_label="column", x_tick_labels=None, y_tick_labels=None,
13
+ output_dir=None, output_path=None, replace_zero_division_with=0, **kwargs):
14
+ """
15
+ 计算并绘制混淆矩阵
16
+
17
+ 参数:
18
+ matrix: <np.ndarray> 矩阵
19
+ row_label: <str> 行标签。
20
+ column_label: <str> 列标签。
21
+ title: <str> 绘图标题,同时用于保存图片的文件名。
22
+ output_dir: <str or None>
23
+ 图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
24
+ 若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
25
+
26
+ output_dir: <str> 图片输出目录。
27
+ output_path: <str> 图片输出路径。
28
+ 以上两个只需指定一个即可,同时指定时以后者为准。
29
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
30
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
31
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
32
+ replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
33
+ 建议使用 np.nan 或 0,默认值为 0。
34
+
35
+ 其他可选参数:
36
+ dpi: <int> 图像保存的分辨率。
37
+ suffix: <str> 图片保存后缀。
38
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
39
+ normalize: <str or None> 指定归一化方式。
40
+ 可选值包括:
41
+ "row"(按行归一化)
42
+ "column"(按列归一化)
43
+ "all"(整体归一化)
44
+ 默认为 None 表示不归一化。
45
+ value_fmt: <str> 矩阵元素数值的显示方式。
46
+ b_return_matrix: <bool> 是否在返回值中包含(当使用 normalize 操作时)修改后的矩阵。
47
+ 默认为 False。
48
+ b_generate_record: <boolean> 是否保存函数参数为档案。
49
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
50
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
51
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
52
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
53
+ 默认为 False
54
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
55
+ 默认为 True
56
+ """
57
+ paras = {
58
+ "dpi": 200,
59
+ "suffix": ".png",
60
+ "b_generate_record": False,
61
+ "b_show_plot": False,
62
+ "b_bgr_image": True,
63
+ "normalize": None, # "true", "pred", "all",
64
+ "b_return_matrix": False, # 是否输出混淆矩阵
65
+
66
+ }
67
+ paras.update(kwargs)
68
+ matrix = np.asarray(matrix)
69
+ paras.setdefault("value_fmt",
70
+ '.2%' if paras["normalize"] is not None or np.issubdtype(matrix.dtype, np.floating) else 'd')
71
+ #
72
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
73
+ save_record(_func=plot_2d_matrix, _name=__name,
74
+ _output_path=_output_path if paras["b_generate_record"] else None,
75
+ **paras)
76
+ matrix = copy.deepcopy(matrix)
77
+
78
+ # replace with nan
79
+ if paras["normalize"] is not None:
80
+ if paras["normalize"] == "all":
81
+ if matrix.sum() == 0:
82
+ matrix[matrix == 0] = replace_zero_division_with
83
+ matrix = matrix / matrix.sum()
84
+ else:
85
+ check_axis = 1 if paras["normalize"] == "row" else 0
86
+ temp = np.sum(matrix, axis=check_axis, keepdims=False)
87
+ for i in range(len(temp)):
88
+ if temp[i] == 0:
89
+ if check_axis == 0:
90
+ matrix[:, i] = replace_zero_division_with
91
+ else:
92
+ matrix[i, :] = replace_zero_division_with
93
+ matrix = matrix / np.sum(matrix, axis=check_axis, keepdims=True)
94
+
95
+ # 绘制混淆矩阵热力图
96
+ plt.clf()
97
+ plt.figure(figsize=(8, 6))
98
+ sns.heatmap(matrix, annot=True, fmt=paras["value_fmt"],
99
+ xticklabels=x_tick_labels if x_tick_labels is not None else "auto",
100
+ yticklabels=y_tick_labels if y_tick_labels is not None else "auto",
101
+ cmap='viridis')
102
+
103
+ plt.xlabel(f'{column_label}')
104
+ plt.ylabel(f'{row_label}')
105
+ plt.title(f'{title}')
106
+
107
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
108
+
109
+ if paras["b_return_matrix"]:
110
+ return _output_path, matrix
111
+ else:
112
+ return _output_path
113
+
114
+
115
+ if __name__ == '__main__':
116
+ import os
117
+
118
+ # 示例真实标签和预测标签
119
+ A = np.random.randint(0, 5, (5, 5))
120
+ print(A)
121
+
122
+ plot_2d_matrix(
123
+ matrix=np.random.randint(0, 5, (5, 5)),
124
+ title="2D Matrix",
125
+ output_dir=os.path.join(os.path.dirname(__file__), "temp"),
126
+ replace_zero_division_with=-1,
127
+ normalize="row"
128
+ )
@@ -0,0 +1,198 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from mpl_toolkits.mplot3d import Axes3D # 兼容部分旧版 matplotlib
4
+ from scipy.interpolate import griddata
5
+ from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
6
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
7
+ log_scaling
8
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
9
+
10
+ __name = ":common_charts:plot_3d"
11
+
12
+
13
+ @COMMON_CHARTS.register(name=__name)
14
+ def plot_3d(data_s, title, x_name, y_name, z_name, cate_name=None, type_=("scatter", "smooth_surf"), output_dir=None,
15
+ output_path=None, **kwargs):
16
+ """
17
+ 绘制3D图
18
+ 支持:散点图、三角剖分曲面及其平滑版本
19
+
20
+
21
+ 参数:
22
+ data_s: <dict> 数据。
23
+ 形如 {<data_name>: <data list>, ...} 的字典
24
+ 需要包含 x、y、z 三个键值对,分别对应 x、y、z 轴的数据。
25
+ title: <str> 绘图标题。
26
+ x_name: <str> x 轴的数据键名。
27
+ y_name: <str> y 轴的数据键名。
28
+ z_name: <str> z 轴的数据键名。
29
+ cate_name: <str> 以哪个 data_name 作为数据点的类别。
30
+ type_: <str/list of str> 图表类型。
31
+ 目前支持以下取值,或者以下取值的列表:
32
+ - "scatter" 散点图
33
+ - "tri_surf" 三角曲面
34
+ - "smooth_surf" 平滑曲面
35
+ 当指定列表时,将会绘制多个图表的混合。
36
+ output_dir: <str> 图片输出目录。
37
+ output_path: <str> 图片输出路径。
38
+ 以上两个只需指定一个即可,同时指定时以后者为准。
39
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
40
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
41
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
42
+
43
+ 其他可选参数:
44
+ dpi: <int> 保存图像的分辨率。
45
+ 默认为 200。
46
+ suffix: <str> 图片保存后缀。
47
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
48
+ b_generate_record: <boolean> 是否保存函数参数为档案。
49
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
50
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
51
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
52
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
53
+ 默认为 False
54
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
55
+ 默认为 True
56
+ scatter_size: <int> 散点大小,默认 30。
57
+ cate_of_surf: <str or list of str> 使用哪些类别的数据来绘制曲面。
58
+ 默认为 None,表示使用所有类别的数据来绘制曲面。
59
+ 仅当 cate_name 非 None 时该参数起效。
60
+ tri_surf_cmap: <str> 三角剖分曲面的颜色映射,默认 "viridis"。
61
+ tri_surf_alpha: <float> 三角剖分曲面的透明度,默认 0.6。
62
+ smooth_surf_cmap: <str> 平滑曲面的颜色映射,默认 "coolwarm"。
63
+ smooth_surf_alpha: <float> 平滑曲面的透明度,默认 0.6。
64
+ smooth_surf_method: <str> 平滑的方法。
65
+ 支持以下取值:
66
+ - "linear"
67
+ - "cubic"
68
+ view_elev: <float> 视角中的仰角,默认 30。
69
+ view_azim: <float> 视角中的方位角,默认 45。
70
+ x_log_scale,y_log_scale,z_log_scale: <int/float> 对 x,y,z 轴数据使用哪个底数进行对数显示。
71
+ 默认为 None,此时表示不使用对数显示。
72
+ x_ticks,...: <int/list of float or int> 在哪个数字下添加坐标记号。
73
+ 默认为 None,表示不添加记号。
74
+ 当设置为 int 时,表示自动根据 x,y,z 数据的范围,选取等间隔选取多少个坐标作为记号。
75
+ x_tick_labels,...: <int/list> 坐标记号的label。
76
+ """
77
+ # 默认参数设置
78
+ paras = {
79
+ "dpi": 200,
80
+ "suffix": ".png",
81
+ "b_generate_record": False,
82
+ "b_show_plot": False,
83
+ "b_bgr_image": True,
84
+ "scatter_size": 30,
85
+ "cate_of_surf": None,
86
+ "tri_surf_cmap": "viridis",
87
+ "tri_surf_alpha": 0.6,
88
+ "smooth_surf_cmap": "coolwarm",
89
+ "smooth_surf_alpha": 0.6,
90
+ "smooth_surf_method": "linear",
91
+ "view_elev": 30,
92
+ "view_azim": 45,
93
+ "x_log_scale": None,
94
+ "x_ticks": None,
95
+ "x_tick_labels": None,
96
+ "y_log_scale": None,
97
+ "y_ticks": None,
98
+ "y_tick_labels": None,
99
+ "z_log_scale": None,
100
+ "z_ticks": None,
101
+ "z_tick_labels": None,
102
+ }
103
+ paras.update(kwargs)
104
+ #
105
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
106
+ save_record(_func=plot_3d, _name=__name,
107
+ _output_path=_output_path if paras["b_generate_record"] else None,
108
+ **paras)
109
+ data_s = data_s.copy()
110
+ if isinstance(type_, str):
111
+ type_ = [type_]
112
+ #
113
+ d_s = dict()
114
+ ticks_s = dict()
115
+ tick_labels_s = dict()
116
+ for k in ("x", "y", "z"):
117
+ d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
118
+ x_ls=data_s[eval(f'{k}_name')], log_scale=paras[f"{k}_log_scale"],
119
+ ticks=paras[f"{k}_ticks"], tick_labels=paras[f"{k}_tick_labels"]
120
+ )
121
+
122
+ x, y, z = [d_s[i].reshape(-1) for i in ("x", "y", "z")]
123
+ color_s = None
124
+ cate_of_surf = None
125
+ if cate_name is not None:
126
+ cates = list(set(data_s[cate_name]))
127
+ color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
128
+ c = [color_s[i] for i in data_s[cate_name]]
129
+ if paras["cate_of_surf"] is not None:
130
+ temp = [paras["cate_of_surf"], ] if isinstance(paras["cate_of_surf"], str) else paras[
131
+ "cate_of_surf"]
132
+ cate_of_surf = [i in temp for i in data_s[cate_name]]
133
+ else:
134
+ c = "red"
135
+
136
+ plt.clf()
137
+ fig = plt.figure(figsize=(10, 8))
138
+ ax = fig.add_subplot(111, projection='3d')
139
+
140
+ # 绘制数据点
141
+ if "scatter" in type_:
142
+ ax.scatter(x, y, z, s=paras["scatter_size"], c=c, depthshade=True)
143
+
144
+ if cate_of_surf is not None:
145
+ x, y, z = x[cate_of_surf], y[cate_of_surf], z[cate_of_surf]
146
+
147
+ # 绘制基于三角剖分的曲面(不平滑)
148
+ if "tri_surf" in type_:
149
+ tri_surf = ax.plot_trisurf(x, y, z, cmap=paras["tri_surf_cmap"], alpha=paras["tri_surf_alpha"])
150
+
151
+ # 构造规则网格,用于平滑曲面插值
152
+ if "smooth_surf" in type_:
153
+ grid_x, grid_y = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
154
+ grid_z = griddata((x, y), z, (grid_x, grid_y), method=paras["smooth_surf_method"])
155
+ # 绘制平滑曲面
156
+ smooth_surf = ax.plot_surface(grid_x, grid_y, grid_z, cmap=paras["smooth_surf_cmap"],
157
+ edgecolor='none', alpha=paras["smooth_surf_alpha"])
158
+ # 添加颜色条以展示平滑曲面颜色与 z 值的对应关系
159
+ cbar = fig.colorbar(smooth_surf, ax=ax, shrink=0.5, aspect=10)
160
+ cbar.set_label(z_name, fontsize=12)
161
+
162
+ # 设置坐标轴标签和图形标题
163
+ ax.set_xlabel(x_name, fontsize=12)
164
+ ax.set_ylabel(y_name, fontsize=12)
165
+ ax.set_zlabel(z_name, fontsize=12)
166
+ ax.set_title(title, fontsize=14)
167
+ for i in ("x", "y", "z"):
168
+ if ticks_s[i] is not None:
169
+ getattr(ax, f'set_{i}ticks')(ticks_s[i])
170
+ getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
171
+
172
+ # 调整视角
173
+ ax.view_init(elev=paras["view_elev"], azim=paras["view_azim"])
174
+
175
+ # 创建图例
176
+ if "scatter" in type_ and cate_name is not None:
177
+ plt.legend(handles=[
178
+ plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
179
+ markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
180
+ ])
181
+
182
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
183
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
184
+
185
+
186
+ if __name__ == '__main__':
187
+ # 示例用法:生成示例数据并绘制3D图像
188
+ np.random.seed(42)
189
+ num_points = 200
190
+ data = {
191
+ 'x': np.random.uniform(-5, 5, num_points),
192
+ 'y': np.random.uniform(-5, 5, num_points),
193
+ "c": np.random.uniform(-5, 5, num_points) > 0.3,
194
+ }
195
+ # 示例 z 值:例如 z = sin(sqrt(x^2+y^2))
196
+ data['z'] = np.sin(np.sqrt(data['x'] ** 2 + data['y'] ** 2)) + 1.1
197
+ plot_3d(data, x_name='x', y_name='y', z_name='z', cate_name="c", title="3D Surface Plot", z_log_scale=10, z_ticks=5,
198
+ type_=("scatter"), output_dir="./temp")
@@ -27,7 +27,7 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
27
27
  output_path: <str or None> 图片输出路径。
28
28
  以上两个只需指定一个即可,同时指定时以后者为准。
29
29
  当只有 output_dir 被指定时,将会以 title 作为图片名。
30
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
30
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
31
31
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
32
32
 
33
33
  其他可选参数:
@@ -39,6 +39,10 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
39
39
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
40
40
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
41
41
  该参数仅在 output_dir 和 output_path 非 None 时起效。
42
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
43
+ 默认为 False
44
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
45
+ 默认为 True
42
46
 
43
47
  返回值:
44
48
  若 output_dir 非 None,则返回图像保存的文件路径。
@@ -78,9 +82,8 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
78
82
  # 显示图例
79
83
  plt.legend()
80
84
 
81
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
82
-
83
- return _output_path
85
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
86
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
84
87
 
85
88
 
86
89
  if __name__ == '__main__':
@@ -30,7 +30,7 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
30
30
  output_path: <str> 图片输出路径。
31
31
  以上两个只需指定一个即可,同时指定时以后者为准。
32
32
  当只有 output_dir 被指定时,将会以 title 作为图片名。
33
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
33
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
34
34
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
35
35
  replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
36
36
  建议使用 np.nan 或 0,默认值为 0。
@@ -51,6 +51,10 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
51
51
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
52
52
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
53
53
  该参数仅在 output_dir 和 output_path 非 None 时起效。
54
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
55
+ 默认为 False
56
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
57
+ 默认为 True
54
58
 
55
59
  返回值:
56
60
  当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
@@ -59,6 +63,8 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
59
63
  "dpi": 200,
60
64
  "suffix": ".png",
61
65
  "b_generate_record": False,
66
+ "b_show_plot": False,
67
+ "b_bgr_image": True,
62
68
  "normalize": None, # "true", "pred", "all",
63
69
  "b_return_cfm": False, # 是否输出混淆矩阵
64
70
  }
@@ -105,12 +111,13 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
105
111
  plt.ylabel(f'{gt_name}')
106
112
  plt.title(f'{title}')
107
113
 
108
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
114
+ res = save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
115
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
109
116
 
110
117
  if paras["b_return_cfm"]:
111
- return _output_path, cfm
118
+ return res, cfm
112
119
  else:
113
- return _output_path
120
+ return res
114
121
 
115
122
 
116
123
  if __name__ == '__main__':