kevin-toolbox-dev 1.4.11__py3-none-any.whl → 1.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/get_pareto_points_idx.py +2 -0
  3. kevin_toolbox/computer_science/algorithm/redirector/redirectable_sequence_fetcher.py +3 -3
  4. kevin_toolbox/computer_science/algorithm/sampler/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/sampler/recent_sampler.py +128 -0
  6. kevin_toolbox/computer_science/algorithm/sampler/reservoir_sampler.py +2 -2
  7. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +2 -0
  8. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +1 -1
  9. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +1 -1
  10. kevin_toolbox/computer_science/algorithm/statistician/maximum_accumulator.py +80 -0
  11. kevin_toolbox/computer_science/algorithm/statistician/minimum_accumulator.py +34 -0
  12. kevin_toolbox/data_flow/file/markdown/table/find_tables.py +38 -12
  13. kevin_toolbox/developing/file_management/__init__.py +1 -0
  14. kevin_toolbox/developing/file_management/file_feature_extractor.py +263 -0
  15. kevin_toolbox/nested_dict_list/serializer/read.py +4 -1
  16. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +5 -0
  17. kevin_toolbox/patches/for_matplotlib/common_charts/plot_2d_matrix.py +134 -0
  18. kevin_toolbox/patches/for_matplotlib/common_charts/plot_3d.py +198 -0
  19. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +7 -4
  20. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +11 -4
  21. kevin_toolbox/patches/for_matplotlib/common_charts/plot_contour.py +157 -0
  22. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +19 -8
  23. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +72 -21
  24. kevin_toolbox/patches/for_matplotlib/common_charts/plot_mean_std_lines.py +135 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +9 -3
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +9 -3
  27. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +1 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/utils/log_scaling.py +69 -0
  29. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +19 -3
  30. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +1 -1
  31. kevin_toolbox/patches/for_numpy/__init__.py +1 -0
  32. kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
  33. kevin_toolbox_dev-1.4.13.dist-info/METADATA +77 -0
  34. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/RECORD +36 -26
  35. kevin_toolbox_dev-1.4.11.dist-info/METADATA +0 -67
  36. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/WHEEL +0 -0
  37. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
32
32
  output_path: <str> 图片输出路径。
33
33
  以上两个只需指定一个即可,同时指定时以后者为准。
34
34
  当只有 output_dir 被指定时,将会以 title 作为图片名。
35
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
35
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
36
36
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
37
37
 
38
38
  其他可选参数:
@@ -43,6 +43,10 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
43
43
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
44
44
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
45
45
  该参数仅在 output_dir 和 output_path 非 None 时起效。
46
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
47
+ 默认为 False
48
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
49
+ 默认为 True
46
50
 
47
51
  返回值:
48
52
  <str> 图像保存的完整文件路径。如果 output_dir 或 output_path 被指定,
@@ -52,7 +56,9 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
52
56
  paras = {
53
57
  "dpi": 200,
54
58
  "suffix": ".png",
55
- "b_generate_record": False
59
+ "b_generate_record": False,
60
+ "b_show_plot": False,
61
+ "b_bgr_image": True
56
62
  }
57
63
  paras.update(kwargs)
58
64
  #
@@ -98,17 +104,22 @@ def plot_distribution(data_s, title, x_name=None, x_name_ls=None, type_="hist",
98
104
  # 显示图例
99
105
  plt.legend()
100
106
 
101
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
102
- return _output_path
107
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
108
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
103
109
 
104
110
 
105
111
  if __name__ == '__main__':
106
112
  import os
113
+ import cv2
107
114
 
108
- plot_distribution(
115
+ image_ = plot_distribution(
109
116
  data_s={
110
- 'a': [1, 2, 3, 4, 5, 3, 2, 1],
117
+ 'a': [0.1, 2, 3, 4, 5, 3, 2, 1],
111
118
  'c': [1, 2, 3, 4, 5, 0, 0, 0]},
112
- title='test_plot_distribution', x_name_ls=['a', 'c'], type_="category",
113
- output_dir=os.path.join(os.path.dirname(__file__), "temp")
119
+ title='test_plot_distribution', x_name_ls=['a', 'c'],
120
+ # type_="category",
121
+ # output_dir=os.path.join(os.path.dirname(__file__), "temp"),
122
+ # b_show_plot=True
123
+ type_="hist", steps=1,
114
124
  )
125
+ cv2.imwrite(os.path.join(os.path.dirname(__file__), "temp", "233.png"), image_)
@@ -1,13 +1,38 @@
1
+ import warnings
1
2
  import matplotlib.pyplot as plt
2
3
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
3
- from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
4
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path, \
5
+ log_scaling
4
6
  from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
5
7
 
6
8
  __name = ":common_charts:plot_lines"
7
9
 
8
10
 
11
+ def log_scaling_for_x_y(data_s, x_name, y_names, **kwargs):
12
+ d_s = dict()
13
+ ticks_s = dict()
14
+ tick_labels_s = dict()
15
+ d_s["x"] = data_s.pop(x_name)
16
+ d_s["y"] = []
17
+ for k in y_names:
18
+ d_s["y"].extend(data_s[k])
19
+ for k in ("x", "y"):
20
+ d_s[k], ticks_s[k], tick_labels_s[k] = log_scaling(
21
+ x_ls=d_s[k], log_scale=kwargs[f"{k}_log_scale"],
22
+ ticks=kwargs[f"{k}_ticks"], tick_labels=kwargs[f"{k}_tick_labels"],
23
+ label_formatter=kwargs[f"{k}_label_formatter"]
24
+ )
25
+ temp = d_s.pop("y")
26
+ count = 0
27
+ for k in y_names:
28
+ data_s[k] = temp[count:count + len(data_s[k])]
29
+ count += len(data_s[k])
30
+ data_s[x_name] = d_s["x"]
31
+ return data_s, ticks_s, tick_labels_s
32
+
33
+
9
34
  @COMMON_CHARTS.register(name=__name)
10
- def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, output_path=None, **kwargs):
35
+ def plot_lines(data_s, title, x_name, y_name_ls=None, output_dir=None, output_path=None, **kwargs):
11
36
  """
12
37
  绘制折线图
13
38
 
@@ -16,15 +41,15 @@ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, o
16
41
  形如 {<data_name>: <data list>, ...} 的字典
17
42
  title: <str> 绘图标题,同时用于保存图片的文件名。
18
43
  x_name: <str> 以哪个 data_name 作为 x 轴。
19
- 其余数据视为需要被绘制的数据点。
44
+ y_name_ls: <list of str> 哪些数据视为需要被绘制的数据点。
45
+ 默认为 None,表示除 x_name 以外的数据都是需要绘制的。
20
46
  例子: data_s={"step":[...], "acc_top1":[...], "acc_top3":[...]}
21
47
  当 x_name="step" 时,将会以 step 为 x 轴绘制 acc_top1 和 acc_top3 的 bar 图。
22
- x_ticklabels_name: <str or None> 若提供则表示 x 轴刻度标签对应的键名,用于替换默认的 x 轴刻度值。
23
48
  output_dir: <str> 图片输出目录。
24
49
  output_path: <str> 图片输出路径。
25
50
  以上两个只需指定一个即可,同时指定时以后者为准。
26
51
  当只有 output_dir 被指定时,将会以 title 作为图片名。
27
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
52
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
28
53
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
29
54
 
30
55
  其他可选参数:
@@ -36,44 +61,65 @@ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, o
36
61
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
37
62
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
38
63
  该参数仅在 output_dir 和 output_path 非 None 时起效。
64
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
65
+ 默认为 False
66
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
67
+ 默认为 True
39
68
  color_ls: <list> 用于绘图的颜色列表,默认根据数据序列个数自动生成。
40
69
  marker_ls: <list of str> 折线图上各数据点的标记。
41
70
  linestyle_ls: <list of str> 线型。
42
71
  默认值为 '-',表示直线。
43
72
  """
44
- line_nums = len(data_s) - 1
73
+ y_names = y_name_ls if y_name_ls else list(i for i in data_s.keys() if i != x_name)
74
+ line_nums = len(y_names)
45
75
  paras = {
46
76
  "dpi": 200,
47
77
  "suffix": ".png",
48
78
  "b_generate_record": False,
79
+ "b_show_plot": False,
80
+ "b_bgr_image": True,
49
81
  "color_ls": generate_color_list(nums=line_nums),
50
82
  "marker_ls": None,
51
83
  "linestyle_ls": '-',
84
+ #
85
+ "x_label": f'{x_name}',
86
+ "y_label": "value",
87
+ "x_log_scale": None,
88
+ "x_ticks": None,
89
+ "x_tick_labels": None,
90
+ "x_label_formatter": None,
91
+ "y_log_scale": None,
92
+ "y_ticks": None,
93
+ "y_tick_labels": None,
94
+ "y_label_formatter": None,
52
95
  }
53
96
  paras.update(kwargs)
54
97
  for k, v in paras.items():
55
98
  if k.endswith("_ls") and not isinstance(v, (list, tuple)):
56
99
  paras[k] = [v] * line_nums
57
100
  assert line_nums == len(paras["color_ls"]) == len(paras["marker_ls"]) == len(paras["linestyle_ls"])
101
+ if "x_ticklabels_name" in paras:
102
+ warnings.warn(f"{__name}: 'x_ticklabels_name' is deprecated, please use 'x_ticks' and 'x_tick_labels' instead.")
58
103
  #
59
104
  _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
105
  save_record(_func=plot_lines, _name=__name,
61
106
  _output_path=_output_path if paras["b_generate_record"] else None,
62
107
  **paras)
63
108
  data_s = data_s.copy()
109
+ #
110
+ data_s, ticks_s, tick_labels_s = log_scaling_for_x_y(data_s=data_s, x_name=x_name, y_names=y_names, **paras)
64
111
 
65
112
  plt.clf()
113
+ fig = plt.figure(figsize=(10, 8))
114
+ ax = fig.add_subplot(111)
115
+
66
116
  #
67
117
  x_all_ls = data_s.pop(x_name)
68
- if x_ticklabels_name is not None:
69
- x_ticklabels = data_s.pop(x_ticklabels_name)
70
- assert len(x_all_ls) == len(x_ticklabels)
71
- plt.xticks(x_all_ls, x_ticklabels)
72
118
  data_s, temp = dict(), data_s
73
119
  for k, v_ls in temp.items():
74
120
  y_ls, x_ls = [], []
75
121
  for x, v in zip(x_all_ls, v_ls):
76
- if x is None or v is None:
122
+ if v is None:
77
123
  continue
78
124
  x_ls.append(x)
79
125
  y_ls.append(v)
@@ -81,17 +127,22 @@ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, o
81
127
  continue
82
128
  data_s[k] = (x_ls, y_ls)
83
129
  #
84
- for i, (k, (x_ls, y_ls)) in enumerate(data_s.items()):
85
- plt.plot(x_ls, y_ls, label=f'{k}', color=paras["color_ls"][i], marker=paras["marker_ls"][i],
86
- linestyle=paras["linestyle_ls"][i])
87
- plt.xlabel(f'{x_name}')
88
- plt.ylabel('value')
89
- plt.title(f'{title}')
130
+ for i, k in enumerate(y_names):
131
+ x_ls, y_ls = data_s[k]
132
+ ax.plot(x_ls, y_ls, label=f'{k}', color=paras["color_ls"][i], marker=paras["marker_ls"][i],
133
+ linestyle=paras["linestyle_ls"][i])
134
+ ax.set_xlabel(paras["x_label"])
135
+ ax.set_ylabel(paras["y_label"])
136
+ ax.set_title(f'{title}')
137
+ for i in ("x", "y",):
138
+ if ticks_s[i] is not None:
139
+ getattr(ax, f'set_{i}ticks')(ticks_s[i])
140
+ getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
90
141
  # 显示图例
91
142
  plt.legend()
92
143
 
93
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
94
- return _output_path
144
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
145
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
95
146
 
96
147
 
97
148
  if __name__ == '__main__':
@@ -99,9 +150,9 @@ if __name__ == '__main__':
99
150
 
100
151
  plot_lines(
101
152
  data_s={
102
- 'a': [1, 2, 3, 4, 5],
153
+ 'a': [0, 2, 3, 4, 5],
103
154
  'b': [5, 4, 3, 2, 1],
104
155
  'c': [1, 2, 3, 4, 5]},
105
- title='test_plot_lines',
156
+ title='test_plot_lines', y_log_scale=2,
106
157
  x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp")
107
158
  )
@@ -0,0 +1,135 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
4
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
5
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
6
+ from kevin_toolbox.patches.for_matplotlib.common_charts.plot_lines import log_scaling_for_x_y
7
+
8
+ __name = ":common_charts:plot_mean_std_lines"
9
+
10
+
11
+ @COMMON_CHARTS.register(name=__name)
12
+ def plot_mean_std_lines(data_s, title, x_name, mean_name_ls, std_name_ls, output_dir=None, output_path=None, **kwargs):
13
+ """
14
+ 绘制均值和标准差折线图及其区域填充
15
+
16
+ 参数:
17
+ data_s: <dict> 数据。
18
+ 格式为:{
19
+ x_name: [...],
20
+ "name1_mean": [...], "name1_std": [...],
21
+ "name2_mean": [...], "name2_std": [...],
22
+ ...
23
+ }
24
+ title: <str> 绘图标题,同时用于保存图片的文件名。
25
+ x_name: <str> 以哪个 data_name 作为 x 轴。
26
+ mean_name_ls: <list of str> 哪些名字对应的数据作为均值。
27
+ std_name_ls: <list of str> 哪些名字对应的数据作为标准差。
28
+ 上面两参数要求大小相同,其同一位置表示该均值和标准差作为同一分组进行展示。
29
+ output_dir: <str> 图片输出目录。
30
+ output_path: <str> 图片输出路径。
31
+ 以上两个只需指定一个即可,同时指定时以后者为准。
32
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
33
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
34
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
35
+
36
+ 可选参数:
37
+ dpi, suffix, b_generate_record, b_show_plot, b_bgr_image, color_ls, marker_ls, linestyle_ls 等(参考 plot_lines 的说明)
38
+ """
39
+ assert x_name in data_s
40
+ y_names = set(mean_name_ls).union(set(std_name_ls))
41
+ assert y_names.issubset(data_s.keys())
42
+ assert len(mean_name_ls) == len(std_name_ls)
43
+ line_nums = len(mean_name_ls)
44
+ y_names = list(y_names)
45
+
46
+ paras = {
47
+ "dpi": 200,
48
+ "suffix": ".png",
49
+ "b_generate_record": False,
50
+ "b_show_plot": False,
51
+ "b_bgr_image": True,
52
+ "color_ls": generate_color_list(nums=line_nums),
53
+ "marker_ls": None,
54
+ "linestyle_ls": '-',
55
+ #
56
+ "x_label": f'{x_name}',
57
+ "y_label": "value",
58
+ "x_log_scale": None,
59
+ "x_ticks": None,
60
+ "x_tick_labels": None,
61
+ "x_label_formatter": None,
62
+ "y_log_scale": None,
63
+ "y_ticks": None,
64
+ "y_tick_labels": None,
65
+ "y_label_formatter": None,
66
+ }
67
+ paras.update(kwargs)
68
+ for k, v in paras.items():
69
+ if k.endswith("_ls") and not isinstance(v, (list, tuple)):
70
+ paras[k] = [v] * line_nums
71
+ assert line_nums == len(paras["color_ls"]) == len(paras["marker_ls"]) == len(paras["linestyle_ls"])
72
+
73
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
74
+ save_record(_func=plot_mean_std_lines, _name=__name,
75
+ _output_path=_output_path if paras["b_generate_record"] else None,
76
+ **paras)
77
+ data_s = data_s.copy()
78
+ #
79
+ data_s, ticks_s, tick_labels_s = log_scaling_for_x_y(data_s=data_s, x_name=x_name, y_names=y_names, **paras)
80
+
81
+ plt.clf()
82
+ fig = plt.figure(figsize=(10, 8))
83
+ ax = fig.add_subplot(111)
84
+
85
+ #
86
+ x_all_ls = data_s.pop(x_name)
87
+ for i, (mean_name, std_name) in enumerate(zip(mean_name_ls, std_name_ls)):
88
+ mean_ls, std_ls, x_ls = [], [], []
89
+ for mean, std, x in zip(data_s[mean_name], data_s[std_name], x_all_ls):
90
+ if mean is None or std is None or x is None:
91
+ continue
92
+ mean_ls.append(mean)
93
+ std_ls.append(std)
94
+ x_ls.append(x)
95
+ if len(x_ls) == 0:
96
+ continue
97
+ mean_ls = np.array(mean_ls)
98
+ std_ls = np.array(std_ls)
99
+ ax.plot(x_ls, mean_ls, label=f'{mean_name}', color=paras["color_ls"][i], marker=paras["marker_ls"][i],
100
+ linestyle=paras["linestyle_ls"][i])
101
+ ax.fill_between(x_ls, mean_ls - std_ls, mean_ls + std_ls, color=paras["color_ls"][i], alpha=0.2)
102
+
103
+ ax.set_xlabel(paras["x_label"])
104
+ ax.set_ylabel(paras["y_label"])
105
+ ax.set_title(f'{title}')
106
+ ax.grid(True)
107
+ for i in ("x", "y",):
108
+ if ticks_s[i] is not None:
109
+ getattr(ax, f'set_{i}ticks')(ticks_s[i])
110
+ getattr(ax, f'set_{i}ticklabels')(tick_labels_s[i])
111
+ # 显示图例
112
+ plt.legend()
113
+ plt.tight_layout()
114
+
115
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
116
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
117
+
118
+
119
+ if __name__ == '__main__':
120
+ import os
121
+
122
+ plot_mean_std_lines(data_s={
123
+ 'a': [0.1, 0.5, 1.0, 2.0, 5.0],
124
+ 'model1': [0.3, 0.45, 0.5, 0.55, 0.6],
125
+ 'model1_std': [0.05, 0.07, 0.08, 0.06, 0.04],
126
+ 'model2': [0.25, 0.4, 0.48, 0.52, 0.58],
127
+ 'model2_std': [0.04, 0.06, 0.07, 0.05, 0.03]
128
+ },
129
+ x_name='a',
130
+ mean_name_ls=['model1', 'model2'],
131
+ std_name_ls=['model1_std', 'model2_std'],
132
+ title='test_plot_mean_std_lines',
133
+ output_dir=os.path.join(os.path.dirname(__file__), "temp"),
134
+ b_generate_record=True, b_show_plot=True
135
+ )
@@ -25,7 +25,7 @@ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None
25
25
  output_path: <str or None> 图片输出路径。
26
26
  以上两个只需指定一个即可,同时指定时以后者为准。
27
27
  当只有 output_dir 被指定时,将会以 title 作为图片名。
28
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
28
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
29
29
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
30
30
 
31
31
  其他可选参数:
@@ -37,6 +37,10 @@ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None
37
37
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
38
38
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
39
39
  该参数仅在 output_dir 和 output_path 非 None 时起效。
40
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
41
+ 默认为 False
42
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
43
+ 默认为 True
40
44
  scatter_size: <int> 散点的大小。
41
45
  默认 5。
42
46
 
@@ -52,6 +56,8 @@ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None
52
56
  "dpi": 200,
53
57
  "suffix": ".png",
54
58
  "b_generate_record": False,
59
+ "b_show_plot": False,
60
+ "b_bgr_image": True,
55
61
  "scatter_size": 5
56
62
  }
57
63
  paras.update(kwargs)
@@ -83,8 +89,8 @@ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None
83
89
  markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
84
90
  ])
85
91
 
86
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
87
- return _output_path
92
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
93
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
88
94
 
89
95
 
90
96
  if __name__ == '__main__':
@@ -25,7 +25,7 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
25
25
  output_path: <str or None> 图片输出路径。
26
26
  以上两个只需指定一个即可,同时指定时以后者为准。
27
27
  当只有 output_dir 被指定时,将会以 title 作为图片名。
28
- 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
28
+ 若同时不指定,则直接以 np.ndarray 形式返回图片,不进行保存。
29
29
  在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
30
30
  cate_color_s: <dict or None> 类别-颜色映射字典,将 cate_name 中的每个类别映射到具体颜色。
31
31
  默认为 None,自动生成颜色列表。
@@ -39,6 +39,10 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
39
39
  默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
40
40
  后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
41
41
  该参数仅在 output_dir 和 output_path 非 None 时起效。
42
+ b_show_plot: <boolean> 是否使用 plt.show() 展示图片。
43
+ 默认为 False
44
+ b_bgr_image: <boolean> 以 np.ndarray 形式返回图片时,图片的channel顺序是采用 bgr 还是 rgb。
45
+ 默认为 True
42
46
  diag_kind: <str> 对角线图表的类型。
43
47
  支持:
44
48
  - "hist"(直方图)
@@ -49,6 +53,8 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
49
53
  "dpi": 200,
50
54
  "suffix": ".png",
51
55
  "b_generate_record": False,
56
+ "b_show_plot": False,
57
+ "b_bgr_image": True,
52
58
  "diag_kind": "kde" # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
53
59
  }
54
60
  assert cate_name in data_s and len(set(x_name_ls).difference(set(data_s.keys()))) == 0
@@ -75,8 +81,8 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
75
81
  plt.subplots_adjust(top=0.95)
76
82
  plt.suptitle(f'{title}', y=0.98, x=0.47)
77
83
 
78
- save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
79
- return _output_path
84
+ return save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"],
85
+ b_bgr_image=paras["b_bgr_image"], b_show_plot=paras["b_show_plot"])
80
86
 
81
87
 
82
88
  if __name__ == '__main__':
@@ -1,3 +1,4 @@
1
1
  from .get_output_path import get_output_path
2
2
  from .save_plot import save_plot
3
3
  from .save_record import save_record
4
+ from .log_scaling import log_scaling
@@ -0,0 +1,69 @@
1
+ import numpy as np
2
+
3
+
4
+ def log_scaling(x_ls, log_scale=None, ticks=None, tick_labels=None, b_replace_nan_inf_with_none=True, label_formatter=None):
5
+ original_x_ls = None
6
+ if isinstance(x_ls, np.ndarray):
7
+ original_x_ls = x_ls
8
+ x_ls = x_ls.reshape(-1)
9
+ label_formatter = label_formatter or (lambda x: f"{x:.2e}")
10
+ raw_x_ls, x_ls = x_ls, []
11
+ none_idx_ls = []
12
+ for idx, i in enumerate(raw_x_ls):
13
+ if i is None or np.isnan(i) or np.isinf(i):
14
+ none_idx_ls.append(idx)
15
+ else:
16
+ x_ls.append(i)
17
+ x_ls = np.asarray(x_ls)
18
+ if isinstance(ticks, int):
19
+ ticks = np.linspace(np.min(x_ls), np.max(x_ls), ticks
20
+ ) if log_scale is None else np.logspace(
21
+ np.log(np.min(x_ls)) / np.log(log_scale), np.log(np.max(x_ls)) / np.log(log_scale), ticks, base=log_scale)
22
+ #
23
+ if log_scale is not None:
24
+ assert log_scale > 0 and np.min(x_ls) > 0
25
+ if ticks is None:
26
+ ticks = sorted(list(set(x_ls.reshape(-1).tolist())))
27
+ if tick_labels is None:
28
+ tick_labels = [label_formatter(j) for j in ticks]
29
+ assert len(ticks) == len(tick_labels)
30
+ ticks = [np.log(j) / np.log(log_scale) for j in ticks]
31
+ x_ls = np.log(x_ls) / np.log(log_scale)
32
+ else:
33
+ if ticks is None:
34
+ from matplotlib.ticker import MaxNLocator
35
+ locator = MaxNLocator()
36
+ ticks = locator.tick_values(np.min(x_ls), np.max(x_ls))
37
+ if ticks is not None and tick_labels is None:
38
+ tick_labels = [label_formatter(j) for j in ticks]
39
+
40
+ if none_idx_ls:
41
+ x_ls = x_ls.tolist()
42
+ for idx in none_idx_ls:
43
+ if b_replace_nan_inf_with_none:
44
+ x_ls.insert(idx, None)
45
+ else:
46
+ x_ls.insert(idx, raw_x_ls[idx])
47
+
48
+ if original_x_ls is not None:
49
+ x_ls = np.asarray(x_ls, dtype=original_x_ls.dtype).reshape(original_x_ls.shape)
50
+
51
+ return x_ls, ticks, tick_labels
52
+
53
+
54
+ if __name__ == "__main__":
55
+ x_ls_ = np.linspace(0.1, 100, 10)
56
+
57
+ out_ls_, ticks_, tick_labels_ = log_scaling(x_ls_, log_scale=10, ticks=5)
58
+ print(out_ls_)
59
+ print(ticks_)
60
+ print(tick_labels_)
61
+
62
+ x_ls_[2] = np.inf
63
+ x_ls_[3] = -np.inf
64
+ x_ls_[4] = -np.nan
65
+ print(x_ls_)
66
+ out_ls_, ticks_, tick_labels_ = log_scaling(x_ls_, log_scale=10, ticks=5)
67
+ print(out_ls_)
68
+ print(ticks_)
69
+ print(tick_labels_)
@@ -1,12 +1,28 @@
1
1
  import os
2
+ import io
3
+ import numpy as np
4
+ from PIL import Image
2
5
 
3
6
 
4
- def save_plot(plt, output_path, dpi=200, suffix=".png", **kwargs):
7
+ def save_plot(plt, output_path, dpi=200, suffix=".png", b_bgr_image=False, b_show_plot=False, **kwargs):
5
8
  assert suffix in [".png", ".jpg", ".bmp"]
6
9
 
7
- if output_path is None:
10
+ if b_show_plot:
8
11
  plt.show()
12
+
13
+ if output_path is None:
14
+ buf = io.BytesIO()
15
+ plt.savefig(buf, format=suffix.split(".")[-1].lower(), dpi=dpi)
16
+ buf.seek(0)
17
+ image = Image.open(buf).convert("RGB")
18
+ image = np.array(image)
19
+ buf.close()
20
+ plt.close()
21
+ if b_bgr_image:
22
+ image = image[..., ::-1]
23
+ return image
9
24
  else:
10
25
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
11
26
  plt.savefig(output_path, dpi=dpi)
12
- plt.close()
27
+ plt.close()
28
+ return output_path
@@ -2,7 +2,7 @@ import inspect
2
2
  from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
3
3
 
4
4
 
5
- def save_record(_name, _output_path,_func=None, **kwargs):
5
+ def save_record(_name, _output_path, _func=None, **kwargs):
6
6
  if _output_path is None:
7
7
  return None
8
8
 
@@ -1 +1,2 @@
1
1
  from . import linalg
2
+ from . import random
@@ -26,7 +26,10 @@ def softmax(x, axis=-1, temperature=None, b_use_log_over_x=False):
26
26
  res = np.where(x == np.max(x, axis=axis), 1, res)
27
27
  elif b_use_log_over_x:
28
28
  # softmax(log(x))
29
- res = x ** (1 / temperature)
29
+ if temperature is not None:
30
+ res = x ** (1 / temperature)
31
+ else:
32
+ res = x
30
33
  else:
31
34
  # softmax(x)
32
35
  # 为了数值稳定,减去最大值
@@ -0,0 +1,77 @@
1
+ Metadata-Version: 2.1
2
+ Name: kevin-toolbox-dev
3
+ Version: 1.4.13
4
+ Summary: 一个常用的工具代码包集合
5
+ Home-page: https://github.com/cantbeblank96/kevin_toolbox
6
+ Download-URL: https://github.com/username/your-package/archive/refs/tags/v1.0.0.tar.gz
7
+ Author: kevin hsu
8
+ Author-email: xukaiming1996@163.com
9
+ License: MIT
10
+ Keywords: mathematics,pytorch,numpy,machine-learning,algorithm
11
+ Platform: UNKNOWN
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.6
16
+ Description-Content-Type: text/markdown
17
+ Requires-Dist: torch (>=1.2.0)
18
+ Requires-Dist: numpy (>=1.19.0)
19
+ Provides-Extra: plot
20
+ Requires-Dist: matplotlib (>=3.0) ; extra == 'plot'
21
+ Provides-Extra: rest
22
+ Requires-Dist: pytest (>=6.2.5) ; extra == 'rest'
23
+ Requires-Dist: line-profiler (>=3.5) ; extra == 'rest'
24
+
25
+ # kevin_toolbox
26
+
27
+ 一个通用的工具代码包集合
28
+
29
+
30
+
31
+ 环境要求
32
+
33
+ ```shell
34
+ numpy>=1.19
35
+ pytorch>=1.2
36
+ ```
37
+
38
+ 安装方法:
39
+
40
+ ```shell
41
+ pip install kevin-toolbox --no-dependencies
42
+ ```
43
+
44
+
45
+
46
+ [项目地址 Repo](https://github.com/cantbeblank96/kevin_toolbox)
47
+
48
+ [使用指南 User_Guide](./notes/User_Guide.md)
49
+
50
+ [免责声明 Disclaimer](./notes/Disclaimer.md)
51
+
52
+ [版本更新记录](./notes/Release_Record.md):
53
+
54
+ - v 1.4.13 (2025-07-21)【bug fix】【new feature】
55
+ - data_flow.file.markdown
56
+ - modify find_tables(),完善读取表格函数,支持更多的表格格式,包括以梅花线作为标题栏分割线,表格最左侧和最右侧分割线省略等情况。
57
+ - nested_dict_list.serializer
58
+ - modify read(),支持在读取时通过参数 b_keep_identical_relations 对 record.json 中的同名参数进行覆盖。
59
+ - computer_science.algorithm
60
+ - redirector
61
+ - 【bug fix】fix bug in _randomly_idx_redirector() for Redirectable_Sequence_Fetcher,改正了 rng.randint(low, high) 中参数 high 的设置。
62
+ - pareto_front
63
+ - modify get_pareto_points_idx(),支持参数 directions 只输入单个值来表示所有方向都使用该值。
64
+ - sampler
65
+ - 【new feature】add Recent_Sampler,最近采样器:始终保留最近加入的 capacity 个样本。
66
+ - patches.for_matplotlib
67
+ - common_charts.utils
68
+ - modify .save_plot(),将原来在 output_path 为 None 时使用 plt.show() 展示图像的行为改为返回 np.array 形式的图像,并支持通过参数 b_show_plot 来单独控制是否展示图像。
69
+ - 【new feature】add log_scaling(),用于处理坐标系变换。
70
+ - common_charts
71
+ - 【new feature】add plot_3d(),绘制3D图,支持:散点图、三角剖分曲面及其平滑版本。
72
+ - 【new feature】add plot_contour(),绘制等高线图。
73
+ - 【new feature】add plot_mean_std_lines(),绘制均值和标准差折线图及其区域填充。
74
+ - 【new feature】add plot_2d_matrix(),计算并绘制混淆矩阵。
75
+
76
+
77
+