kevin-toolbox-dev 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/{developing → computer_science/algorithm}/decorator/__init__.py +2 -1
- kevin_toolbox/computer_science/algorithm/decorator/retry.py +62 -0
- kevin_toolbox/computer_science/algorithm/registration/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/registration/serializer_for_registry_execution.py +82 -0
- kevin_toolbox/data_flow/core/cache/cache_manager_for_iterator.py +1 -1
- kevin_toolbox/data_flow/file/json_/write_json.py +2 -1
- kevin_toolbox/env_info/check_version_and_update.py +0 -1
- kevin_toolbox/env_info/variable_/env_vars_parser.py +17 -2
- kevin_toolbox/nested_dict_list/copy_.py +4 -2
- kevin_toolbox/nested_dict_list/get_nodes.py +4 -2
- kevin_toolbox/nested_dict_list/serializer/variable.py +14 -2
- kevin_toolbox/nested_dict_list/serializer/write.py +2 -0
- kevin_toolbox/nested_dict_list/traverse.py +75 -21
- kevin_toolbox/nested_dict_list/value_parser/replace_identical_with_reference.py +1 -4
- kevin_toolbox/network/__init__.py +10 -0
- kevin_toolbox/network/download_file.py +120 -0
- kevin_toolbox/network/fetch_content.py +55 -0
- kevin_toolbox/network/fetch_metadata.py +64 -0
- kevin_toolbox/network/get_response.py +50 -0
- kevin_toolbox/network/variable.py +6 -0
- kevin_toolbox/patches/for_logging/build_logger.py +1 -1
- kevin_toolbox/patches/for_matplotlib/color/convert_format.py +0 -2
- kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +45 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +63 -22
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +67 -20
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +66 -17
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_from_record.py +21 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +59 -19
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +61 -12
- kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +57 -14
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +3 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/get_output_path.py +15 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +12 -0
- kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +34 -0
- kevin_toolbox/patches/for_matplotlib/variable.py +20 -0
- kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
- kevin_toolbox_dev-1.4.9.dist-info/METADATA +75 -0
- {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/RECORD +42 -28
- kevin_toolbox_dev-1.4.7.dist-info/METADATA +0 -69
- /kevin_toolbox/{developing → computer_science/algorithm}/decorator/restore_original_work_path.py +0 -0
- {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
import warnings
|
2
|
+
from kevin_toolbox.network import get_response
|
3
|
+
|
4
|
+
|
5
|
+
def fetch_content(url=None, response=None, decoding=None, chunk_size=None, **kwargs):
|
6
|
+
"""
|
7
|
+
从 URL/response 中获取内容
|
8
|
+
|
9
|
+
参数:
|
10
|
+
url: <str> 请求的 URL 地址。
|
11
|
+
response: 响应。
|
12
|
+
以上两个参数只需要指定其一即可,建议使用后者。
|
13
|
+
decoding: <str> 响应内容的解码方式
|
14
|
+
默认为 None,返回原始的字节流。
|
15
|
+
一些常用的可选值:
|
16
|
+
"utf-8"
|
17
|
+
chunk_size: <int> 采用分块方式读取内容时,块的大小
|
18
|
+
默认为 None,此时不使用分块读取,而直接读取所有内容
|
19
|
+
注意!当 chunk_size 非 None 时,decoding 将失效
|
20
|
+
|
21
|
+
返回:
|
22
|
+
当 chunk_size 为 None 时:
|
23
|
+
str: 请求成功后的响应内容。如果 decoding 为 None,则返回 bytes 类型数据。
|
24
|
+
当 chunk_size 非 None 时:
|
25
|
+
读取内容的生成器。
|
26
|
+
"""
|
27
|
+
assert url is not None or response is not None
|
28
|
+
if url is not None:
|
29
|
+
response = response or get_response(url=url, **kwargs)
|
30
|
+
assert response is not None
|
31
|
+
|
32
|
+
if chunk_size is None:
|
33
|
+
content = response.data
|
34
|
+
if decoding:
|
35
|
+
content = content.decode(decoding)
|
36
|
+
return content
|
37
|
+
else:
|
38
|
+
if decoding:
|
39
|
+
warnings.warn(f'当 chunk_size 非 None 时,decoding 参数将失效。')
|
40
|
+
return __generator(response, chunk_size)
|
41
|
+
|
42
|
+
|
43
|
+
def __generator(response, chunk_size):
|
44
|
+
while True:
|
45
|
+
chunk = response.read(chunk_size)
|
46
|
+
if not chunk:
|
47
|
+
break
|
48
|
+
yield chunk
|
49
|
+
|
50
|
+
|
51
|
+
if __name__ == "__main__":
|
52
|
+
url_ = "https://i.pinimg.com/736x/28/6a/b1/286ab1eb816dc59a1c72374c75645d80.jpg" # "https://www.google.com/"
|
53
|
+
print(len(fetch_content(url=url_, decoding=None)))
|
54
|
+
for i, j in enumerate(fetch_content(url=url_, chunk_size=50000)):
|
55
|
+
print(i, len(j))
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import time
|
4
|
+
import mimetypes
|
5
|
+
from urllib.parse import quote
|
6
|
+
from urllib.parse import urlsplit, unquote
|
7
|
+
from kevin_toolbox.network import get_response
|
8
|
+
|
9
|
+
|
10
|
+
def fetch_metadata(url=None, response=None, default_suffix=".bin", default_name=None, **kwargs):
|
11
|
+
"""
|
12
|
+
从 URL/response 中获取文件名、后缀(扩展名)、大小等元信息
|
13
|
+
|
14
|
+
参数:
|
15
|
+
url: <str> 请求的 URL 地址。
|
16
|
+
response: 响应。
|
17
|
+
以上两个参数建议同时指定。
|
18
|
+
default_suffix: <str> 默认后缀。
|
19
|
+
default_name: <str> 默认文件名。
|
20
|
+
默认为 None,表示使用当前时间戳作为默认文件名。
|
21
|
+
只有从 URL/response 中无法获取出 suffix 和 name 时才会使用上面的默认值作为填充
|
22
|
+
|
23
|
+
返回:
|
24
|
+
dict with keys ['content_length', 'content_type', 'content_disp', 'suffix', 'name']
|
25
|
+
"""
|
26
|
+
assert url is not None or response is not None
|
27
|
+
if url is not None:
|
28
|
+
response = response or get_response(url=url, **kwargs)
|
29
|
+
assert response is not None
|
30
|
+
default_name = f'{time.time()}' if default_name is None else default_name
|
31
|
+
|
32
|
+
metadata_s = {"content_length": None, "content_type": None, "content_disp": None, "suffix": None, "name": None}
|
33
|
+
name, suffix = None, None
|
34
|
+
# 尝试直接从url中获取文件名和后缀
|
35
|
+
if url is not None:
|
36
|
+
url = quote(url, safe='/:?=&')
|
37
|
+
basename = unquote(os.path.basename(urlsplit(url).path))
|
38
|
+
name, suffix = os.path.splitext(basename)
|
39
|
+
# 尝试从更加可信的响应头中获取文件名和后缀
|
40
|
+
content_length = response.headers.get("Content-Length", None)
|
41
|
+
metadata_s["content_length"] = int(content_length) if content_length and content_length.isdigit() else None
|
42
|
+
#
|
43
|
+
content_type = response.headers.get("Content-Type", None)
|
44
|
+
metadata_s["content_type"] = content_type
|
45
|
+
if content_type:
|
46
|
+
suffix = mimetypes.guess_extension(content_type.split(";")[0].strip()) or suffix
|
47
|
+
#
|
48
|
+
content_disp = response.headers.get("Content-Disposition", None)
|
49
|
+
metadata_s["content_disp"] = content_disp
|
50
|
+
if content_disp:
|
51
|
+
temp_ls = re.findall('filename="([^"]+)"', content_disp)
|
52
|
+
if temp_ls:
|
53
|
+
name, temp = os.path.splitext(temp_ls[0])
|
54
|
+
suffix = temp or suffix
|
55
|
+
metadata_s["name"] = name or default_name
|
56
|
+
metadata_s["suffix"] = suffix or default_suffix
|
57
|
+
|
58
|
+
return metadata_s
|
59
|
+
|
60
|
+
|
61
|
+
# 示例用法
|
62
|
+
if __name__ == "__main__":
|
63
|
+
url_ = "https://i.pinimg.com/736x/28/6a/b1/286ab1eb816dc59a1c72374c75645d80.jpg" # "https://www.google.com/"
|
64
|
+
print(fetch_metadata(url=url_))
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import urllib3
|
2
|
+
from urllib.parse import quote
|
3
|
+
from kevin_toolbox.computer_science.algorithm.decorator import retry
|
4
|
+
from kevin_toolbox.network.variable import DEFAULT_HEADERS
|
5
|
+
|
6
|
+
# 全局 PoolManager 实例,设置 cert_reqs='CERT_NONE' 可关闭 SSL 证书验证
|
7
|
+
http = urllib3.PoolManager(cert_reqs='CERT_NONE')
|
8
|
+
|
9
|
+
|
10
|
+
def get_response(url, data=None, headers=None, method=None, retries=3, delay=0.5, b_verbose=False, stream=True,
|
11
|
+
**kwargs):
|
12
|
+
"""
|
13
|
+
获取 url 的响应
|
14
|
+
|
15
|
+
参数:
|
16
|
+
url: <str> 请求的 URL 地址。
|
17
|
+
data: <bytes, optional> 请求发送的数据,如果需要传递数据,必须是字节类型。
|
18
|
+
headers: <dict> 请求头字典。
|
19
|
+
默认为 DEFAULT_HEADERS。
|
20
|
+
method: <str> HTTP 请求方法,如 "GET", "POST" 等。
|
21
|
+
retries: <int> 重试次数
|
22
|
+
默认重试3次
|
23
|
+
delay: <int/float> 每次重试前等待的秒数。
|
24
|
+
默认0.5秒
|
25
|
+
b_verbose: <boolean> 进行多次重试时是否打印详细日志信息。
|
26
|
+
默认为 False。
|
27
|
+
|
28
|
+
返回:
|
29
|
+
响应。 urllib3.response.HTTPResponse object
|
30
|
+
"""
|
31
|
+
headers = headers or DEFAULT_HEADERS
|
32
|
+
|
33
|
+
url = quote(url, safe='/:?=&')
|
34
|
+
worker = retry(retries=retries, delay=delay, logger="default" if b_verbose else None)(func=__worker)
|
35
|
+
response = worker(url, data, headers, method, stream)
|
36
|
+
return response
|
37
|
+
|
38
|
+
|
39
|
+
def __worker(url, data, headers, method, stream):
|
40
|
+
method = method if method is not None else "GET"
|
41
|
+
response = http.request(method, url, body=data, headers=headers, preload_content=not stream)
|
42
|
+
if response.status >= 400:
|
43
|
+
raise Exception(f"HTTP 请求失败,状态码:{response.status}")
|
44
|
+
return response
|
45
|
+
|
46
|
+
|
47
|
+
if __name__ == "__main__":
|
48
|
+
url_ = "https://i.pinimg.com/736x/28/6a/b1/286ab1eb816dc59a1c72374c75645d80.jpg" # "https://www.google.com/"
|
49
|
+
a = get_response(url=url_, b_verbose=True)
|
50
|
+
print(a)
|
@@ -49,7 +49,7 @@ def build_logger(name, handler_ls, level=logging.DEBUG,
|
|
49
49
|
# 输出到控制台
|
50
50
|
handler = logging.StreamHandler()
|
51
51
|
else:
|
52
|
-
raise ValueError(f'unexpected target {target}')
|
52
|
+
raise ValueError(f'unexpected target {details["target"]}')
|
53
53
|
handler.setLevel(details.get("level", level))
|
54
54
|
handler.setFormatter(logging.Formatter(details.get("formatter", formatter)))
|
55
55
|
# 添加到logger中
|
@@ -6,8 +6,6 @@ def hex_to_rgba(hex_color):
|
|
6
6
|
assert len(hex_color) in (6, 8), \
|
7
7
|
f'hex_color should be 6 or 8 characters long (not including #). but got {len(hex_color)}'
|
8
8
|
res = list(int(hex_color[i * 2:i * 2 + 2], 16) for i in range(len(hex_color) // 2))
|
9
|
-
if len(res) not in (3, 4):
|
10
|
-
breakpoint()
|
11
9
|
if len(res) == 4:
|
12
10
|
res[3] /= 255
|
13
11
|
return tuple(res)
|
@@ -1,6 +1,51 @@
|
|
1
|
+
# 在非 windows 系统下,尝试自动下载中文字体,并尝试自动设置字体
|
2
|
+
import sys
|
3
|
+
|
4
|
+
if not sys.platform.startswith("win"):
|
5
|
+
import os
|
6
|
+
from kevin_toolbox.env_info.variable_ import env_vars_parser
|
7
|
+
|
8
|
+
font_setting_s = dict(
|
9
|
+
b_auto_download=True,
|
10
|
+
download_url="https://drive.usercontent.google.com/download?id=1wd-a4-AwAXkr7mHmB9BAIcGpAcqMrbqg&export=download&authuser=0",
|
11
|
+
save_path="~/.kvt_data/fonts/SimHei.ttf"
|
12
|
+
)
|
13
|
+
font_setting_s.update()
|
14
|
+
for k, v in list(font_setting_s.items()):
|
15
|
+
font_setting_s[k] = env_vars_parser.parse(
|
16
|
+
text=f"${{KVT_PATCHES:for_matplotlib:common_charts:font_settings:for_non-windows-platform:{k}}}",
|
17
|
+
default=v
|
18
|
+
)
|
19
|
+
save_path = os.path.expanduser(font_setting_s["save_path"])
|
20
|
+
|
21
|
+
if font_setting_s["b_auto_download"] and not os.path.isfile(save_path):
|
22
|
+
from kevin_toolbox.network import download_file
|
23
|
+
|
24
|
+
print(f'检测到当前系统非 Windows 系统,尝试自动下载中文字体...')
|
25
|
+
download_file(
|
26
|
+
output_dir=os.path.dirname(save_path),
|
27
|
+
file_name=os.path.basename(save_path),
|
28
|
+
url=font_setting_s["download_url"],
|
29
|
+
b_display_progress=True
|
30
|
+
)
|
31
|
+
|
32
|
+
if os.path.isfile(save_path):
|
33
|
+
import matplotlib.font_manager as fm
|
34
|
+
import matplotlib.pyplot as plt
|
35
|
+
|
36
|
+
# 注册字体
|
37
|
+
fm.fontManager.addfont(save_path)
|
38
|
+
# 获取字体名称
|
39
|
+
font_name = fm.FontProperties(fname=save_path).get_name()
|
40
|
+
|
41
|
+
# 全局设置默认字体
|
42
|
+
plt.rcParams['font.family'] = font_name
|
43
|
+
plt.rcParams['axes.unicode_minus'] = False # 解决负号 '-' 显示为方块的问题
|
44
|
+
|
1
45
|
from .plot_lines import plot_lines
|
2
46
|
from .plot_scatters import plot_scatters
|
3
47
|
from .plot_distribution import plot_distribution
|
4
48
|
from .plot_bars import plot_bars
|
5
49
|
from .plot_scatters_matrix import plot_scatters_matrix
|
6
50
|
from .plot_confusion_matrix import plot_confusion_matrix
|
51
|
+
from .plot_from_record import plot_from_record
|
@@ -1,20 +1,62 @@
|
|
1
|
-
import os
|
2
|
-
import copy
|
3
|
-
from kevin_toolbox.computer_science.algorithm import for_seq
|
4
1
|
import matplotlib.pyplot as plt
|
5
|
-
from kevin_toolbox.
|
2
|
+
from kevin_toolbox.computer_science.algorithm import for_seq
|
3
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
|
4
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
5
|
+
|
6
|
+
__name = ":common_charts:plot_bars"
|
7
|
+
|
6
8
|
|
7
|
-
|
8
|
-
|
9
|
-
|
9
|
+
@COMMON_CHARTS.register(name=__name)
|
10
|
+
def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs):
|
11
|
+
"""
|
12
|
+
绘制条形图
|
10
13
|
|
14
|
+
参数:
|
15
|
+
data_s: <dict> 数据。
|
16
|
+
形如 {<data_name>: <data list>, ...} 的字典
|
17
|
+
title: <str> 绘图标题。
|
18
|
+
x_name: <str> 以哪个 data_name 作为 x 轴。
|
19
|
+
其余数据视为需要被绘制的数据点。
|
20
|
+
例子: data_s={"step":[...], "acc_top1":[...], "acc_top3":[...]}
|
21
|
+
当 x_name="step" 时,将会以 step 为 x 轴绘制 acc_top1 和 acc_top3 的 bar 图。
|
22
|
+
x_label: <str> x 轴的标签名称。
|
23
|
+
默认与指定的 x_name 相同。
|
24
|
+
y_label: <str> y 轴的标签名称。
|
25
|
+
默认为 "value"。
|
26
|
+
output_dir: <str or None> 图片输出目录。
|
27
|
+
output_path: <str or None> 图片输出路径。
|
28
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
29
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
30
|
+
若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
|
31
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
11
32
|
|
12
|
-
|
13
|
-
|
33
|
+
其他可选参数:
|
34
|
+
dpi: <int> 保存图像的分辨率。
|
35
|
+
默认为 200。
|
36
|
+
suffix: <str> 图片保存后缀。
|
37
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
38
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
39
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
40
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
41
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
42
|
+
|
43
|
+
返回值:
|
44
|
+
若 output_dir 非 None,则返回图像保存的文件路径。
|
45
|
+
"""
|
46
|
+
assert x_name in data_s and len(data_s) >= 2
|
14
47
|
paras = {
|
15
|
-
"
|
48
|
+
"x_label": f'{x_name}',
|
49
|
+
"y_label": "value",
|
50
|
+
"dpi": 200,
|
51
|
+
"suffix": ".png",
|
52
|
+
"b_generate_record": False
|
16
53
|
}
|
17
54
|
paras.update(kwargs)
|
55
|
+
#
|
56
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
57
|
+
save_record(_func=plot_bars, _name=__name, _output_path=_output_path if paras["b_generate_record"] else None,
|
58
|
+
**paras)
|
59
|
+
data_s = data_s.copy()
|
18
60
|
|
19
61
|
plt.clf()
|
20
62
|
#
|
@@ -26,29 +68,28 @@ def plot_bars(data_s, title, x_name, y_label=None, output_dir=None, **kwargs):
|
|
26
68
|
else:
|
27
69
|
plt.bar([j + 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
|
28
70
|
|
29
|
-
plt.xlabel(
|
30
|
-
plt.ylabel(
|
71
|
+
plt.xlabel(paras["x_label"])
|
72
|
+
plt.ylabel(paras["y_label"])
|
31
73
|
temp = for_seq.flatten_list([list(i) for i in data_s.values()])
|
32
74
|
y_min, y_max = min(temp), max(temp)
|
33
75
|
plt.ylim(max(min(y_min, 0), y_min - (y_max - y_min) * 0.2), y_max + (y_max - y_min) * 0.1)
|
34
|
-
plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls)
|
76
|
+
plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls)
|
35
77
|
plt.title(f'{title}')
|
36
78
|
# 显示图例
|
37
79
|
plt.legend()
|
38
80
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
else:
|
43
|
-
os.makedirs(output_dir, exist_ok=True)
|
44
|
-
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
45
|
-
plt.savefig(output_path, dpi=paras["dpi"])
|
46
|
-
return output_path
|
81
|
+
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
|
82
|
+
|
83
|
+
return _output_path
|
47
84
|
|
48
85
|
|
49
86
|
if __name__ == '__main__':
|
87
|
+
import os
|
88
|
+
|
50
89
|
plot_bars(data_s={
|
51
90
|
'a': [1.5, 2, 3, 4, 5],
|
52
91
|
'b': [5, 4, 3, 2, 1],
|
53
92
|
'c': [1, 2, 3, 4, 5]},
|
54
|
-
title='
|
93
|
+
title='test_plot_bars', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"),
|
94
|
+
b_generate_record=True
|
95
|
+
)
|
@@ -1,26 +1,74 @@
|
|
1
|
-
import os
|
2
1
|
import numpy as np
|
3
2
|
from sklearn.metrics import confusion_matrix
|
4
3
|
import matplotlib.pyplot as plt
|
5
4
|
import seaborn as sns
|
6
|
-
from kevin_toolbox.patches.
|
5
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
|
6
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
7
7
|
|
8
|
+
__name = ":common_charts:plot_confusion_matrix"
|
8
9
|
|
9
|
-
|
10
|
+
|
11
|
+
@COMMON_CHARTS.register(name=__name)
|
12
|
+
def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None, output_path=None,
|
10
13
|
replace_zero_division_with=0, **kwargs):
|
11
14
|
"""
|
12
15
|
计算并绘制混淆矩阵
|
13
16
|
|
14
17
|
参数:
|
15
|
-
|
16
|
-
|
18
|
+
data_s: <dict> 数据。
|
19
|
+
形如 {<data_name>: <data list>, ...} 的字典
|
20
|
+
title: <str> 绘图标题,同时用于保存图片的文件名。
|
21
|
+
gt_name: <str> 在 data_s 中表示真实标签数据的键名。
|
22
|
+
pd_name: <str> 在 data_s 中表示预测标签数据的键名。
|
23
|
+
label_to_value_s: <dict> 标签-取值映射字典。
|
24
|
+
如 {"cat": 0, "dog": 1})。
|
25
|
+
output_dir: <str or None>
|
26
|
+
图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
|
27
|
+
若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
|
28
|
+
|
29
|
+
output_dir: <str> 图片输出目录。
|
30
|
+
output_path: <str> 图片输出路径。
|
31
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
32
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
33
|
+
若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
|
34
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
35
|
+
replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
|
36
|
+
建议使用 np.nan 或 0,默认值为 0。
|
37
|
+
|
38
|
+
其他可选参数:
|
39
|
+
dpi: <int> 图像保存的分辨率。
|
40
|
+
suffix: <str> 图片保存后缀。
|
41
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
42
|
+
normalize: <str or None> 指定归一化方式。
|
43
|
+
可选值包括:
|
44
|
+
"true"(按真实标签归一化)
|
45
|
+
"pred"(按预测标签归一化)
|
46
|
+
"all"(整体归一化)
|
47
|
+
默认为 None 表示不归一化。
|
48
|
+
b_return_cfm: <bool> 是否在返回值中包含计算得到的混淆矩阵数据。
|
49
|
+
默认为 False。
|
50
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
51
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
52
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
53
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
54
|
+
|
55
|
+
返回值:
|
56
|
+
当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
|
17
57
|
"""
|
18
58
|
paras = {
|
19
59
|
"dpi": 200,
|
60
|
+
"suffix": ".png",
|
61
|
+
"b_generate_record": False,
|
20
62
|
"normalize": None, # "true", "pred", "all",
|
21
63
|
"b_return_cfm": False, # 是否输出混淆矩阵
|
22
64
|
}
|
23
65
|
paras.update(kwargs)
|
66
|
+
#
|
67
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
68
|
+
save_record(_func=plot_confusion_matrix, _name=__name,
|
69
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
70
|
+
**paras)
|
71
|
+
data_s = data_s.copy()
|
24
72
|
|
25
73
|
value_set = set(data_s[gt_name]).union(set(data_s[pd_name]))
|
26
74
|
if label_to_value_s is None:
|
@@ -57,28 +105,27 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
|
|
57
105
|
plt.ylabel(f'{gt_name}')
|
58
106
|
plt.title(f'{title}')
|
59
107
|
|
60
|
-
|
61
|
-
plt.show()
|
62
|
-
output_path = None
|
63
|
-
else:
|
64
|
-
os.makedirs(output_dir, exist_ok=True)
|
65
|
-
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
66
|
-
plt.savefig(output_path, dpi=paras["dpi"])
|
108
|
+
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
|
67
109
|
|
68
110
|
if paras["b_return_cfm"]:
|
69
|
-
return
|
111
|
+
return _output_path, cfm
|
70
112
|
else:
|
71
|
-
return
|
113
|
+
return _output_path
|
72
114
|
|
73
115
|
|
74
116
|
if __name__ == '__main__':
|
117
|
+
import os
|
118
|
+
|
75
119
|
# 示例真实标签和预测标签
|
76
120
|
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 5])
|
77
121
|
y_pred = np.array([0, 2, 1, 0, 2, 1, 0, 1, 1, 5])
|
78
122
|
|
79
|
-
plot_confusion_matrix(
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
123
|
+
plot_confusion_matrix(
|
124
|
+
data_s={'a': y_true, 'b': y_pred},
|
125
|
+
title='test_plot_confusion_matrix', gt_name='a', pd_name='b',
|
126
|
+
label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2, "E": 3},
|
127
|
+
output_dir=os.path.join(os.path.dirname(__file__), "temp"),
|
128
|
+
replace_zero_division_with=-1,
|
129
|
+
normalize="all",
|
130
|
+
b_generate_record=True
|
131
|
+
)
|
@@ -1,15 +1,66 @@
|
|
1
|
-
import os
|
2
1
|
import math
|
3
2
|
import matplotlib.pyplot as plt
|
4
3
|
import numpy as np
|
5
|
-
from kevin_toolbox.patches.
|
4
|
+
from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
|
5
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
6
6
|
|
7
|
+
__name = ":common_charts:plot_distribution"
|
7
8
|
|
8
|
-
|
9
|
+
|
10
|
+
@COMMON_CHARTS.register(name=__name)
|
11
|
+
def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, output_path=None,
|
12
|
+
**kwargs):
|
13
|
+
"""
|
14
|
+
概率分布图
|
15
|
+
支持以下几种绘图类型:
|
16
|
+
1. 数字数据,绘制概率分布图: type_ 参数为 "hist" 或 "histogram" 时。
|
17
|
+
2. 字符串数据,绘制概率直方图:type_ 参数为 "category" 或 "cate" 时。
|
18
|
+
|
19
|
+
参数:
|
20
|
+
data_s: <dict> 数据。
|
21
|
+
形如 {<data_name>: <data list>, ...} 的字典,
|
22
|
+
title: <str> 绘图标题,同时用于保存图片的文件名。
|
23
|
+
x_name: <str> 以哪个 data_name 作为待绘制数据。
|
24
|
+
x_name_ls: <list or tuple> 以多个 data_name 对应的多组数据在同一图中绘制多个概率分布图。
|
25
|
+
type_: <str> 指定绘图类型。
|
26
|
+
支持的取值有:
|
27
|
+
- "hist" 或 "histogram": 需要 <data list> 为数值数据,将绘制概率分布图。
|
28
|
+
需要进一步指定 steps 步长参数,
|
29
|
+
或者 min、max、bin_nums 参数。
|
30
|
+
- "category" 或 "cate": 需要 <data list> 为字符串数据,将绘制概率直方图。
|
31
|
+
output_dir: <str> 图片输出目录。
|
32
|
+
output_path: <str> 图片输出路径。
|
33
|
+
以上两个只需指定一个即可,同时指定时以后者为准。
|
34
|
+
当只有 output_dir 被指定时,将会以 title 作为图片名。
|
35
|
+
若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
|
36
|
+
在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
|
37
|
+
|
38
|
+
其他可选参数:
|
39
|
+
dpi: <int> 图像保存的分辨率。
|
40
|
+
suffix: <str> 图片保存后缀。
|
41
|
+
目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
|
42
|
+
b_generate_record: <boolean> 是否保存函数参数为档案。
|
43
|
+
默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
|
44
|
+
后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
|
45
|
+
该参数仅在 output_dir 和 output_path 非 None 时起效。
|
46
|
+
|
47
|
+
返回值:
|
48
|
+
<str> 图像保存的完整文件路径。如果 output_dir 或 output_path 被指定,
|
49
|
+
则图像会保存到对应位置并返回保存路径;否则可能直接显示图像,
|
50
|
+
返回值依赖于 save_plot 函数的具体实现。
|
51
|
+
"""
|
9
52
|
paras = {
|
10
|
-
"dpi": 200
|
53
|
+
"dpi": 200,
|
54
|
+
"suffix": ".png",
|
55
|
+
"b_generate_record": False
|
11
56
|
}
|
12
57
|
paras.update(kwargs)
|
58
|
+
#
|
59
|
+
_output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
|
60
|
+
save_record(_func=plot_distribution, _name=__name,
|
61
|
+
_output_path=_output_path if paras["b_generate_record"] else None,
|
62
|
+
**paras)
|
63
|
+
data_s = data_s.copy()
|
13
64
|
if x_name is not None:
|
14
65
|
x_name_ls = [x_name, ]
|
15
66
|
assert isinstance(x_name_ls, (list, tuple)) and len(x_name_ls) > 0
|
@@ -47,19 +98,17 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
|
|
47
98
|
# 显示图例
|
48
99
|
plt.legend()
|
49
100
|
|
50
|
-
|
51
|
-
|
52
|
-
return None
|
53
|
-
else:
|
54
|
-
os.makedirs(output_dir, exist_ok=True)
|
55
|
-
output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
|
56
|
-
plt.savefig(output_path, dpi=paras["dpi"])
|
57
|
-
return output_path
|
101
|
+
save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
|
102
|
+
return _output_path
|
58
103
|
|
59
104
|
|
60
105
|
if __name__ == '__main__':
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
106
|
+
import os
|
107
|
+
|
108
|
+
plot_distribution(
|
109
|
+
data_s={
|
110
|
+
'a': [1, 2, 3, 4, 5, 3, 2, 1],
|
111
|
+
'c': [1, 2, 3, 4, 5, 0, 0, 0]},
|
112
|
+
title='test_plot_distribution', x_name_ls=['a', 'c'], type_="category",
|
113
|
+
output_dir=os.path.join(os.path.dirname(__file__), "temp")
|
114
|
+
)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
|
2
|
+
|
3
|
+
|
4
|
+
def plot_from_record(input_path, **kwargs):
|
5
|
+
"""
|
6
|
+
从 record 中恢复并绘制图像
|
7
|
+
支持通过 **kwargs 对其中部分参数进行覆盖
|
8
|
+
"""
|
9
|
+
from kevin_toolbox.computer_science.algorithm.registration import Serializer_for_Registry_Execution
|
10
|
+
|
11
|
+
serializer = Serializer_for_Registry_Execution()
|
12
|
+
serializer.load(input_path)
|
13
|
+
serializer.record_s["kwargs"].update(kwargs)
|
14
|
+
return serializer.recover()()
|
15
|
+
|
16
|
+
|
17
|
+
if __name__ == '__main__':
|
18
|
+
import os
|
19
|
+
|
20
|
+
plot_from_record(input_path=os.path.join(os.path.dirname(__file__), "temp/好-吧.png.record.tar"),
|
21
|
+
output_dir=os.path.join(os.path.dirname(__file__), "temp/recover"))
|