kevin-toolbox-dev 1.4.6__py3-none-any.whl → 1.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/{developing → computer_science/algorithm}/decorator/__init__.py +2 -1
  3. kevin_toolbox/computer_science/algorithm/decorator/retry.py +62 -0
  4. kevin_toolbox/computer_science/algorithm/parallel_and_concurrent/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/parallel_and_concurrent/multi_process_execute.py +109 -0
  6. kevin_toolbox/computer_science/algorithm/parallel_and_concurrent/multi_thread_execute.py +50 -29
  7. kevin_toolbox/computer_science/algorithm/parallel_and_concurrent/utils/__init__.py +15 -0
  8. kevin_toolbox/computer_science/algorithm/parallel_and_concurrent/utils/wrapper_with_timeout_1.py +69 -0
  9. kevin_toolbox/computer_science/algorithm/parallel_and_concurrent/utils/wrapper_with_timeout_2.py +76 -0
  10. kevin_toolbox/computer_science/algorithm/registration/__init__.py +1 -0
  11. kevin_toolbox/computer_science/algorithm/registration/serializer_for_registry_execution.py +82 -0
  12. kevin_toolbox/computer_science/data_structure/executor.py +2 -2
  13. kevin_toolbox/data_flow/core/cache/cache_manager_for_iterator.py +1 -1
  14. kevin_toolbox/data_flow/file/json_/write_json.py +36 -3
  15. kevin_toolbox/env_info/variable_/env_vars_parser.py +17 -2
  16. kevin_toolbox/nested_dict_list/serializer/backends/_json_.py +2 -2
  17. kevin_toolbox/nested_dict_list/serializer/variable.py +14 -2
  18. kevin_toolbox/nested_dict_list/serializer/write.py +2 -0
  19. kevin_toolbox/network/__init__.py +10 -0
  20. kevin_toolbox/network/download_file.py +120 -0
  21. kevin_toolbox/network/fetch_content.py +55 -0
  22. kevin_toolbox/network/fetch_metadata.py +64 -0
  23. kevin_toolbox/network/get_response.py +50 -0
  24. kevin_toolbox/network/variable.py +6 -0
  25. kevin_toolbox/patches/for_logging/build_logger.py +1 -1
  26. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +45 -0
  27. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +63 -22
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +67 -20
  29. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +66 -17
  30. kevin_toolbox/patches/for_matplotlib/common_charts/plot_from_record.py +21 -0
  31. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +63 -19
  32. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +61 -12
  33. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +57 -14
  34. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +3 -0
  35. kevin_toolbox/patches/for_matplotlib/common_charts/utils/get_output_path.py +15 -0
  36. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +11 -0
  37. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +34 -0
  38. kevin_toolbox/patches/for_matplotlib/variable.py +20 -0
  39. kevin_toolbox_dev-1.4.8.dist-info/METADATA +86 -0
  40. {kevin_toolbox_dev-1.4.6.dist-info → kevin_toolbox_dev-1.4.8.dist-info}/RECORD +43 -25
  41. kevin_toolbox_dev-1.4.6.dist-info/METADATA +0 -76
  42. /kevin_toolbox/{developing → computer_science/algorithm}/decorator/restore_original_work_path.py +0 -0
  43. {kevin_toolbox_dev-1.4.6.dist-info → kevin_toolbox_dev-1.4.8.dist-info}/WHEEL +0 -0
  44. {kevin_toolbox_dev-1.4.6.dist-info → kevin_toolbox_dev-1.4.8.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,74 @@
1
- import os
2
1
  import numpy as np
3
2
  from sklearn.metrics import confusion_matrix
4
3
  import matplotlib.pyplot as plt
5
4
  import seaborn as sns
6
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
6
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
7
7
 
8
+ __name = ":common_charts:plot_confusion_matrix"
8
9
 
9
- def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None,
10
+
11
+ @COMMON_CHARTS.register(name=__name)
12
+ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None, output_dir=None, output_path=None,
10
13
  replace_zero_division_with=0, **kwargs):
11
14
  """
12
15
  计算并绘制混淆矩阵
13
16
 
14
17
  参数:
15
- replace_zero_division_with: <float> 对于在normalize时引发除0错误的矩阵元素,使用何种值进行替代
16
- 建议使用 np.nan 或者 0
18
+ data_s: <dict> 数据。
19
+ 形如 {<data_name>: <data list>, ...} 的字典
20
+ title: <str> 绘图标题,同时用于保存图片的文件名。
21
+ gt_name: <str> 在 data_s 中表示真实标签数据的键名。
22
+ pd_name: <str> 在 data_s 中表示预测标签数据的键名。
23
+ label_to_value_s: <dict> 标签-取值映射字典。
24
+ 如 {"cat": 0, "dog": 1})。
25
+ output_dir: <str or None>
26
+ 图像保存的输出目录。如果同时指定了 output_path,则以 output_path 为准。
27
+ 若 output_dir 和 output_path 均未指定,则图像将直接通过 plt.show() 显示而不会保存到文件。
28
+
29
+ output_dir: <str> 图片输出目录。
30
+ output_path: <str> 图片输出路径。
31
+ 以上两个只需指定一个即可,同时指定时以后者为准。
32
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
33
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
34
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
35
+ replace_zero_division_with: <float> 在归一化混淆矩阵时,如果遇到除0错误的情况,将使用该值进行替代。
36
+ 建议使用 np.nan 或 0,默认值为 0。
37
+
38
+ 其他可选参数:
39
+ dpi: <int> 图像保存的分辨率。
40
+ suffix: <str> 图片保存后缀。
41
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
42
+ normalize: <str or None> 指定归一化方式。
43
+ 可选值包括:
44
+ "true"(按真实标签归一化)
45
+ "pred"(按预测标签归一化)
46
+ "all"(整体归一化)
47
+ 默认为 None 表示不归一化。
48
+ b_return_cfm: <bool> 是否在返回值中包含计算得到的混淆矩阵数据。
49
+ 默认为 False。
50
+ b_generate_record: <boolean> 是否保存函数参数为档案。
51
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
52
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
53
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
54
+
55
+ 返回值:
56
+ 当 b_return_cfm 为 True 时,返回值可能为一个包含 (图像路径, 混淆矩阵数据) 的元组。
17
57
  """
18
58
  paras = {
19
59
  "dpi": 200,
60
+ "suffix": ".png",
61
+ "b_generate_record": False,
20
62
  "normalize": None, # "true", "pred", "all",
21
63
  "b_return_cfm": False, # 是否输出混淆矩阵
22
64
  }
23
65
  paras.update(kwargs)
66
+ #
67
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
68
+ save_record(_func=plot_confusion_matrix, _name=__name,
69
+ _output_path=_output_path if paras["b_generate_record"] else None,
70
+ **paras)
71
+ data_s = data_s.copy()
24
72
 
25
73
  value_set = set(data_s[gt_name]).union(set(data_s[pd_name]))
26
74
  if label_to_value_s is None:
@@ -57,28 +105,27 @@ def plot_confusion_matrix(data_s, title, gt_name, pd_name, label_to_value_s=None
57
105
  plt.ylabel(f'{gt_name}')
58
106
  plt.title(f'{title}')
59
107
 
60
- if output_dir is None:
61
- plt.show()
62
- output_path = None
63
- else:
64
- os.makedirs(output_dir, exist_ok=True)
65
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
66
- plt.savefig(output_path, dpi=paras["dpi"])
108
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
67
109
 
68
110
  if paras["b_return_cfm"]:
69
- return output_path, cfm
111
+ return _output_path, cfm
70
112
  else:
71
- return output_path
113
+ return _output_path
72
114
 
73
115
 
74
116
  if __name__ == '__main__':
117
+ import os
118
+
75
119
  # 示例真实标签和预测标签
76
120
  y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 5])
77
121
  y_pred = np.array([0, 2, 1, 0, 2, 1, 0, 1, 1, 5])
78
122
 
79
- plot_confusion_matrix(data_s={'a': y_true, 'b': y_pred},
80
- title='test', gt_name='a', pd_name='b',
81
- label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2, "E": 3},
82
- # output_dir=os.path.join(os.path.dirname(__file__), "temp"),
83
- replace_zero_division_with=-1,
84
- normalize="all")
123
+ plot_confusion_matrix(
124
+ data_s={'a': y_true, 'b': y_pred},
125
+ title='test_plot_confusion_matrix', gt_name='a', pd_name='b',
126
+ label_to_value_s={"A": 5, "B": 0, "C": 1, "D": 2, "E": 3},
127
+ output_dir=os.path.join(os.path.dirname(__file__), "temp"),
128
+ replace_zero_division_with=-1,
129
+ normalize="all",
130
+ b_generate_record=True
131
+ )
@@ -1,15 +1,66 @@
1
- import os
2
1
  import math
3
2
  import matplotlib.pyplot as plt
4
3
  import numpy as np
5
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
4
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
5
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
6
6
 
7
+ __name = ":common_charts:plot_distribution"
7
8
 
8
- def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, **kwargs):
9
+
10
+ @COMMON_CHARTS.register(name=__name)
11
+ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist", output_dir=None, output_path=None,
12
+ **kwargs):
13
+ """
14
+ 概率分布图
15
+ 支持以下几种绘图类型:
16
+ 1. 数字数据,绘制概率分布图: type_ 参数为 "hist" 或 "histogram" 时。
17
+ 2. 字符串数据,绘制概率直方图:type_ 参数为 "category" 或 "cate" 时。
18
+
19
+ 参数:
20
+ data_s: <dict> 数据。
21
+ 形如 {<data_name>: <data list>, ...} 的字典,
22
+ title: <str> 绘图标题,同时用于保存图片的文件名。
23
+ x_name: <str> 以哪个 data_name 作为待绘制数据。
24
+ x_name_ls: <list or tuple> 以多个 data_name 对应的多组数据在同一图中绘制多个概率分布图。
25
+ type_: <str> 指定绘图类型。
26
+ 支持的取值有:
27
+ - "hist" 或 "histogram": 需要 <data list> 为数值数据,将绘制概率分布图。
28
+ 需要进一步指定 steps 步长参数,
29
+ 或者 min、max、bin_nums 参数。
30
+ - "category" 或 "cate": 需要 <data list> 为字符串数据,将绘制概率直方图。
31
+ output_dir: <str> 图片输出目录。
32
+ output_path: <str> 图片输出路径。
33
+ 以上两个只需指定一个即可,同时指定时以后者为准。
34
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
35
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
36
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
37
+
38
+ 其他可选参数:
39
+ dpi: <int> 图像保存的分辨率。
40
+ suffix: <str> 图片保存后缀。
41
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
42
+ b_generate_record: <boolean> 是否保存函数参数为档案。
43
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
44
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
45
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
46
+
47
+ 返回值:
48
+ <str> 图像保存的完整文件路径。如果 output_dir 或 output_path 被指定,
49
+ 则图像会保存到对应位置并返回保存路径;否则可能直接显示图像,
50
+ 返回值依赖于 save_plot 函数的具体实现。
51
+ """
9
52
  paras = {
10
- "dpi": 200
53
+ "dpi": 200,
54
+ "suffix": ".png",
55
+ "b_generate_record": False
11
56
  }
12
57
  paras.update(kwargs)
58
+ #
59
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
+ save_record(_func=plot_distribution, _name=__name,
61
+ _output_path=_output_path if paras["b_generate_record"] else None,
62
+ **paras)
63
+ data_s = data_s.copy()
13
64
  if x_name is not None:
14
65
  x_name_ls = [x_name, ]
15
66
  assert isinstance(x_name_ls, (list, tuple)) and len(x_name_ls) > 0
@@ -47,19 +98,17 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
47
98
  # 显示图例
48
99
  plt.legend()
49
100
 
50
- if output_dir is None:
51
- plt.show()
52
- return None
53
- else:
54
- os.makedirs(output_dir, exist_ok=True)
55
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
56
- plt.savefig(output_path, dpi=paras["dpi"])
57
- return output_path
101
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
102
+ return _output_path
58
103
 
59
104
 
60
105
  if __name__ == '__main__':
61
- plot_distribution(data_s={
62
- 'a': [1, 2, 3, 4, 5, 3, 2, 1],
63
- 'c': [1, 2, 3, 4, 5, 0, 0, 0]},
64
- title='test', x_name_ls=['a', 'c'], type_="category",
65
- output_dir=os.path.join(os.path.dirname(__file__), "temp"))
106
+ import os
107
+
108
+ plot_distribution(
109
+ data_s={
110
+ 'a': [1, 2, 3, 4, 5, 3, 2, 1],
111
+ 'c': [1, 2, 3, 4, 5, 0, 0, 0]},
112
+ title='test_plot_distribution', x_name_ls=['a', 'c'], type_="category",
113
+ output_dir=os.path.join(os.path.dirname(__file__), "temp")
114
+ )
@@ -0,0 +1,21 @@
1
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
2
+
3
+
4
+ def plot_from_record(input_path, **kwargs):
5
+ """
6
+ 从 record 中恢复并绘制图像
7
+ 支持通过 **kwargs 对其中部分参数进行覆盖
8
+ """
9
+ from kevin_toolbox.computer_science.algorithm.registration import Serializer_for_Registry_Execution
10
+
11
+ serializer = Serializer_for_Registry_Execution()
12
+ serializer.load(input_path)
13
+ serializer.record_s["kwargs"].update(kwargs)
14
+ return serializer.recover()()
15
+
16
+
17
+ if __name__ == '__main__':
18
+ import os
19
+
20
+ plot_from_record(input_path=os.path.join(os.path.dirname(__file__), "temp/好-吧.png.record.tar"),
21
+ output_dir=os.path.join(os.path.dirname(__file__), "temp/recover"))
@@ -1,15 +1,51 @@
1
- import os
2
- import copy
3
1
  import matplotlib.pyplot as plt
4
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
2
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
3
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
4
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
6
5
 
6
+ __name = ":common_charts:plot_lines"
7
7
 
8
- def plot_lines(data_s, title, x_name, output_dir=None, **kwargs):
9
- data_s = copy.copy(data_s)
8
+
9
+ @COMMON_CHARTS.register(name=__name)
10
+ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, output_path=None, **kwargs):
11
+ """
12
+ 绘制折线图
13
+
14
+ 参数:
15
+ data_s: <dict> 数据。
16
+ 形如 {<data_name>: <data list>, ...} 的字典
17
+ title: <str> 绘图标题,同时用于保存图片的文件名。
18
+ x_name: <str> 以哪个 data_name 作为 x 轴。
19
+ 其余数据视为需要被绘制的数据点。
20
+ 例子: data_s={"step":[...], "acc_top1":[...], "acc_top3":[...]}
21
+ 当 x_name="step" 时,将会以 step 为 x 轴绘制 acc_top1 和 acc_top3 的 bar 图。
22
+ x_ticklabels_name: <str or None> 若提供则表示 x 轴刻度标签对应的键名,用于替换默认的 x 轴刻度值。
23
+ output_dir: <str> 图片输出目录。
24
+ output_path: <str> 图片输出路径。
25
+ 以上两个只需指定一个即可,同时指定时以后者为准。
26
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
27
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
28
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
29
+
30
+ 其他可选参数:
31
+ dpi: <int> 保存图像的分辨率。
32
+ 默认为 200。
33
+ suffix: <str> 图片保存后缀。
34
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
35
+ b_generate_record: <boolean> 是否保存函数参数为档案。
36
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
37
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
38
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
39
+ color_ls: <list> 用于绘图的颜色列表,默认根据数据序列个数自动生成。
40
+ marker_ls: <list of str> 折线图上各数据点的标记。
41
+ linestyle_ls: <list of str> 线型。
42
+ 默认值为 '-',表示直线。
43
+ """
10
44
  line_nums = len(data_s) - 1
11
45
  paras = {
12
46
  "dpi": 200,
47
+ "suffix": ".png",
48
+ "b_generate_record": False,
13
49
  "color_ls": generate_color_list(nums=line_nums),
14
50
  "marker_ls": None,
15
51
  "linestyle_ls": '-',
@@ -19,10 +55,20 @@ def plot_lines(data_s, title, x_name, output_dir=None, **kwargs):
19
55
  if k.endswith("_ls") and not isinstance(v, (list, tuple)):
20
56
  paras[k] = [v] * line_nums
21
57
  assert line_nums == len(paras["color_ls"]) == len(paras["marker_ls"]) == len(paras["linestyle_ls"])
58
+ #
59
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
+ save_record(_func=plot_lines, _name=__name,
61
+ _output_path=_output_path if paras["b_generate_record"] else None,
62
+ **paras)
63
+ data_s = data_s.copy()
22
64
 
23
65
  plt.clf()
24
66
  #
25
67
  x_all_ls = data_s.pop(x_name)
68
+ if x_ticklabels_name is not None:
69
+ x_ticklabels = data_s.pop(x_ticklabels_name)
70
+ assert len(x_all_ls) == len(x_ticklabels)
71
+ plt.xticks(x_all_ls, x_ticklabels)
26
72
  data_s, temp = dict(), data_s
27
73
  for k, v_ls in temp.items():
28
74
  y_ls, x_ls = [], []
@@ -44,20 +90,18 @@ def plot_lines(data_s, title, x_name, output_dir=None, **kwargs):
44
90
  # 显示图例
45
91
  plt.legend()
46
92
 
47
- if output_dir is None:
48
- plt.show()
49
- return None
50
- else:
51
- # 对非法字符进行替换
52
- os.makedirs(output_dir, exist_ok=True)
53
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
54
- plt.savefig(output_path, dpi=paras["dpi"])
55
- return output_path
93
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
94
+ return _output_path
56
95
 
57
96
 
58
97
  if __name__ == '__main__':
59
- plot_lines(data_s={
60
- 'a': [1, 2, 3, 4, 5],
61
- 'b': [5, 4, 3, 2, 1],
62
- 'c': [1, 2, 3, 4, 5]},
63
- title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
98
+ import os
99
+
100
+ plot_lines(
101
+ data_s={
102
+ 'a': [1, 2, 3, 4, 5],
103
+ 'b': [5, 4, 3, 2, 1],
104
+ 'c': [1, 2, 3, 4, 5]},
105
+ title='test_plot_lines',
106
+ x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp")
107
+ )
@@ -1,15 +1,65 @@
1
- import os
2
1
  import matplotlib.pyplot as plt
3
2
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
4
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
3
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
4
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
5
5
 
6
+ __name = ":common_charts:plot_scatters"
6
7
 
7
- def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, **kwargs):
8
+
9
+ @COMMON_CHARTS.register(name=__name)
10
+ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, output_path=None, **kwargs):
11
+ """
12
+ 绘制散点图
13
+ 不同类别的数据点
14
+
15
+ 参数:
16
+ data_s: <dict>
17
+ 数据字典,其中每个键对应一个数据列表。必须包含 x_name 和 y_name 对应的键,
18
+ 如果指定了 cate_name,还需要包含 cate_name 对应的键,用于按类别对散点图上数据点着色。
19
+ title: <str>
20
+ 图形标题,同时用于生成保存图像时的文件名(标题中的非法字符会被替换)。
21
+ x_name: <str> 以哪个 data_name 作为数据点的 x 轴。
22
+ y_name: <str> 以哪个 data_name 作为数据点的 y 轴。
23
+ cate_name: <str> 以哪个 data_name 作为数据点的类别。
24
+ output_dir: <str or None> 图片输出目录。
25
+ output_path: <str or None> 图片输出路径。
26
+ 以上两个只需指定一个即可,同时指定时以后者为准。
27
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
28
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
29
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
30
+
31
+ 其他可选参数:
32
+ dpi: <int> 保存图像的分辨率。
33
+ 默认为 200。
34
+ suffix: <str> 图片保存后缀。
35
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
36
+ b_generate_record: <boolean> 是否保存函数参数为档案。
37
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
38
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
39
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
40
+ scatter_size: <int> 散点的大小。
41
+ 默认 5。
42
+
43
+ 示例:
44
+ >>> data = {
45
+ ... "age": [25, 30, 22, 40],
46
+ ... "income": [50000, 60000, 45000, 80000],
47
+ ... "gender": ["M", "F", "M", "F"]
48
+ ... }
49
+ >>> path = plot_scatters(data, "Age vs Income", "age", "income", cate_name="gender", output_dir="./plots")
50
+ """
8
51
  paras = {
9
52
  "dpi": 200,
53
+ "suffix": ".png",
54
+ "b_generate_record": False,
10
55
  "scatter_size": 5
11
56
  }
12
57
  paras.update(kwargs)
58
+ #
59
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
+ save_record(_func=plot_scatters, _name=__name,
61
+ _output_path=_output_path if paras["b_generate_record"] else None,
62
+ **paras)
13
63
 
14
64
  plt.clf()
15
65
  #
@@ -33,21 +83,20 @@ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None
33
83
  markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
34
84
  ])
35
85
 
36
- if output_dir is None:
37
- plt.show()
38
- return None
39
- else:
40
- os.makedirs(output_dir, exist_ok=True)
41
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
42
- plt.savefig(output_path, dpi=paras["dpi"])
43
- return output_path
86
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
87
+ return _output_path
44
88
 
45
89
 
46
90
  if __name__ == '__main__':
91
+ import os
92
+
47
93
  data_s_ = dict(
48
94
  x=[1, 2, 3, 4, 5],
49
95
  y=[2, 4, 6, 8, 10],
50
96
  categories=['A', 'B', 'A', 'B', 'A']
51
97
  )
52
98
 
53
- plot_scatters(data_s=data_s_, title='test', x_name='x', y_name='y', cate_name='categories')
99
+ plot_scatters(
100
+ data_s=data_s_, title='test_plot_scatters', x_name='x', y_name='y', cate_name='categories',
101
+ output_dir=os.path.join(os.path.dirname(__file__), "temp")
102
+ )
@@ -1,14 +1,54 @@
1
- import os
2
1
  import pandas as pd
3
2
  import seaborn as sns
4
3
  import matplotlib.pyplot as plt
5
4
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
6
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
6
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
7
7
 
8
+ __name = ":common_charts:plot_scatters_matrix"
8
9
 
9
- def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, cate_color_s=None, **kwargs):
10
+
11
+ @COMMON_CHARTS.register(name=__name)
12
+ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, output_path=None, cate_color_s=None,
13
+ **kwargs):
14
+ """
15
+ 绘制散点图矩阵
16
+ 该函数用于展示多个变量之间的两两关系。
17
+
18
+ 参数:
19
+ data_s: <dict> 数据。
20
+ 形如 {<data_name>: <data list>, ...} 的字典
21
+ title: <str> 绘图标题。
22
+ x_name_ls: <list of str> 指定用于需要绘制哪些变量的两两关系。
23
+ cate_name: <str or None> 使用哪个 data_name 对应的取值作为类别。
24
+ output_dir: <str or None> 图片输出目录。
25
+ output_path: <str or None> 图片输出路径。
26
+ 以上两个只需指定一个即可,同时指定时以后者为准。
27
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
28
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
29
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
30
+ cate_color_s: <dict or None> 类别-颜色映射字典,将 cate_name 中的每个类别映射到具体颜色。
31
+ 默认为 None,自动生成颜色列表。
32
+
33
+ 其他可选参数:
34
+ dpi: <int> 保存图像的分辨率。
35
+ 默认为 200。
36
+ suffix: <str> 图片保存后缀。
37
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
38
+ b_generate_record: <boolean> 是否保存函数参数为档案。
39
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
40
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
41
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
42
+ diag_kind: <str> 对角线图表的类型。
43
+ 支持:
44
+ - "hist"(直方图)
45
+ - "kde"(核密度图)
46
+ 默认为 "kde"。
47
+ """
10
48
  paras = {
11
49
  "dpi": 200,
50
+ "suffix": ".png",
51
+ "b_generate_record": False,
12
52
  "diag_kind": "kde" # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
13
53
  }
14
54
  assert cate_name in data_s and len(set(x_name_ls).difference(set(data_s.keys()))) == 0
@@ -17,6 +57,11 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
17
57
  cate_color_s = {k: v for k, v in zip(temp, generate_color_list(len(temp)))}
18
58
  assert set(cate_color_s.keys()) == set(data_s[cate_name])
19
59
  paras.update(kwargs)
60
+ #
61
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
62
+ save_record(_func=plot_scatters_matrix, _name=__name,
63
+ _output_path=_output_path if paras["b_generate_record"] else None,
64
+ **paras)
20
65
 
21
66
  plt.clf()
22
67
  # 使用seaborn绘制散点图矩阵
@@ -29,19 +74,14 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
29
74
  #
30
75
  plt.subplots_adjust(top=0.95)
31
76
  plt.suptitle(f'{title}', y=0.98, x=0.47)
32
- # g.fig.suptitle(f'{title}', y=1.05)
33
77
 
34
- if output_dir is None:
35
- plt.show()
36
- return None
37
- else:
38
- os.makedirs(output_dir, exist_ok=True)
39
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
40
- plt.savefig(output_path, dpi=paras["dpi"])
41
- return output_path
78
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
79
+ return _output_path
42
80
 
43
81
 
44
82
  if __name__ == '__main__':
83
+ import os
84
+
45
85
  data_s_ = dict(
46
86
  x=[1, 2, 3, 4, 5],
47
87
  y=[2, 4, 6, 8, 10],
@@ -50,5 +90,8 @@ if __name__ == '__main__':
50
90
  title='test',
51
91
  )
52
92
 
53
- plot_scatters_matrix(data_s=data_s_, title='test', x_name_ls=['y', 'x', 'z'], cate_name='categories',
54
- cate_color_s={'A': 'red', 'B': 'blue'})
93
+ plot_scatters_matrix(
94
+ data_s=data_s_, title='test_plot_scatters_matrix', x_name_ls=['y', 'x', 'z'], cate_name='categories',
95
+ cate_color_s={'A': 'red', 'B': 'blue'},
96
+ output_dir=os.path.join(os.path.dirname(__file__), "temp")
97
+ )
@@ -0,0 +1,3 @@
1
+ from .get_output_path import get_output_path
2
+ from .save_plot import save_plot
3
+ from .save_record import save_record
@@ -0,0 +1,15 @@
1
+ import os
2
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
3
+
4
+
5
+ def get_output_path(output_path=None, output_dir=None, title=None, suffix=".png", **kwargs):
6
+ if output_path is None:
7
+ if output_dir is None:
8
+ output_path = None
9
+ else:
10
+ assert title is not None
11
+ assert suffix in [".png", ".jpg", ".bmp"]
12
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}{suffix}')
13
+ else:
14
+ output_path = os.path.join(os.path.dirname(output_path), replace_illegal_chars(os.path.basename(output_path)))
15
+ return output_path
@@ -0,0 +1,11 @@
1
+ import os
2
+
3
+
4
+ def save_plot(plt, output_path, dpi=200, suffix=".png", **kwargs):
5
+ assert suffix in [".png", ".jpg", ".bmp"]
6
+
7
+ if output_path is None:
8
+ plt.show()
9
+ else:
10
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
11
+ plt.savefig(output_path, dpi=dpi)
@@ -0,0 +1,34 @@
1
+ import inspect
2
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
3
+
4
+
5
+ def save_record(_name, _output_path,_func=None, **kwargs):
6
+ if _output_path is None:
7
+ return None
8
+
9
+ # 获取函数的参数列表
10
+ frame = inspect.currentframe().f_back
11
+ if _func is None:
12
+ func_name = inspect.getframeinfo(frame).function
13
+ _func = globals()[func_name]
14
+ sig = inspect.signature(_func)
15
+ # 获取参数值
16
+ args_info = inspect.getargvalues(frame)
17
+ arg_names = args_info.args
18
+ arg_values = args_info.locals
19
+ # 将参数值映射到函数签名中
20
+ kwargs_raw = {name: arg_values[name] for name in arg_names}
21
+ for param_name, param in sig.parameters.items():
22
+ if param.kind == param.VAR_KEYWORD:
23
+ kwargs_raw.update(arg_values[param_name])
24
+ elif param.kind == param.VAR_POSITIONAL:
25
+ kwargs_raw[param_name] = arg_values[param_name]
26
+
27
+ kwargs_raw.update(kwargs)
28
+
29
+ from kevin_toolbox.computer_science.algorithm.registration import Serializer_for_Registry_Execution
30
+ serializer = Serializer_for_Registry_Execution()
31
+ file_path = serializer.record_name(
32
+ _name=_name, _registry=COMMON_CHARTS
33
+ ).record_paras(**kwargs_raw).save(_output_path + ".record", b_pack_into_tar=True, b_allow_overwrite=True)
34
+ return file_path
@@ -0,0 +1,20 @@
1
+ import os
2
+ from kevin_toolbox.computer_science.algorithm.registration import Registry
3
+
4
+ COMMON_CHARTS = Registry(uid="COMMON_CHARTS")
5
+
6
+ # 导入时的默认过滤规则
7
+ ignore_s = [
8
+ {
9
+ "func": lambda _, __, path: os.path.basename(path) in ["temp", "test", "__pycache__",
10
+ "_old_version"],
11
+ "scope": ["root", "dirs"]
12
+ },
13
+ ]
14
+
15
+ # 从 kevin_toolbox/patches/for_matplotlib/common_charts 下收集被注册的方法
16
+ COMMON_CHARTS.collect_from_paths(
17
+ path_ls=[os.path.join(os.path.dirname(__file__), "common_charts"), ],
18
+ ignore_s=ignore_s,
19
+ b_execute_now=False
20
+ )