kevin-toolbox-dev 1.4.10__py3-none-any.whl → 1.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/computer_science/algorithm/for_seq/sample_subset_most_evenly.py +3 -1
- kevin_toolbox/computer_science/algorithm/redirector/redirectable_sequence_fetcher.py +2 -2
- kevin_toolbox/computer_science/algorithm/statistician/__init__.py +2 -0
- kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +1 -1
- kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +1 -1
- kevin_toolbox/computer_science/algorithm/statistician/maximum_accumulator.py +80 -0
- kevin_toolbox/computer_science/algorithm/statistician/minimum_accumulator.py +34 -0
- kevin_toolbox/nested_dict_list/get_nodes.py +9 -0
- kevin_toolbox/nested_dict_list/value_parser/replace_identical_with_reference.py +2 -4
- kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +3 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_2d_matrix.py +128 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_3d.py +198 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +7 -4
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +11 -4
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_contour.py +157 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +19 -8
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +65 -21
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +9 -3
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +9 -3
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +1 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/log_scaling.py +62 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +19 -3
- kevin_toolbox/patches/for_numpy/__init__.py +1 -0
- kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
- kevin_toolbox_dev-1.4.12.dist-info/METADATA +64 -0
- {kevin_toolbox_dev-1.4.10.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/RECORD +29 -23
- kevin_toolbox_dev-1.4.10.dist-info/METADATA +0 -106
- {kevin_toolbox_dev-1.4.10.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.4.10.dist-info → kevin_toolbox_dev-1.4.12.dist-info}/top_level.txt +0 -0
kevin_toolbox/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "1.4.
|
1
|
+
__version__ = "1.4.12"
|
2
2
|
|
3
3
|
|
4
4
|
import os
|
@@ -12,5 +12,5 @@ os.system(
|
|
12
12
|
os.system(
|
13
13
|
f'python {os.path.split(__file__)[0]}/env_info/check_validity_and_uninstall.py '
|
14
14
|
f'--package_name kevin-toolbox-dev '
|
15
|
-
f'--expiration_timestamp
|
15
|
+
f'--expiration_timestamp 1763992975 --verbose 0'
|
16
16
|
)
|
@@ -26,7 +26,9 @@ def sample_subset_most_evenly(inputs, ratio=None, nums=None, seed=None, rng=None
|
|
26
26
|
仅在b_shuffle_the_tail=True时,以上两个参数起效,且仅需指定一个即可。
|
27
27
|
|
28
28
|
"""
|
29
|
-
|
29
|
+
if nums is None:
|
30
|
+
assert ratio is not None
|
31
|
+
nums = math.ceil(len(inputs) * ratio)
|
30
32
|
assert nums >= 0
|
31
33
|
if len(inputs) == 0 or nums == 0:
|
32
34
|
return []
|
@@ -10,8 +10,8 @@ def _randomly_idx_redirector(idx, seq_len, attempts, rng, *args):
|
|
10
10
|
elif idx == seq_len - 1:
|
11
11
|
return rng.randint(0, seq_len - 2)
|
12
12
|
else:
|
13
|
-
return rng.
|
14
|
-
|
13
|
+
return rng.choice([rng.randint(0, idx - 1), rng.randint(idx + 1, seq_len - 1)], size=1,
|
14
|
+
p=[idx / (seq_len - 1), (seq_len - idx - 1) / (seq_len - 1)])[0]
|
15
15
|
|
16
16
|
|
17
17
|
idx_redirector_s = {
|
@@ -2,3 +2,5 @@ from .accumulator_base import Accumulator_Base
|
|
2
2
|
from .exponential_moving_average import Exponential_Moving_Average
|
3
3
|
from .average_accumulator import Average_Accumulator
|
4
4
|
from .accumulator_for_ndl import Accumulator_for_Ndl
|
5
|
+
from .maximum_accumulator import Maximum_Accumulator
|
6
|
+
from .minimum_accumulator import Minimum_Accumulator
|
@@ -27,7 +27,7 @@ class Average_Accumulator(Accumulator_Base):
|
|
27
27
|
以上三种方式,默认选用最后一种。
|
28
28
|
如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
|
29
29
|
"""
|
30
|
-
super(
|
30
|
+
super().__init__(**kwargs)
|
31
31
|
|
32
32
|
def add_sequence(self, var_ls, **kwargs):
|
33
33
|
for var in var_ls:
|
@@ -56,7 +56,7 @@ class Exponential_Moving_Average(Accumulator_Base):
|
|
56
56
|
# 校验参数
|
57
57
|
assert isinstance(paras["keep_ratio"], (int, float,)) and 0 <= paras["keep_ratio"] <= 1
|
58
58
|
#
|
59
|
-
super(
|
59
|
+
super().__init__(**paras)
|
60
60
|
|
61
61
|
def add_sequence(self, var_ls, **kwargs):
|
62
62
|
for var in var_ls:
|
@@ -0,0 +1,80 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
from kevin_toolbox.computer_science.algorithm.statistician import Accumulator_Base
|
4
|
+
|
5
|
+
|
6
|
+
class Maximum_Accumulator(Accumulator_Base):
|
7
|
+
"""
|
8
|
+
用于计算最大值的累积器
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, **kwargs):
|
12
|
+
"""
|
13
|
+
参数:
|
14
|
+
data_format: 指定数据格式
|
15
|
+
like: 指定数据格式
|
16
|
+
指定输入数据的格式,有三种方式:
|
17
|
+
1. 显式指定数据的形状和所在设备等。
|
18
|
+
data_format: <dict of paras>
|
19
|
+
其中需要包含以下参数:
|
20
|
+
type_: <str>
|
21
|
+
"numpy": np.ndarray
|
22
|
+
"torch": torch.tensor
|
23
|
+
shape: <list of integers>
|
24
|
+
device: <torch.device>
|
25
|
+
dtype: <torch.dtype>
|
26
|
+
2. 根据输入的数据,来推断出形状、设备等。
|
27
|
+
like: <torch.tensor / np.ndarray / int / float>
|
28
|
+
3. 均不指定 data_format 和 like,此时将等到第一次调用 add()/add_sequence() 时再根据输入来自动推断。
|
29
|
+
以上三种方式,默认选用最后一种。
|
30
|
+
如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
|
31
|
+
"""
|
32
|
+
super().__init__(**kwargs)
|
33
|
+
|
34
|
+
def add_sequence(self, var_ls, **kwargs):
|
35
|
+
for var in var_ls:
|
36
|
+
self.add(var, **kwargs)
|
37
|
+
|
38
|
+
def add(self, var, **kwargs):
|
39
|
+
"""
|
40
|
+
添加单个数据
|
41
|
+
|
42
|
+
参数:
|
43
|
+
var: 数据
|
44
|
+
"""
|
45
|
+
if self.var is None:
|
46
|
+
self.var = var
|
47
|
+
else:
|
48
|
+
# 统计
|
49
|
+
if torch.is_tensor(var):
|
50
|
+
self.var = torch.maximum(self.var, var)
|
51
|
+
else:
|
52
|
+
self.var = np.maximum(self.var, var)
|
53
|
+
self.state["total_nums"] += 1
|
54
|
+
|
55
|
+
def get(self, **kwargs):
|
56
|
+
"""
|
57
|
+
获取当前累加的平均值
|
58
|
+
当未有累积时,返回 None
|
59
|
+
"""
|
60
|
+
if len(self) == 0:
|
61
|
+
return None
|
62
|
+
return self.var
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def _init_state():
|
66
|
+
"""
|
67
|
+
初始化状态
|
68
|
+
"""
|
69
|
+
return dict(
|
70
|
+
total_nums=0
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
if __name__ == '__main__':
|
75
|
+
|
76
|
+
seq = list(torch.tensor(range(1, 10))-5)
|
77
|
+
avg = Maximum_Accumulator()
|
78
|
+
for i, v in enumerate(seq):
|
79
|
+
avg.add(var=v)
|
80
|
+
print(i, v, avg.get())
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
from kevin_toolbox.computer_science.algorithm.statistician import Maximum_Accumulator
|
4
|
+
|
5
|
+
|
6
|
+
class Minimum_Accumulator(Maximum_Accumulator):
|
7
|
+
"""
|
8
|
+
用于计算最小值的累积器
|
9
|
+
"""
|
10
|
+
|
11
|
+
def add(self, var, **kwargs):
|
12
|
+
"""
|
13
|
+
添加单个数据
|
14
|
+
|
15
|
+
参数:
|
16
|
+
var: 数据
|
17
|
+
"""
|
18
|
+
if self.var is None:
|
19
|
+
self.var = var
|
20
|
+
else:
|
21
|
+
# 统计
|
22
|
+
if torch.is_tensor(var):
|
23
|
+
self.var = torch.minimum(self.var, var)
|
24
|
+
else:
|
25
|
+
self.var = np.minimum(self.var, var)
|
26
|
+
self.state["total_nums"] += 1
|
27
|
+
|
28
|
+
|
29
|
+
if __name__ == '__main__':
|
30
|
+
seq = list(torch.tensor(range(1, 10)) + 5)
|
31
|
+
avg = Minimum_Accumulator()
|
32
|
+
for i, v in enumerate(seq):
|
33
|
+
avg.add(var=v)
|
34
|
+
print(i, v, avg.get())
|
@@ -21,6 +21,15 @@ def get_nodes(var, level=-1, b_strict=True, **kwargs):
|
|
21
21
|
对于 level=-10 返回的是 [('', {'d': {'c': 4}, 'c': 4}), ]
|
22
22
|
对于 level=10 返回的是 [(':c', 4), (':d:c', 4)]
|
23
23
|
默认为 True,不添加。
|
24
|
+
|
25
|
+
注意:
|
26
|
+
当 level 为负数(表示从叶节点往上计起)时,某些节点可能同时属于多个 level,比如对于:
|
27
|
+
{'d': {'c': [1, ], 'e': 4}},
|
28
|
+
其中:
|
29
|
+
level=-1: :d:e, :d:c@0
|
30
|
+
level=-2: :d, :d:c
|
31
|
+
level=-3: "", :d
|
32
|
+
可以看到由于 :d 下面有两个不等长的到不同叶节点的路径,因此该节点属于 level -2 和 -3
|
24
33
|
"""
|
25
34
|
assert isinstance(level, (int,))
|
26
35
|
kwargs.setdefault("b_skip_repeated_non_leaf_node", False)
|
@@ -44,7 +44,7 @@ def _forward(var, flag, match_cond):
|
|
44
44
|
if not match_cond(name, value):
|
45
45
|
continue
|
46
46
|
id_to_name_s[id(value)].add(name)
|
47
|
-
id_to_height_s[id(value)]
|
47
|
+
id_to_height_s[id(value)] = height
|
48
48
|
height += 1
|
49
49
|
|
50
50
|
#
|
@@ -54,10 +54,8 @@ def _forward(var, flag, match_cond):
|
|
54
54
|
id_to_height_s.pop(k)
|
55
55
|
id_to_name_s.pop(k)
|
56
56
|
continue
|
57
|
-
# 具有相同 id 的节点所处的高度应该相同
|
58
|
-
assert len(v) == 1, f'nodes {id_to_name_s[k]} have different heights: {v}'
|
59
57
|
# 按高度排序
|
60
|
-
id_vs_height = sorted([(k, v
|
58
|
+
id_vs_height = sorted([(k, v) for k, v in id_to_height_s.items()], key=lambda x: x[1], reverse=True)
|
61
59
|
|
62
60
|
# 从高到低,依次将具有相同 id 的节点替换为 单个节点和多个引用 的形式
|
63
61
|
temp = []
|
@@ -48,4 +48,7 @@ from .plot_distribution import plot_distribution
|
|
48
48
|
from .plot_bars import plot_bars
|
49
49
|
from .plot_scatters_matrix import plot_scatters_matrix
|
50
50
|
from .plot_confusion_matrix import plot_confusion_matrix
|
51
|
+
from .plot_2d_matrix import plot_2d_matrix
|
52
|
+
from .plot_contour import plot_contour
|
53
|
+
from .plot_3d import plot_3d
|
51
54
|
from .plot_from_record import plot_from_record
|
@@ -0,0 +1,128 @@
|
|
1
|
+
import copy
|
2
|
+
import numpy as np
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
import seaborn as sns
|
5
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
|
6
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
7
|
+
|
8
|
+
__name = ":common_charts:plot_matrix"
|
9
|
+
|
10
|
+
|
11
|
+
@COMMON_CHARTS.register(name=__name)
|
12
|
+
def plot_2d_matrix(matrix, title, row_label="row", column_label="column", x_tick_labels=None, y_tick_labels=None,
|
13
|
+
output_dir=None, output_path=None, replace_zero_division_with=0, **kwargs):
|
14
|
+
"""
|
15
|
+
计算并绘制混淆矩阵
|
16
|
+
|
17
|
+
参数:
|
18
|
+
matrix: <np.ndarray> 矩阵
|
19
|
+
row_label: <str> 行标签。
|
20
|
+
column_label: <str> 列标签。
|
21
|
+
title: <str> 绘图标题,同时用于保存图片的文件名。
|
22
|
+
output_dir: <str or None>
|
23
|
+
图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
|
24
|
+
若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
|
25
|
+
|
26
|
+
output_dir: <str> 图片输出目录。
|
27
|
+
output_path: <str> 图片输出路径。
|
28
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
29
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
30
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
31
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
32
|
+
replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
|
33
|
+
建议使用 np.nan 或 0,默认值为 0。
|
34
|
+
|
35
|
+
其他可选参数:
|
36
|
+
dpi: <int> 图像保存的分辨率。
|
37
|
+
suffix: <str> 图片保存后缀。
|
38
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
39
|
+
normalize: <str or None> 指定归一化方式。
|
40
|
+
可选值包括:
|
41
|
+
"row"(按行归一化)
|
42
|
+
"column"(按列归一化)
|
43
|
+
"all"(整体归一化)
|
44
|
+
默认为 None 表示不归一化。
|
45
|
+
value_fmt: <str> 矩阵元素数值的显示方式。
|
46
|
+
b_return_matrix: <bool> 是否在返回值中包含(当使用 normalize 操作时)修改后的矩阵。
|
47
|
+
默认为 False。
|
48
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
49
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
50
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
51
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
52
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
53
|
+
默认为 False
|
54
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
55
|
+
默认为 True
|
56
|
+
"""
|
57
|
+
paras = {
|
58
|
+
"dpi": 200,
|
59
|
+
"suffix": ".png",
|
60
|
+
"b_generate_record": False,
|
61
|
+
"b_show_plot": False,
|
62
|
+
"b_bgr_image": True,
|
63
|
+
"normalize": None, # "true", "pred", "all",
|
64
|
+
"b_return_matrix": False, # 是否输出混淆矩阵
|
65
|
+
|
66
|
+
}
|
67
|
+
paras.update(kwargs)
|
68
|
+
matrix = np.asarray(matrix)
|
69
|
+
paras.setdefault("value_fmt",
|
70
|
+
'.2%' if paras["normalize"] is not None or np.issubdtype(matrix.dtype, np.floating) else 'd')
|
71
|
+
#
|
72
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
73
|
+
save_record(_func=plot_2d_matrix, _name=__name,
|
74
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
75
|
+
**paras)
|
76
|
+
matrix = copy.deepcopy(matrix)
|
77
|
+
|
78
|
+
# replace with nan
|
79
|
+
if paras["normalize"] is not None:
|
80
|
+
if paras["normalize"] == "all":
|
81
|
+
if matrix.sum() == 0:
|
82
|
+
matrix[matrix == 0] = replace_zero_division_with
|
83
|
+
matrix = matrix / matrix.sum()
|
84
|
+
else:
|
85
|
+
check_axis = 1 if paras["normalize"] == "row" else 0
|
86
|
+
temp = np.sum(matrix, axis=check_axis, keepdims=False)
|
87
|
+
for i in range(len(temp)):
|
88
|
+
if temp[i] == 0:
|
89
|
+
if check_axis == 0:
|
90
|
+
matrix[:, i] = replace_zero_division_with
|
91
|
+
else:
|
92
|
+
matrix[i, :] = replace_zero_division_with
|
93
|
+
matrix = matrix / np.sum(matrix, axis=check_axis, keepdims=True)
|
94
|
+
|
95
|
+
# 绘制混淆矩阵热力图
|
96
|
+
plt.clf()
|
97
|
+
plt.figure(figsize=(8, 6))
|
98
|
+
sns.heatmap(matrix, annot=True, fmt=paras["value_fmt"],
|
99
|
+
xticklabels=x_tick_labels if x_tick_labels is not None else "auto",
|
100
|
+
yticklabels=y_tick_labels if y_tick_labels is not None else "auto",
|
101
|
+
cmap='viridis')
|
102
|
+
|
103
|
+
plt.xlabel(f'{column_label}')
|
104
|
+
plt.ylabel(f'{row_label}')
|
105
|
+
plt.title(f'{title}')
|
106
|
+
|
107
|
+
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
|
108
|
+
|
109
|
+
if paras["b_return_matrix"]:
|
110
|
+
return _output_path, matrix
|
111
|
+
else:
|
112
|
+
return _output_path
|
113
|
+
|
114
|
+
|
115
|
+
if __name__ == '__main__':
|
116
|
+
import os
|
117
|
+
|
118
|
+
# 示例真实标签和预测标签
|
119
|
+
A = np.random.randint(0, 5, (5, 5))
|
120
|
+
print(A)
|
121
|
+
|
122
|
+
plot_2d_matrix(
|
123
|
+
matrix=np.random.randint(0, 5, (5, 5)),
|
124
|
+
title="2D Matrix",
|
125
|
+
output_dir=os.path.join(os.path.dirname(__file__), "temp"),
|
126
|
+
replace_zero_division_with=-1,
|
127
|
+
normalize="row"
|
128
|
+
)
|
@@ -0,0 +1,198 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
from mpl_toolkits.mplot3d import Axes3D # 兼容部分旧版 matplotlib
|
4
|
+
from scipy.interpolate import griddata
|
5
|
+
from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
|
6
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
|
7
|
+
log_scaling
|
8
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
9
|
+
|
10
|
+
__name = ":common_charts:plot_3d"
|
11
|
+
|
12
|
+
|
13
|
+
@COMMON_CHARTS.register(name=__name)
|
14
|
+
def plot_3d(data_s, title, x_name, y_name, z_name, cate_name=None, type_=("scatter", "smooth_surf"), output_dir=None,
|
15
|
+
output_path=None, **kwargs):
|
16
|
+
"""
|
17
|
+
绘制3D图
|
18
|
+
支持:散点图、三角剖分曲面及其平滑版本
|
19
|
+
|
20
|
+
|
21
|
+
参数:
|
22
|
+
data_s: <dict> 数据。
|
23
|
+
形如 {<data_name>: <data list>, ...} 的字典
|
24
|
+
需要包含 x、y、z 三个键值对,分别对应 x、y、z 轴的数据。
|
25
|
+
title: <str> 绘图标题。
|
26
|
+
x_name: <str> x 轴的数据键名。
|
27
|
+
y_name: <str> y 轴的数据键名。
|
28
|
+
z_name: <str> z 轴的数据键名。
|
29
|
+
cate_name: <str> 以哪个 data_name 作为数据点的类别。
|
30
|
+
type_: <str/list of str> 图表类型。
|
31
|
+
目前支持以下取值,或者以下取值的列表:
|
32
|
+
- "scatter" 散点图
|
33
|
+
- "tri_surf" 三角曲面
|
34
|
+
- "smooth_surf" 平滑曲面
|
35
|
+
当指定列表时,将会绘制多个图表的混合。
|
36
|
+
output_dir: <str> 图片输出目录。
|
37
|
+
output_path: <str> 图片输出路径。
|
38
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
39
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
40
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
41
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
42
|
+
|
43
|
+
其他可选参数:
|
44
|
+
dpi: <int> 保存图像的分辨率。
|
45
|
+
默认为 200。
|
46
|
+
suffix: <str> 图片保存后缀。
|
47
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
48
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
49
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
50
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
51
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
52
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
53
|
+
默认为 False
|
54
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
55
|
+
默认为 True
|
56
|
+
scatter_size: <int> 散点大小,默认 30。
|
57
|
+
cate_of_surf: <str or list of str> 使用哪些类别的数据来绘制曲面。
|
58
|
+
默认为 None,表示使用所有类别的数据来绘制曲面。
|
59
|
+
仅当 cate_name 非 None 时该参数起效。
|
60
|
+
tri_surf_cmap: <str> 三角剖分曲面的颜色映射,默认 "viridis"。
|
61
|
+
tri_surf_alpha: <float> 三角剖分曲面的透明度,默认 0.6。
|
62
|
+
smooth_surf_cmap: <str> 平滑曲面的颜色映射,默认 "coolwarm"。
|
63
|
+
smooth_surf_alpha: <float> 平滑曲面的透明度,默认 0.6。
|
64
|
+
smooth_surf_method: <str> 平滑的方法。
|
65
|
+
支持以下取值:
|
66
|
+
- "linear"
|
67
|
+
- "cubic"
|
68
|
+
view_elev: <float> 视角中的仰角,默认 30。
|
69
|
+
view_azim: <float> 视角中的方位角,默认 45。
|
70
|
+
x_log_scale,y_log_scale,z_log_scale: <int/float> 对 x,y,z 轴数据使用哪个底数进行对数显示。
|
71
|
+
默认为 None,此时表示不使用对数显示。
|
72
|
+
x_ticks,...: <int/list of float or int> 在哪个数字下添加坐标记号。
|
73
|
+
默认为 None,表示不添加记号。
|
74
|
+
当设置为 int 时,表示自动根据 x,y,z 数据的范围,选取等间隔选取多少个坐标作为记号。
|
75
|
+
x_tick_labels,...: <int/list> 坐标记号的label。
|
76
|
+
"""
|
77
|
+
# 默认参数设置
|
78
|
+
paras = {
|
79
|
+
"dpi": 200,
|
80
|
+
"suffix": ".png",
|
81
|
+
"b_generate_record": False,
|
82
|
+
"b_show_plot": False,
|
83
|
+
"b_bgr_image": True,
|
84
|
+
"scatter_size": 30,
|
85
|
+
"cate_of_surf": None,
|
86
|
+
"tri_surf_cmap": "viridis",
|
87
|
+
"tri_surf_alpha": 0.6,
|
88
|
+
"smooth_surf_cmap": "coolwarm",
|
89
|
+
"smooth_surf_alpha": 0.6,
|
90
|
+
"smooth_surf_method": "linear",
|
91
|
+
"view_elev": 30,
|
92
|
+
"view_azim": 45,
|
93
|
+
"x_log_scale": None,
|
94
|
+
"x_ticks": None,
|
95
|
+
"x_tick_labels": None,
|
96
|
+
"y_log_scale": None,
|
97
|
+
"y_ticks": None,
|
98
|
+
"y_tick_labels": None,
|
99
|
+
"z_log_scale": None,
|
100
|
+
"z_ticks": None,
|
101
|
+
"z_tick_labels": None,
|
102
|
+
}
|
103
|
+
paras.update(kwargs)
|
104
|
+
#
|
105
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
106
|
+
save_record(_func=plot_3d, _name=__name,
|
107
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
108
|
+
**paras)
|
109
|
+
data_s = data_s.copy()
|
110
|
+
if isinstance(type_, str):
|
111
|
+
type_ = [type_]
|
112
|
+
#
|
113
|
+
d_s = dict()
|
114
|
+
ticks_s = dict()
|
115
|
+
tick_labels_s = dict()
|
116
|
+
for k in ("x", "y", "z"):
|
117
|
+
d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
|
118
|
+
x_ls=data_s[eval(f'{k}_name')], log_scale=paras[f"{k}_log_scale"],
|
119
|
+
ticks=paras[f"{k}_ticks"], tick_labels=paras[f"{k}_tick_labels"]
|
120
|
+
)
|
121
|
+
|
122
|
+
x, y, z = [d_s[i].reshape(-1) for i in ("x", "y", "z")]
|
123
|
+
color_s = None
|
124
|
+
cate_of_surf = None
|
125
|
+
if cate_name is not None:
|
126
|
+
cates = list(set(data_s[cate_name]))
|
127
|
+
color_s = {i: j for i, j in zip(cates, generate_color_list(nums=len(cates)))}
|
128
|
+
c = [color_s[i] for i in data_s[cate_name]]
|
129
|
+
if paras["cate_of_surf"] is not None:
|
130
|
+
temp = [paras["cate_of_surf"], ] if isinstance(paras["cate_of_surf"], str) else paras[
|
131
|
+
"cate_of_surf"]
|
132
|
+
cate_of_surf = [i in temp for i in data_s[cate_name]]
|
133
|
+
else:
|
134
|
+
c = "red"
|
135
|
+
|
136
|
+
plt.clf()
|
137
|
+
fig = plt.figure(figsize=(10, 8))
|
138
|
+
ax = fig.add_subplot(111, projection='3d')
|
139
|
+
|
140
|
+
# 绘制数据点
|
141
|
+
if "scatter" in type_:
|
142
|
+
ax.scatter(x, y, z, s=paras["scatter_size"], c=c, depthshade=True)
|
143
|
+
|
144
|
+
if cate_of_surf is not None:
|
145
|
+
x, y, z = x[cate_of_surf], y[cate_of_surf], z[cate_of_surf]
|
146
|
+
|
147
|
+
# 绘制基于三角剖分的曲面(不平滑)
|
148
|
+
if "tri_surf" in type_:
|
149
|
+
tri_surf = ax.plot_trisurf(x, y, z, cmap=paras["tri_surf_cmap"], alpha=paras["tri_surf_alpha"])
|
150
|
+
|
151
|
+
# 构造规则网格,用于平滑曲面插值
|
152
|
+
if "smooth_surf" in type_:
|
153
|
+
grid_x, grid_y = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
|
154
|
+
grid_z = griddata((x, y), z, (grid_x, grid_y), method=paras["smooth_surf_method"])
|
155
|
+
# 绘制平滑曲面
|
156
|
+
smooth_surf = ax.plot_surface(grid_x, grid_y, grid_z, cmap=paras["smooth_surf_cmap"],
|
157
|
+
edgecolor='none', alpha=paras["smooth_surf_alpha"])
|
158
|
+
# 添加颜色条以展示平滑曲面颜色与 z 值的对应关系
|
159
|
+
cbar = fig.colorbar(smooth_surf, ax=ax, shrink=0.5, aspect=10)
|
160
|
+
cbar.set_label(z_name, fontsize=12)
|
161
|
+
|
162
|
+
# 设置坐标轴标签和图形标题
|
163
|
+
ax.set_xlabel(x_name, fontsize=12)
|
164
|
+
ax.set_ylabel(y_name, fontsize=12)
|
165
|
+
ax.set_zlabel(z_name, fontsize=12)
|
166
|
+
ax.set_title(title, fontsize=14)
|
167
|
+
for i in ("x", "y", "z"):
|
168
|
+
if ticks_s[i] is not None:
|
169
|
+
getattr(ax, f'set_{i}ticks')(ticks_s[i])
|
170
|
+
getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
|
171
|
+
|
172
|
+
# 调整视角
|
173
|
+
ax.view_init(elev=paras["view_elev"], azim=paras["view_azim"])
|
174
|
+
|
175
|
+
# 创建图例
|
176
|
+
if "scatter" in type_ and cate_name is not None:
|
177
|
+
plt.legend(handles=[
|
178
|
+
plt.Line2D([0], [0], marker='o', color='w', label=i, markerfacecolor=j,
|
179
|
+
markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
|
180
|
+
])
|
181
|
+
|
182
|
+
return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
183
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
184
|
+
|
185
|
+
|
186
|
+
if __name__ == '__main__':
|
187
|
+
# 示例用法:生成示例数据并绘制3D图像
|
188
|
+
np.random.seed(42)
|
189
|
+
num_points = 200
|
190
|
+
data = {
|
191
|
+
'x': np.random.uniform(-5, 5, num_points),
|
192
|
+
'y': np.random.uniform(-5, 5, num_points),
|
193
|
+
"c": np.random.uniform(-5, 5, num_points) > 0.3,
|
194
|
+
}
|
195
|
+
# 示例 z 值:例如 z = sin(sqrt(x^2+y^2))
|
196
|
+
data['z'] = np.sin(np.sqrt(data['x'] ** 2 + data['y'] ** 2)) + 1.1
|
197
|
+
plot_3d(data, x_name='x', y_name='y', z_name='z', cate_name="c", title="3D Surface Plot", z_log_scale=10, z_ticks=5,
|
198
|
+
type_=("scatter"), output_dir="./temp")
|
@@ -27,7 +27,7 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
|
|
27
27
|
output_path: <str or None> 图片输出路径。
|
28
28
|
以上两个只需指定一个即可,同时指定时以后者为准。
|
29
29
|
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
30
|
-
|
30
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
31
31
|
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
32
32
|
|
33
33
|
其他可选参数:
|
@@ -39,6 +39,10 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
|
|
39
39
|
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
40
40
|
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
41
41
|
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
42
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
43
|
+
默认为 False
|
44
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
45
|
+
默认为 True
|
42
46
|
|
43
47
|
返回值:
|
44
48
|
若 output_dir 非 None,则返回图像保存的文件路径。
|
@@ -78,9 +82,8 @@ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs
|
|
78
82
|
# 显示图例
|
79
83
|
plt.legend()
|
80
84
|
|
81
|
-
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"]
|
82
|
-
|
83
|
-
return _output_path
|
85
|
+
return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
86
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
84
87
|
|
85
88
|
|
86
89
|
if __name__ == '__main__':
|
@@ -30,7 +30,7 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
30
30
|
output_path: <str> 图片输出路径。
|
31
31
|
以上两个只需指定一个即可,同时指定时以后者为准。
|
32
32
|
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
33
|
-
|
33
|
+
若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
|
34
34
|
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
35
35
|
replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
|
36
36
|
建议使用 np.nan 或 0,默认值为 0。
|
@@ -51,6 +51,10 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
51
51
|
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
52
52
|
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
53
53
|
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
54
|
+
b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
|
55
|
+
默认为 False
|
56
|
+
b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
|
57
|
+
默认为 True
|
54
58
|
|
55
59
|
返回值:
|
56
60
|
当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
|
@@ -59,6 +63,8 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
59
63
|
"dpi": 200,
|
60
64
|
"suffix": ".png",
|
61
65
|
"b_generate_record": False,
|
66
|
+
"b_show_plot": False,
|
67
|
+
"b_bgr_image": True,
|
62
68
|
"normalize": None, # "true", "pred", "all",
|
63
69
|
"b_return_cfm": False, # 是否输出混淆矩阵
|
64
70
|
}
|
@@ -105,12 +111,13 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
105
111
|
plt.ylabel(f'{gt_name}')
|
106
112
|
plt.title(f'{title}')
|
107
113
|
|
108
|
-
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"]
|
114
|
+
res = save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
|
115
|
+
b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
|
109
116
|
|
110
117
|
if paras["b_return_cfm"]:
|
111
|
-
return
|
118
|
+
return res, cfm
|
112
119
|
else:
|
113
|
-
return
|
120
|
+
return res
|
114
121
|
|
115
122
|
|
116
123
|
if __name__ == '__main__':
|