kevin-toolbox-dev 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/{developing → computer_science/algorithm}/decorator/__init__.py +2 -1
  3. kevin_toolbox/computer_science/algorithm/decorator/retry.py +62 -0
  4. kevin_toolbox/computer_science/algorithm/registration/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/registration/serializer_for_registry_execution.py +82 -0
  6. kevin_toolbox/data_flow/core/cache/cache_manager_for_iterator.py +1 -1
  7. kevin_toolbox/data_flow/file/json_/write_json.py +2 -1
  8. kevin_toolbox/env_info/check_version_and_update.py +0 -1
  9. kevin_toolbox/env_info/variable_/env_vars_parser.py +17 -2
  10. kevin_toolbox/nested_dict_list/copy_.py +4 -2
  11. kevin_toolbox/nested_dict_list/get_nodes.py +4 -2
  12. kevin_toolbox/nested_dict_list/serializer/variable.py +14 -2
  13. kevin_toolbox/nested_dict_list/serializer/write.py +2 -0
  14. kevin_toolbox/nested_dict_list/traverse.py +75 -21
  15. kevin_toolbox/nested_dict_list/value_parser/replace_identical_with_reference.py +1 -4
  16. kevin_toolbox/network/__init__.py +10 -0
  17. kevin_toolbox/network/download_file.py +120 -0
  18. kevin_toolbox/network/fetch_content.py +55 -0
  19. kevin_toolbox/network/fetch_metadata.py +64 -0
  20. kevin_toolbox/network/get_response.py +50 -0
  21. kevin_toolbox/network/variable.py +6 -0
  22. kevin_toolbox/patches/for_logging/build_logger.py +1 -1
  23. kevin_toolbox/patches/for_matplotlib/color/convert_format.py +0 -2
  24. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +45 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +63 -22
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +67 -20
  27. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +66 -17
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_from_record.py +21 -0
  29. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +59 -19
  30. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +61 -12
  31. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +57 -14
  32. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +3 -0
  33. kevin_toolbox/patches/for_matplotlib/common_charts/utils/get_output_path.py +15 -0
  34. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +12 -0
  35. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +34 -0
  36. kevin_toolbox/patches/for_matplotlib/variable.py +20 -0
  37. kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
  38. kevin_toolbox_dev-1.4.9.dist-info/METADATA +75 -0
  39. {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/RECORD +42 -28
  40. kevin_toolbox_dev-1.4.7.dist-info/METADATA +0 -69
  41. /kevin_toolbox/{developing → computer_science/algorithm}/decorator/restore_original_work_path.py +0 -0
  42. {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/WHEEL +0 -0
  43. {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,55 @@
1
+ import warnings
2
+ from kevin_toolbox.network import get_response
3
+
4
+
5
+ def fetch_content(url=None, response=None, decoding=None, chunk_size=None, **kwargs):
6
+ """
7
+ 从 URL/response 中获取内容
8
+
9
+ 参数:
10
+ url: <str> 请求的 URL 地址。
11
+ response: 响应。
12
+ 以上两个参数只需要指定其一即可,建议使用后者。
13
+ decoding: <str> 响应内容的解码方式
14
+ 默认为 None,返回原始的字节流。
15
+ 一些常用的可选值:
16
+ "utf-8"
17
+ chunk_size: <int> 采用分块方式读取内容时,块的大小
18
+ 默认为 None,此时不使用分块读取,而直接读取所有内容
19
+ 注意!当 chunk_size 非 None 时,decoding 将失效
20
+
21
+ 返回:
22
+ 当 chunk_size 为 None 时:
23
+ str: 请求成功后的响应内容。如果 decoding 为 None,则返回 bytes 类型数据。
24
+ 当 chunk_size 非 None 时:
25
+ 读取内容的生成器。
26
+ """
27
+ assert url is not None or response is not None
28
+ if url is not None:
29
+ response = response or get_response(url=url, **kwargs)
30
+ assert response is not None
31
+
32
+ if chunk_size is None:
33
+ content = response.data
34
+ if decoding:
35
+ content = content.decode(decoding)
36
+ return content
37
+ else:
38
+ if decoding:
39
+ warnings.warn(f'当 chunk_size 非 None 时,decoding 参数将失效。')
40
+ return __generator(response, chunk_size)
41
+
42
+
43
+ def __generator(response, chunk_size):
44
+ while True:
45
+ chunk = response.read(chunk_size)
46
+ if not chunk:
47
+ break
48
+ yield chunk
49
+
50
+
51
+ if __name__ == "__main__":
52
+ url_ = "https://i.pinimg.com/736x/28/6a/b1/286ab1eb816dc59a1c72374c75645d80.jpg" # "https://www.google.com/"
53
+ print(len(fetch_content(url=url_, decoding=None)))
54
+ for i, j in enumerate(fetch_content(url=url_, chunk_size=50000)):
55
+ print(i, len(j))
@@ -0,0 +1,64 @@
1
+ import os
2
+ import re
3
+ import time
4
+ import mimetypes
5
+ from urllib.parse import quote
6
+ from urllib.parse import urlsplit, unquote
7
+ from kevin_toolbox.network import get_response
8
+
9
+
10
+ def fetch_metadata(url=None, response=None, default_suffix=".bin", default_name=None, **kwargs):
11
+ """
12
+ 从 URL/response 中获取文件名、后缀(扩展名)、大小等元信息
13
+
14
+ 参数:
15
+ url: <str> 请求的 URL 地址。
16
+ response: 响应。
17
+ 以上两个参数建议同时指定。
18
+ default_suffix: <str> 默认后缀。
19
+ default_name: <str> 默认文件名。
20
+ 默认为 None,表示使用当前时间戳作为默认文件名。
21
+ 只有从 URL/response 中无法获取出 suffix 和 name 时才会使用上面的默认值作为填充
22
+
23
+ 返回:
24
+ dict with keys ['content_length', 'content_type', 'content_disp', 'suffix', 'name']
25
+ """
26
+ assert url is not None or response is not None
27
+ if url is not None:
28
+ response = response or get_response(url=url, **kwargs)
29
+ assert response is not None
30
+ default_name = f'{time.time()}' if default_name is None else default_name
31
+
32
+ metadata_s = {"content_length": None, "content_type": None, "content_disp": None, "suffix": None, "name": None}
33
+ name, suffix = None, None
34
+ # 尝试直接从url中获取文件名和后缀
35
+ if url is not None:
36
+ url = quote(url, safe='/:?=&')
37
+ basename = unquote(os.path.basename(urlsplit(url).path))
38
+ name, suffix = os.path.splitext(basename)
39
+ # 尝试从更加可信的响应头中获取文件名和后缀
40
+ content_length = response.headers.get("Content-Length", None)
41
+ metadata_s["content_length"] = int(content_length) if content_length and content_length.isdigit() else None
42
+ #
43
+ content_type = response.headers.get("Content-Type", None)
44
+ metadata_s["content_type"] = content_type
45
+ if content_type:
46
+ suffix = mimetypes.guess_extension(content_type.split(";")[0].strip()) or suffix
47
+ #
48
+ content_disp = response.headers.get("Content-Disposition", None)
49
+ metadata_s["content_disp"] = content_disp
50
+ if content_disp:
51
+ temp_ls = re.findall('filename="([^"]+)"', content_disp)
52
+ if temp_ls:
53
+ name, temp = os.path.splitext(temp_ls[0])
54
+ suffix = temp or suffix
55
+ metadata_s["name"] = name or default_name
56
+ metadata_s["suffix"] = suffix or default_suffix
57
+
58
+ return metadata_s
59
+
60
+
61
+ # 示例用法
62
+ if __name__ == "__main__":
63
+ url_ = "https://i.pinimg.com/736x/28/6a/b1/286ab1eb816dc59a1c72374c75645d80.jpg" # "https://www.google.com/"
64
+ print(fetch_metadata(url=url_))
@@ -0,0 +1,50 @@
1
+ import urllib3
2
+ from urllib.parse import quote
3
+ from kevin_toolbox.computer_science.algorithm.decorator import retry
4
+ from kevin_toolbox.network.variable import DEFAULT_HEADERS
5
+
6
+ # 全局 PoolManager 实例,设置 cert_reqs='CERT_NONE' 可关闭 SSL 证书验证
7
+ http = urllib3.PoolManager(cert_reqs='CERT_NONE')
8
+
9
+
10
+ def get_response(url, data=None, headers=None, method=None, retries=3, delay=0.5, b_verbose=False, stream=True,
11
+ **kwargs):
12
+ """
13
+ 获取 url 的响应
14
+
15
+ 参数:
16
+ url: <str> 请求的 URL 地址。
17
+ data: <bytes, optional> 请求发送的数据,如果需要传递数据,必须是字节类型。
18
+ headers: <dict> 请求头字典。
19
+ 默认为 DEFAULT_HEADERS。
20
+ method: <str> HTTP 请求方法,如 "GET", "POST" 等。
21
+ retries: <int> 重试次数
22
+ 默认重试3次
23
+ delay: <int/float> 每次重试前等待的秒数。
24
+ 默认0.5秒
25
+ b_verbose: <boolean> 进行多次重试时是否打印详细日志信息。
26
+ 默认为 False。
27
+
28
+ 返回:
29
+ 响应。 urllib3.response.HTTPResponse object
30
+ """
31
+ headers = headers or DEFAULT_HEADERS
32
+
33
+ url = quote(url, safe='/:?=&')
34
+ worker = retry(retries=retries, delay=delay, logger="default" if b_verbose else None)(func=__worker)
35
+ response = worker(url, data, headers, method, stream)
36
+ return response
37
+
38
+
39
+ def __worker(url, data, headers, method, stream):
40
+ method = method if method is not None else "GET"
41
+ response = http.request(method, url, body=data, headers=headers, preload_content=not stream)
42
+ if response.status >= 400:
43
+ raise Exception(f"HTTP 请求失败,状态码:{response.status}")
44
+ return response
45
+
46
+
47
+ if __name__ == "__main__":
48
+ url_ = "https://i.pinimg.com/736x/28/6a/b1/286ab1eb816dc59a1c72374c75645d80.jpg" # "https://www.google.com/"
49
+ a = get_response(url=url_, b_verbose=True)
50
+ print(a)
@@ -0,0 +1,6 @@
1
+ DEFAULT_HEADERS = {
2
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
3
+ "(KHTML, like Gecko) Chrome/70.0.3538.110 Safari/537.36",
4
+ "Upgrade-Insecure-Requests": "1",
5
+ # "Cookie": ''
6
+ }
@@ -49,7 +49,7 @@ def build_logger(name, handler_ls, level=logging.DEBUG,
49
49
  # 输出到控制台
50
50
  handler = logging.StreamHandler()
51
51
  else:
52
- raise ValueError(f'unexpected target {target}')
52
+ raise ValueError(f'unexpected target {details["target"]}')
53
53
  handler.setLevel(details.get("level", level))
54
54
  handler.setFormatter(logging.Formatter(details.get("formatter", formatter)))
55
55
  # 添加到logger中
@@ -6,8 +6,6 @@ def hex_to_rgba(hex_color):
6
6
  assert len(hex_color) in (6, 8), \
7
7
  f'hex_color should be 6 or 8 characters long (not including #). but got {len(hex_color)}'
8
8
  res = list(int(hex_color[i * 2:i * 2 + 2], 16) for i in range(len(hex_color) // 2))
9
- if len(res) not in (3, 4):
10
- breakpoint()
11
9
  if len(res) == 4:
12
10
  res[3] /= 255
13
11
  return tuple(res)
@@ -1,6 +1,51 @@
1
+ # 在非 windows 系统下,尝试自动下载中文字体,并尝试自动设置字体
2
+ import sys
3
+
4
+ if not sys.platform.startswith("win"):
5
+ import os
6
+ from kevin_toolbox.env_info.variable_ import env_vars_parser
7
+
8
+ font_setting_s = dict(
9
+ b_auto_download=True,
10
+ download_url="https://drive.usercontent.google.com/download?id=1wd-a4-AwAXkr7mHmB9BAIcGpAcqMrbqg&export=download&authuser=0",
11
+ save_path="~/.kvt_data/fonts/SimHei.ttf"
12
+ )
13
+ font_setting_s.update()
14
+ for k, v in list(font_setting_s.items()):
15
+ font_setting_s[k] = env_vars_parser.parse(
16
+ text=f"${{KVT_PATCHES:for_matplotlib:common_charts:font_settings:for_non-windows-platform:{k}}}",
17
+ default=v
18
+ )
19
+ save_path = os.path.expanduser(font_setting_s["save_path"])
20
+
21
+ if font_setting_s["b_auto_download"] and not os.path.isfile(save_path):
22
+ from kevin_toolbox.network import download_file
23
+
24
+ print(f'检测到当前系统非 Windows 系统,尝试自动下载中文字体...')
25
+ download_file(
26
+ output_dir=os.path.dirname(save_path),
27
+ file_name=os.path.basename(save_path),
28
+ url=font_setting_s["download_url"],
29
+ b_display_progress=True
30
+ )
31
+
32
+ if os.path.isfile(save_path):
33
+ import matplotlib.font_manager as fm
34
+ import matplotlib.pyplot as plt
35
+
36
+ # 注册字体
37
+ fm.fontManager.addfont(save_path)
38
+ # 获取字体名称
39
+ font_name = fm.FontProperties(fname=save_path).get_name()
40
+
41
+ # 全局设置默认字体
42
+ plt.rcParams['font.family'] = font_name
43
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号 '-' 显示为方块的问题
44
+
1
45
  from .plot_lines import plot_lines
2
46
  from .plot_scatters import plot_scatters
3
47
  from .plot_distribution import plot_distribution
4
48
  from .plot_bars import plot_bars
5
49
  from .plot_scatters_matrix import plot_scatters_matrix
6
50
  from .plot_confusion_matrix import plot_confusion_matrix
51
+ from .plot_from_record import plot_from_record
@@ -1,20 +1,62 @@
1
- import os
2
- import copy
3
- from kevin_toolbox.computer_science.algorithm import for_seq
4
1
  import matplotlib.pyplot as plt
5
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
2
+ from kevin_toolbox.computer_science.algorithm import for_seq
3
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
4
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
5
+
6
+ __name = ":common_charts:plot_bars"
7
+
6
8
 
7
- # TODO 在 linux 系统下遇到中文时,尝试自动下载中文字体,并尝试自动设置字体
8
- # font_path = os.path.join(root_dir, "utils/SimHei.ttf")
9
- # font_name = FontProperties(fname=font_path)
9
+ @COMMON_CHARTS.register(name=__name)
10
+ def plot_bars(data_s, title, x_name, output_dir=None, output_path=None, **kwargs):
11
+ """
12
+ 绘制条形图
10
13
 
14
+ 参数:
15
+ data_s: <dict> 数据。
16
+ 形如 {<data_name>: <data list>, ...} 的字典
17
+ title: <str> 绘图标题。
18
+ x_name: <str> 以哪个 data_name 作为 x 轴。
19
+ 其余数据视为需要被绘制的数据点。
20
+ 例子: data_s={"step":[...], "acc_top1":[...], "acc_top3":[...]}
21
+ 当 x_name="step" 时,将会以 step 为 x 轴绘制 acc_top1 和 acc_top3 的 bar 图。
22
+ x_label: <str> x 轴的标签名称。
23
+ 默认与指定的 x_name 相同。
24
+ y_label: <str> y 轴的标签名称。
25
+ 默认为 "value"。
26
+ output_dir: <str or None> 图片输出目录。
27
+ output_path: <str or None> 图片输出路径。
28
+ 以上两个只需指定一个即可,同时指定时以后者为准。
29
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
30
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
31
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
11
32
 
12
- def plot_bars(data_s, title, x_name, y_label=None, output_dir=None, **kwargs):
13
- data_s = copy.deepcopy(data_s)
33
+ 其他可选参数:
34
+ dpi: <int> 保存图像的分辨率。
35
+ 默认为 200。
36
+ suffix: <str> 图片保存后缀。
37
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
38
+ b_generate_record: <boolean> 是否保存函数参数为档案。
39
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
40
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
41
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
42
+
43
+ 返回值:
44
+ 若 output_dir 非 None,则返回图像保存的文件路径。
45
+ """
46
+ assert x_name in data_s and len(data_s) >= 2
14
47
  paras = {
15
- "dpi": 200
48
+ "x_label": f'{x_name}',
49
+ "y_label": "value",
50
+ "dpi": 200,
51
+ "suffix": ".png",
52
+ "b_generate_record": False
16
53
  }
17
54
  paras.update(kwargs)
55
+ #
56
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
57
+ save_record(_func=plot_bars, _name=__name, _output_path=_output_path if paras["b_generate_record"] else None,
58
+ **paras)
59
+ data_s = data_s.copy()
18
60
 
19
61
  plt.clf()
20
62
  #
@@ -26,29 +68,28 @@ def plot_bars(data_s, title, x_name, y_label=None, output_dir=None, **kwargs):
26
68
  else:
27
69
  plt.bar([j + 0.1 for j in range(len(x_all_ls))], y_ls, width=0.2, align='center', label=k)
28
70
 
29
- plt.xlabel(f'{x_name}')
30
- plt.ylabel(f'{y_label if y_label else "value"}')
71
+ plt.xlabel(paras["x_label"])
72
+ plt.ylabel(paras["y_label"])
31
73
  temp = for_seq.flatten_list([list(i) for i in data_s.values()])
32
74
  y_min, y_max = min(temp), max(temp)
33
75
  plt.ylim(max(min(y_min, 0), y_min - (y_max - y_min) * 0.2), y_max + (y_max - y_min) * 0.1)
34
- plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls) # , fontproperties=font_name
76
+ plt.xticks(list(range(len(x_all_ls))), labels=x_all_ls)
35
77
  plt.title(f'{title}')
36
78
  # 显示图例
37
79
  plt.legend()
38
80
 
39
- if output_dir is None:
40
- plt.show()
41
- return None
42
- else:
43
- os.makedirs(output_dir, exist_ok=True)
44
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
45
- plt.savefig(output_path, dpi=paras["dpi"])
46
- return output_path
81
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
82
+
83
+ return _output_path
47
84
 
48
85
 
49
86
  if __name__ == '__main__':
87
+ import os
88
+
50
89
  plot_bars(data_s={
51
90
  'a': [1.5, 2, 3, 4, 5],
52
91
  'b': [5, 4, 3, 2, 1],
53
92
  'c': [1, 2, 3, 4, 5]},
54
- title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
93
+ title='test_plot_bars', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"),
94
+ b_generate_record=True
95
+ )
@@ -1,26 +1,74 @@
1
- import os
2
1
  import numpy as np
3
2
  from sklearn.metrics import confusion_matrix
4
3
  import matplotlib.pyplot as plt
5
4
  import seaborn as sns
6
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
6
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
7
7
 
8
+ __name = ":common_charts:plot_confusion_matrix"
8
9
 
9
- def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None,
10
+
11
+ @COMMON_CHARTS.register(name=__name)
12
+ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None, output_path=None,
10
13
  replace_zero_division_with=0, **kwargs):
11
14
  """
12
15
  计算并绘制混淆矩阵
13
16
 
14
17
  参数:
15
- replace_zero_division_with: <float> 对于在normalize时引发除0错误的矩阵元素,使用何种值进行替代
16
- 建议使用 np.nan 或者 0
18
+ data_s: <dict> 数据。
19
+ 形如 {<data_name>: <data list>, ...} 的字典
20
+ title: <str> 绘图标题,同时用于保存图片的文件名。
21
+ gt_name: <str> 在 data_s 中表示真实标签数据的键名。
22
+ pd_name: <str> 在 data_s 中表示预测标签数据的键名。
23
+ label_to_value_s: <dict> 标签-取值映射字典。
24
+ 如 {"cat": 0, "dog": 1})。
25
+ output_dir: <str or None>
26
+ 图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
27
+ 若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
28
+
29
+ output_dir: <str> 图片输出目录。
30
+ output_path: <str> 图片输出路径。
31
+ 以上两个只需指定一个即可,同时指定时以后者为准。
32
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
33
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
34
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
35
+ replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
36
+ 建议使用 np.nan 或 0,默认值为 0。
37
+
38
+ 其他可选参数:
39
+ dpi: <int> 图像保存的分辨率。
40
+ suffix: <str> 图片保存后缀。
41
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
42
+ normalize: <str or None> 指定归一化方式。
43
+ 可选值包括:
44
+ "true"(按真实标签归一化)
45
+ "pred"(按预测标签归一化)
46
+ "all"(整体归一化)
47
+ 默认为 None 表示不归一化。
48
+ b_return_cfm: <bool> 是否在返回值中包含计算得到的混淆矩阵数据。
49
+ 默认为 False。
50
+ b_generate_record: <boolean> 是否保存函数参数为档案。
51
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
52
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
53
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
54
+
55
+ 返回值:
56
+ 当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
17
57
  """
18
58
  paras = {
19
59
  "dpi": 200,
60
+ "suffix": ".png",
61
+ "b_generate_record": False,
20
62
  "normalize": None, # "true", "pred", "all",
21
63
  "b_return_cfm": False, # 是否输出混淆矩阵
22
64
  }
23
65
  paras.update(kwargs)
66
+ #
67
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
68
+ save_record(_func=plot_confusion_matrix, _name=__name,
69
+ _output_path=_output_path if paras["b_generate_record"] else None,
70
+ **paras)
71
+ data_s = data_s.copy()
24
72
 
25
73
  value_set = set(data_s[gt_name]).union(set(data_s[pd_name]))
26
74
  if label_to_value_s is None:
@@ -57,28 +105,27 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
57
105
  plt.ylabel(f'{gt_name}')
58
106
  plt.title(f'{title}')
59
107
 
60
- if output_dir is None:
61
- plt.show()
62
- output_path = None
63
- else:
64
- os.makedirs(output_dir, exist_ok=True)
65
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
66
- plt.savefig(output_path, dpi=paras["dpi"])
108
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
67
109
 
68
110
  if paras["b_return_cfm"]:
69
- return output_path, cfm
111
+ return _output_path, cfm
70
112
  else:
71
- return output_path
113
+ return _output_path
72
114
 
73
115
 
74
116
  if __name__ == '__main__':
117
+ import os
118
+
75
119
  # 示例真实标签和预测标签
76
120
  y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 5])
77
121
  y_pred = np.array([0, 2, 1, 0, 2, 1, 0, 1, 1, 5])
78
122
 
79
- plot_confusion_matrix(data_s={'a': y_true, 'b': y_pred},
80
- title='test', gt_name='a', pd_name='b',
81
- label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2, "E": 3},
82
- # output_dir=os.path.join(os.path.dirname(__file__), "temp"),
83
- replace_zero_division_with=-1,
84
- normalize="all")
123
+ plot_confusion_matrix(
124
+ data_s={'a': y_true, 'b': y_pred},
125
+ title='test_plot_confusion_matrix', gt_name='a', pd_name='b',
126
+ label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2, "E": 3},
127
+ output_dir=os.path.join(os.path.dirname(__file__), "temp"),
128
+ replace_zero_division_with=-1,
129
+ normalize="all",
130
+ b_generate_record=True
131
+ )
@@ -1,15 +1,66 @@
1
- import os
2
1
  import math
3
2
  import matplotlib.pyplot as plt
4
3
  import numpy as np
5
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
4
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
5
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
6
6
 
7
+ __name = ":common_charts:plot_distribution"
7
8
 
8
- def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, **kwargs):
9
+
10
+ @COMMON_CHARTS.register(name=__name)
11
+ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, output_path=None,
12
+ **kwargs):
13
+ """
14
+ 概率分布图
15
+ 支持以下几种绘图类型:
16
+ 1. 数字数据,绘制概率分布图: type_ 参数为 "hist" 或 "histogram" 时。
17
+ 2. 字符串数据,绘制概率直方图:type_ 参数为 "category" 或 "cate" 时。
18
+
19
+ 参数:
20
+ data_s: <dict> 数据。
21
+ 形如 {<data_name>: <data list>, ...} 的字典,
22
+ title: <str> 绘图标题,同时用于保存图片的文件名。
23
+ x_name: <str> 以哪个 data_name 作为待绘制数据。
24
+ x_name_ls: <list or tuple> 以多个 data_name 对应的多组数据在同一图中绘制多个概率分布图。
25
+ type_: <str> 指定绘图类型。
26
+ 支持的取值有:
27
+ - "hist" 或 "histogram": 需要 <data list> 为数值数据,将绘制概率分布图。
28
+ 需要进一步指定 steps 步长参数,
29
+ 或者 min、max、bin_nums 参数。
30
+ - "category" 或 "cate": 需要 <data list> 为字符串数据,将绘制概率直方图。
31
+ output_dir: <str> 图片输出目录。
32
+ output_path: <str> 图片输出路径。
33
+ 以上两个只需指定一个即可,同时指定时以后者为准。
34
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
35
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
36
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
37
+
38
+ 其他可选参数:
39
+ dpi: <int> 图像保存的分辨率。
40
+ suffix: <str> 图片保存后缀。
41
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
42
+ b_generate_record: <boolean> 是否保存函数参数为档案。
43
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
44
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
45
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
46
+
47
+ 返回值:
48
+ <str> 图像保存的完整文件路径。如果 output_dir 或 output_path 被指定,
49
+ 则图像会保存到对应位置并返回保存路径;否则可能直接显示图像,
50
+ 返回值依赖于 save_plot 函数的具体实现。
51
+ """
9
52
  paras = {
10
- "dpi": 200
53
+ "dpi": 200,
54
+ "suffix": ".png",
55
+ "b_generate_record": False
11
56
  }
12
57
  paras.update(kwargs)
58
+ #
59
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
+ save_record(_func=plot_distribution, _name=__name,
61
+ _output_path=_output_path if paras["b_generate_record"] else None,
62
+ **paras)
63
+ data_s = data_s.copy()
13
64
  if x_name is not None:
14
65
  x_name_ls = [x_name, ]
15
66
  assert isinstance(x_name_ls, (list, tuple)) and len(x_name_ls) > 0
@@ -47,19 +98,17 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
47
98
  # 显示图例
48
99
  plt.legend()
49
100
 
50
- if output_dir is None:
51
- plt.show()
52
- return None
53
- else:
54
- os.makedirs(output_dir, exist_ok=True)
55
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
56
- plt.savefig(output_path, dpi=paras["dpi"])
57
- return output_path
101
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
102
+ return _output_path
58
103
 
59
104
 
60
105
  if __name__ == '__main__':
61
- plot_distribution(data_s={
62
- 'a': [1, 2, 3, 4, 5, 3, 2, 1],
63
- 'c': [1, 2, 3, 4, 5, 0, 0, 0]},
64
- title='test', x_name_ls=['a', 'c'], type_="category",
65
- output_dir=os.path.join(os.path.dirname(__file__), "temp"))
106
+ import os
107
+
108
+ plot_distribution(
109
+ data_s={
110
+ 'a': [1, 2, 3, 4, 5, 3, 2, 1],
111
+ 'c': [1, 2, 3, 4, 5, 0, 0, 0]},
112
+ title='test_plot_distribution', x_name_ls=['a', 'c'], type_="category",
113
+ output_dir=os.path.join(os.path.dirname(__file__), "temp")
114
+ )
@@ -0,0 +1,21 @@
1
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
2
+
3
+
4
+ def plot_from_record(input_path, **kwargs):
5
+ """
6
+ 从 record 中恢复并绘制图像
7
+ 支持通过 **kwargs 对其中部分参数进行覆盖
8
+ """
9
+ from kevin_toolbox.computer_science.algorithm.registration import Serializer_for_Registry_Execution
10
+
11
+ serializer = Serializer_for_Registry_Execution()
12
+ serializer.load(input_path)
13
+ serializer.record_s["kwargs"].update(kwargs)
14
+ return serializer.recover()()
15
+
16
+
17
+ if __name__ == '__main__':
18
+ import os
19
+
20
+ plot_from_record(input_path=os.path.join(os.path.dirname(__file__), "temp/好-吧.png.record.tar"),
21
+ output_dir=os.path.join(os.path.dirname(__file__), "temp/recover"))