kevin-toolbox-dev 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/{developing → computer_science/algorithm}/decorator/__init__.py +2 -1
  3. kevin_toolbox/computer_science/algorithm/decorator/retry.py +62 -0
  4. kevin_toolbox/computer_science/algorithm/registration/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/registration/serializer_for_registry_execution.py +82 -0
  6. kevin_toolbox/data_flow/core/cache/cache_manager_for_iterator.py +1 -1
  7. kevin_toolbox/data_flow/file/json_/write_json.py +2 -1
  8. kevin_toolbox/env_info/check_version_and_update.py +0 -1
  9. kevin_toolbox/env_info/variable_/env_vars_parser.py +17 -2
  10. kevin_toolbox/nested_dict_list/copy_.py +4 -2
  11. kevin_toolbox/nested_dict_list/get_nodes.py +4 -2
  12. kevin_toolbox/nested_dict_list/serializer/variable.py +14 -2
  13. kevin_toolbox/nested_dict_list/serializer/write.py +2 -0
  14. kevin_toolbox/nested_dict_list/traverse.py +75 -21
  15. kevin_toolbox/nested_dict_list/value_parser/replace_identical_with_reference.py +1 -4
  16. kevin_toolbox/network/__init__.py +10 -0
  17. kevin_toolbox/network/download_file.py +120 -0
  18. kevin_toolbox/network/fetch_content.py +55 -0
  19. kevin_toolbox/network/fetch_metadata.py +64 -0
  20. kevin_toolbox/network/get_response.py +50 -0
  21. kevin_toolbox/network/variable.py +6 -0
  22. kevin_toolbox/patches/for_logging/build_logger.py +1 -1
  23. kevin_toolbox/patches/for_matplotlib/color/convert_format.py +0 -2
  24. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +45 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +63 -22
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +67 -20
  27. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +66 -17
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_from_record.py +21 -0
  29. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +59 -19
  30. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +61 -12
  31. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +57 -14
  32. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +3 -0
  33. kevin_toolbox/patches/for_matplotlib/common_charts/utils/get_output_path.py +15 -0
  34. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +12 -0
  35. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +34 -0
  36. kevin_toolbox/patches/for_matplotlib/variable.py +20 -0
  37. kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
  38. kevin_toolbox_dev-1.4.9.dist-info/METADATA +75 -0
  39. {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/RECORD +42 -28
  40. kevin_toolbox_dev-1.4.7.dist-info/METADATA +0 -69
  41. /kevin_toolbox/{developing → computer_science/algorithm}/decorator/restore_original_work_path.py +0 -0
  42. {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/WHEEL +0 -0
  43. {kevin_toolbox_dev-1.4.7.dist-info → kevin_toolbox_dev-1.4.9.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,51 @@
1
- import os
2
- import copy
3
1
  import matplotlib.pyplot as plt
4
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
2
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
3
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
4
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
6
5
 
6
+ __name = ":common_charts:plot_lines"
7
7
 
8
- def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, **kwargs):
9
- data_s = copy.copy(data_s)
8
+
9
+ @COMMON_CHARTS.register(name=__name)
10
+ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, output_path=None, **kwargs):
11
+ """
12
+ 绘制折线图
13
+
14
+ 参数:
15
+ data_s: <dict> 数据。
16
+ 形如 {<data_name>: <data list>, ...} 的字典
17
+ title: <str> 绘图标题,同时用于保存图片的文件名。
18
+ x_name: <str> 以哪个 data_name 作为 x 轴。
19
+ 其余数据视为需要被绘制的数据点。
20
+ 例子: data_s={"step":[...], "acc_top1":[...], "acc_top3":[...]}
21
+ 当 x_name="step" 时,将会以 step 为 x 轴绘制 acc_top1 和 acc_top3 的 bar 图。
22
+ x_ticklabels_name: <str or None> 若提供则表示 x 轴刻度标签对应的键名,用于替换默认的 x 轴刻度值。
23
+ output_dir: <str> 图片输出目录。
24
+ output_path: <str> 图片输出路径。
25
+ 以上两个只需指定一个即可,同时指定时以后者为准。
26
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
27
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
28
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
29
+
30
+ 其他可选参数:
31
+ dpi: <int> 保存图像的分辨率。
32
+ 默认为 200。
33
+ suffix: <str> 图片保存后缀。
34
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
35
+ b_generate_record: <boolean> 是否保存函数参数为档案。
36
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
37
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
38
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
39
+ color_ls: <list> 用于绘图的颜色列表,默认根据数据序列个数自动生成。
40
+ marker_ls: <list of str> 折线图上各数据点的标记。
41
+ linestyle_ls: <list of str> 线型。
42
+ 默认值为 '-',表示直线。
43
+ """
10
44
  line_nums = len(data_s) - 1
11
45
  paras = {
12
46
  "dpi": 200,
47
+ "suffix": ".png",
48
+ "b_generate_record": False,
13
49
  "color_ls": generate_color_list(nums=line_nums),
14
50
  "marker_ls": None,
15
51
  "linestyle_ls": '-',
@@ -19,6 +55,12 @@ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, *
19
55
  if k.endswith("_ls") and not isinstance(v, (list, tuple)):
20
56
  paras[k] = [v] * line_nums
21
57
  assert line_nums == len(paras["color_ls"]) == len(paras["marker_ls"]) == len(paras["linestyle_ls"])
58
+ #
59
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
+ save_record(_func=plot_lines, _name=__name,
61
+ _output_path=_output_path if paras["b_generate_record"] else None,
62
+ **paras)
63
+ data_s = data_s.copy()
22
64
 
23
65
  plt.clf()
24
66
  #
@@ -48,20 +90,18 @@ def plot_lines(data_s, title, x_name, x_ticklabels_name=None, output_dir=None, *
48
90
  # 显示图例
49
91
  plt.legend()
50
92
 
51
- if output_dir is None:
52
- plt.show()
53
- return None
54
- else:
55
- # 对非法字符进行替换
56
- os.makedirs(output_dir, exist_ok=True)
57
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
58
- plt.savefig(output_path, dpi=paras["dpi"])
59
- return output_path
93
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
94
+ return _output_path
60
95
 
61
96
 
62
97
  if __name__ == '__main__':
63
- plot_lines(data_s={
64
- 'a': [1, 2, 3, 4, 5],
65
- 'b': [5, 4, 3, 2, 1],
66
- 'c': [1, 2, 3, 4, 5]},
67
- title='test', x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp"))
98
+ import os
99
+
100
+ plot_lines(
101
+ data_s={
102
+ 'a': [1, 2, 3, 4, 5],
103
+ 'b': [5, 4, 3, 2, 1],
104
+ 'c': [1, 2, 3, 4, 5]},
105
+ title='test_plot_lines',
106
+ x_name='a', output_dir=os.path.join(os.path.dirname(__file__), "temp")
107
+ )
@@ -1,15 +1,65 @@
1
- import os
2
1
  import matplotlib.pyplot as plt
3
2
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
4
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
3
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
4
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
5
5
 
6
+ __name = ":common_charts:plot_scatters"
6
7
 
7
- def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, **kwargs):
8
+
9
+ @COMMON_CHARTS.register(name=__name)
10
+ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None, output_path=None, **kwargs):
11
+ """
12
+ 绘制散点图
13
+ 不同类别的数据点
14
+
15
+ 参数:
16
+ data_s: <dict>
17
+ 数据字典,其中每个键对应一个数据列表。必须包含 x_name 和 y_name 对应的键,
18
+ 如果指定了 cate_name,还需要包含 cate_name 对应的键,用于按类别对散点图上数据点着色。
19
+ title: <str>
20
+ 图形标题,同时用于生成保存图像时的文件名(标题中的非法字符会被替换)。
21
+ x_name: <str> 以哪个 data_name 作为数据点的 x 轴。
22
+ y_name: <str> 以哪个 data_name 作为数据点的 y 轴。
23
+ cate_name: <str> 以哪个 data_name 作为数据点的类别。
24
+ output_dir: <str or None> 图片输出目录。
25
+ output_path: <str or None> 图片输出路径。
26
+ 以上两个只需指定一个即可,同时指定时以后者为准。
27
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
28
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
29
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
30
+
31
+ 其他可选参数:
32
+ dpi: <int> 保存图像的分辨率。
33
+ 默认为 200。
34
+ suffix: <str> 图片保存后缀。
35
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
36
+ b_generate_record: <boolean> 是否保存函数参数为档案。
37
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
38
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
39
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
40
+ scatter_size: <int> 散点的大小。
41
+ 默认 5。
42
+
43
+ 示例:
44
+ >>> data = {
45
+ ... "age": [25, 30, 22, 40],
46
+ ... "income": [50000, 60000, 45000, 80000],
47
+ ... "gender": ["M", "F", "M", "F"]
48
+ ... }
49
+ >>> path = plot_scatters(data, "Age vs Income", "age", "income", cate_name="gender", output_dir="./plots")
50
+ """
8
51
  paras = {
9
52
  "dpi": 200,
53
+ "suffix": ".png",
54
+ "b_generate_record": False,
10
55
  "scatter_size": 5
11
56
  }
12
57
  paras.update(kwargs)
58
+ #
59
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
60
+ save_record(_func=plot_scatters, _name=__name,
61
+ _output_path=_output_path if paras["b_generate_record"] else None,
62
+ **paras)
13
63
 
14
64
  plt.clf()
15
65
  #
@@ -33,21 +83,20 @@ def plot_scatters(data_s, title, x_name, y_name, cate_name=None, output_dir=None
33
83
  markersize=min(paras["scatter_size"], 5)) for i, j in color_s.items()
34
84
  ])
35
85
 
36
- if output_dir is None:
37
- plt.show()
38
- return None
39
- else:
40
- os.makedirs(output_dir, exist_ok=True)
41
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
42
- plt.savefig(output_path, dpi=paras["dpi"])
43
- return output_path
86
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
87
+ return _output_path
44
88
 
45
89
 
46
90
  if __name__ == '__main__':
91
+ import os
92
+
47
93
  data_s_ = dict(
48
94
  x=[1, 2, 3, 4, 5],
49
95
  y=[2, 4, 6, 8, 10],
50
96
  categories=['A', 'B', 'A', 'B', 'A']
51
97
  )
52
98
 
53
- plot_scatters(data_s=data_s_, title='test', x_name='x', y_name='y', cate_name='categories')
99
+ plot_scatters(
100
+ data_s=data_s_, title='test_plot_scatters', x_name='x', y_name='y', cate_name='categories',
101
+ output_dir=os.path.join(os.path.dirname(__file__), "temp")
102
+ )
@@ -1,14 +1,54 @@
1
- import os
2
1
  import pandas as pd
3
2
  import seaborn as sns
4
3
  import matplotlib.pyplot as plt
5
4
  from kevin_toolbox.patches.for_matplotlib.color import generate_color_list
6
- from kevin_toolbox.patches.for_os.path import replace_illegal_chars
5
+ from kevin_toolbox.patches.for_matplotlib.common_charts.utils import save_plot, save_record, get_output_path
6
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
7
7
 
8
+ __name = ":common_charts:plot_scatters_matrix"
8
9
 
9
- def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, cate_color_s=None, **kwargs):
10
+
11
+ @COMMON_CHARTS.register(name=__name)
12
+ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=None, output_path=None, cate_color_s=None,
13
+ **kwargs):
14
+ """
15
+ 绘制散点图矩阵
16
+ 该函数用于展示多个变量之间的两两关系。
17
+
18
+ 参数:
19
+ data_s: <dict> 数据。
20
+ 形如 {<data_name>: <data list>, ...} 的字典
21
+ title: <str> 绘图标题。
22
+ x_name_ls: <list of str> 指定用于需要绘制哪些变量的两两关系。
23
+ cate_name: <str or None> 使用哪个 data_name 对应的取值作为类别。
24
+ output_dir: <str or None> 图片输出目录。
25
+ output_path: <str or None> 图片输出路径。
26
+ 以上两个只需指定一个即可,同时指定时以后者为准。
27
+ 当只有 output_dir 被指定时,将会以 title 作为图片名。
28
+ 若同时不指定,则直接调用 plt.show() 显示图像,而不进行保存。
29
+ 在保存为文件时,若文件名中存在路径不适宜的非法字符将会被进行替换。
30
+ cate_color_s: <dict or None> 类别-颜色映射字典,将 cate_name 中的每个类别映射到具体颜色。
31
+ 默认为 None,自动生成颜色列表。
32
+
33
+ 其他可选参数:
34
+ dpi: <int> 保存图像的分辨率。
35
+ 默认为 200。
36
+ suffix: <str> 图片保存后缀。
37
+ 目前支持的取值有 ".png", ".jpg", ".bmp",默认为第一个。
38
+ b_generate_record: <boolean> 是否保存函数参数为档案。
39
+ 默认为 False,当设置为 True 时将会把函数参数保存成 [output_path].record.tar。
40
+ 后续可以使用 plot_from_record() 函数或者 Serializer_for_Registry_Execution 读取该档案,并进行修改和重新绘制。
41
+ 该参数仅在 output_dir 和 output_path 非 None 时起效。
42
+ diag_kind: <str> 对角线图表的类型。
43
+ 支持:
44
+ - "hist"(直方图)
45
+ - "kde"(核密度图)
46
+ 默认为 "kde"。
47
+ """
10
48
  paras = {
11
49
  "dpi": 200,
50
+ "suffix": ".png",
51
+ "b_generate_record": False,
12
52
  "diag_kind": "kde" # 设置对角线图直方图/密度图 {‘hist’, ‘kde’}
13
53
  }
14
54
  assert cate_name in data_s and len(set(x_name_ls).difference(set(data_s.keys()))) == 0
@@ -17,6 +57,11 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
17
57
  cate_color_s = {k: v for k, v in zip(temp, generate_color_list(len(temp)))}
18
58
  assert set(cate_color_s.keys()) == set(data_s[cate_name])
19
59
  paras.update(kwargs)
60
+ #
61
+ _output_path = get_output_path(output_path=output_path, output_dir=output_dir, title=title, **kwargs)
62
+ save_record(_func=plot_scatters_matrix, _name=__name,
63
+ _output_path=_output_path if paras["b_generate_record"] else None,
64
+ **paras)
20
65
 
21
66
  plt.clf()
22
67
  # 使用seaborn绘制散点图矩阵
@@ -29,19 +74,14 @@ def plot_scatters_matrix(data_s, title, x_name_ls, cate_name=None, output_dir=No
29
74
  #
30
75
  plt.subplots_adjust(top=0.95)
31
76
  plt.suptitle(f'{title}', y=0.98, x=0.47)
32
- # g.fig.suptitle(f'{title}', y=1.05)
33
77
 
34
- if output_dir is None:
35
- plt.show()
36
- return None
37
- else:
38
- os.makedirs(output_dir, exist_ok=True)
39
- output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}.png')
40
- plt.savefig(output_path, dpi=paras["dpi"])
41
- return output_path
78
+ save_plot(plt=plt, output_path=_output_path, dpi=paras["dpi"], suffix=paras["suffix"])
79
+ return _output_path
42
80
 
43
81
 
44
82
  if __name__ == '__main__':
83
+ import os
84
+
45
85
  data_s_ = dict(
46
86
  x=[1, 2, 3, 4, 5],
47
87
  y=[2, 4, 6, 8, 10],
@@ -50,5 +90,8 @@ if __name__ == '__main__':
50
90
  title='test',
51
91
  )
52
92
 
53
- plot_scatters_matrix(data_s=data_s_, title='test', x_name_ls=['y', 'x', 'z'], cate_name='categories',
54
- cate_color_s={'A': 'red', 'B': 'blue'})
93
+ plot_scatters_matrix(
94
+ data_s=data_s_, title='test_plot_scatters_matrix', x_name_ls=['y', 'x', 'z'], cate_name='categories',
95
+ cate_color_s={'A': 'red', 'B': 'blue'},
96
+ output_dir=os.path.join(os.path.dirname(__file__), "temp")
97
+ )
@@ -0,0 +1,3 @@
1
+ from .get_output_path import get_output_path
2
+ from .save_plot import save_plot
3
+ from .save_record import save_record
@@ -0,0 +1,15 @@
1
+ import os
2
+ from kevin_toolbox.patches.for_os.path import replace_illegal_chars
3
+
4
+
5
+ def get_output_path(output_path=None, output_dir=None, title=None, suffix=".png", **kwargs):
6
+ if output_path is None:
7
+ if output_dir is None:
8
+ output_path = None
9
+ else:
10
+ assert title is not None
11
+ assert suffix in [".png", ".jpg", ".bmp"]
12
+ output_path = os.path.join(output_dir, f'{replace_illegal_chars(title)}{suffix}')
13
+ else:
14
+ output_path = os.path.join(os.path.dirname(output_path), replace_illegal_chars(os.path.basename(output_path)))
15
+ return output_path
@@ -0,0 +1,12 @@
1
+ import os
2
+
3
+
4
+ def save_plot(plt, output_path, dpi=200, suffix=".png", **kwargs):
5
+ assert suffix in [".png", ".jpg", ".bmp"]
6
+
7
+ if output_path is None:
8
+ plt.show()
9
+ else:
10
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
11
+ plt.savefig(output_path, dpi=dpi)
12
+ plt.close()
@@ -0,0 +1,34 @@
1
+ import inspect
2
+ from kevin_toolbox.patches.for_matplotlib.variable import COMMON_CHARTS
3
+
4
+
5
+ def save_record(_name, _output_path,_func=None, **kwargs):
6
+ if _output_path is None:
7
+ return None
8
+
9
+ # 获取函数的参数列表
10
+ frame = inspect.currentframe().f_back
11
+ if _func is None:
12
+ func_name = inspect.getframeinfo(frame).function
13
+ _func = globals()[func_name]
14
+ sig = inspect.signature(_func)
15
+ # 获取参数值
16
+ args_info = inspect.getargvalues(frame)
17
+ arg_names = args_info.args
18
+ arg_values = args_info.locals
19
+ # 将参数值映射到函数签名中
20
+ kwargs_raw = {name: arg_values[name] for name in arg_names}
21
+ for param_name, param in sig.parameters.items():
22
+ if param.kind == param.VAR_KEYWORD:
23
+ kwargs_raw.update(arg_values[param_name])
24
+ elif param.kind == param.VAR_POSITIONAL:
25
+ kwargs_raw[param_name] = arg_values[param_name]
26
+
27
+ kwargs_raw.update(kwargs)
28
+
29
+ from kevin_toolbox.computer_science.algorithm.registration import Serializer_for_Registry_Execution
30
+ serializer = Serializer_for_Registry_Execution()
31
+ file_path = serializer.record_name(
32
+ _name=_name, _registry=COMMON_CHARTS
33
+ ).record_paras(**kwargs_raw).save(_output_path + ".record", b_pack_into_tar=True, b_allow_overwrite=True)
34
+ return file_path
@@ -0,0 +1,20 @@
1
+ import os
2
+ from kevin_toolbox.computer_science.algorithm.registration import Registry
3
+
4
+ COMMON_CHARTS = Registry(uid="COMMON_CHARTS")
5
+
6
+ # 导入时的默认过滤规则
7
+ ignore_s = [
8
+ {
9
+ "func": lambda _, __, path: os.path.basename(path) in ["temp", "test", "__pycache__",
10
+ "_old_version"],
11
+ "scope": ["root", "dirs"]
12
+ },
13
+ ]
14
+
15
+ # 从 kevin_toolbox/patches/for_matplotlib/common_charts 下收集被注册的方法
16
+ COMMON_CHARTS.collect_from_paths(
17
+ path_ls=[os.path.join(os.path.dirname(__file__), "common_charts"), ],
18
+ ignore_s=ignore_s,
19
+ b_execute_now=False
20
+ )
@@ -30,7 +30,7 @@ def softmax(x, axis=-1, temperature=None, b_use_log_over_x=False):
30
30
  else:
31
31
  # softmax(x)
32
32
  # 为了数值稳定,减去最大值
33
- x = x - np.max(x)
33
+ x = x - np.max(x, axis=axis, keepdims=True)
34
34
  #
35
35
  if temperature is not None:
36
36
  assert temperature > 0
@@ -46,3 +46,6 @@ if __name__ == '__main__':
46
46
  print(softmax(np.asarray([[[0], [0.1]]]), temperature=0.00001, axis=1))
47
47
  print(softmax(np.asarray([[[0], [0.1]]]), temperature=0, axis=1))
48
48
  print(softmax([0, 1, 2], temperature=0.1))
49
+ print(softmax([[5.0000e-01, 5.0000e-01],
50
+ [7.0000e-01, 3.0000e-01],
51
+ [0.0000e+00, 1.0000e+03]], axis=-1, temperature=None))
@@ -0,0 +1,75 @@
1
+ Metadata-Version: 2.1
2
+ Name: kevin-toolbox-dev
3
+ Version: 1.4.9
4
+ Summary: 一个常用的工具代码包集合
5
+ Home-page: https://github.com/cantbeblank96/kevin_toolbox
6
+ Download-URL: https://github.com/username/your-package/archive/refs/tags/v1.0.0.tar.gz
7
+ Author: kevin hsu
8
+ Author-email: xukaiming1996@163.com
9
+ License: MIT
10
+ Keywords: mathematics,pytorch,numpy,machine-learning,algorithm
11
+ Platform: UNKNOWN
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python
14
+ Classifier: Programming Language :: Python :: 3
15
+ Requires-Python: >=3.6
16
+ Description-Content-Type: text/markdown
17
+ Requires-Dist: torch (>=1.2.0)
18
+ Requires-Dist: numpy (>=1.19.0)
19
+ Provides-Extra: plot
20
+ Requires-Dist: matplotlib (>=3.0) ; extra == 'plot'
21
+ Provides-Extra: rest
22
+ Requires-Dist: pytest (>=6.2.5) ; extra == 'rest'
23
+ Requires-Dist: line-profiler (>=3.5) ; extra == 'rest'
24
+
25
+ # kevin_toolbox
26
+
27
+ 一个通用的工具代码包集合
28
+
29
+
30
+
31
+ 环境要求
32
+
33
+ ```shell
34
+ numpy>=1.19
35
+ pytorch>=1.2
36
+ ```
37
+
38
+ 安装方法:
39
+
40
+ ```shell
41
+ pip install kevin-toolbox --no-dependencies
42
+ ```
43
+
44
+
45
+
46
+ [项目地址 Repo](https://github.com/cantbeblank96/kevin_toolbox)
47
+
48
+ [使用指南 User_Guide](./notes/User_Guide.md)
49
+
50
+ [免责声明 Disclaimer](./notes/Disclaimer.md)
51
+
52
+ [版本更新记录](./notes/Release_Record.md):
53
+
54
+ - v 1.4.9 (2025-03-27)【new feature】【bug fix】
55
+
56
+ - patches.for_numpy.linalg
57
+ - 【bug fix】fix bug in softmax(),修改 33 行,从减去全局最大值改为减去各个分组内部的最大值,避免全局最大值过大导致某些分组全体数值过小导致计算溢出。
58
+ - patches.for_matplotlib.common_charts.utils
59
+ - modify save_plot(),在最后增加 plt.close() 用于及时销毁已使用完的画布,避免不必要的内存占用。
60
+ - nested_dict_list
61
+ - 【new feature】modify traverse(),增加以下参数以更加精确地控制遍历时的行为:
62
+ - b_skip_repeated_non_leaf_node: 是否跳过重复的非叶节点。
63
+ - 何为重复?在内存中的id相同。
64
+ - 默认为 None,此时将根据 action_mode 的来决定:
65
+ - 对于会对节点进行修改的模式,比如 "remove" 和 "replace",将设为 True,以避免预期外的重复转换和替换。
66
+ - 对于不会修改节点内容的模式,比如 "skip",将设为 False。
67
+ - cond_for_repeated_leaf_to_skip:函数列表。在叶节点位置上,遇到满足其中某个条件的重复的元素时需要跳过。
68
+ - 同步修改内部使用了 traverse() 的 get_nodes() 和 copy_() 等函数。
69
+ - 新增了对应的测试用例。
70
+ - data_flow.file.json_
71
+ - 【bug fix】fix bug in write()。
72
+ - bug 归因:在 json_.write() 中通过使用 ndl.traverse() 来找出待转换的元素并进行转换,但是在 v1.4.8 前,该函数默认不会跳过重复(在内存中的id相同)出现的内容。由于该内容的不同引用实际上指向的是同一个,因此对这些引用的分别多次操作实际上就是对该内容进行了多次操作。
73
+ - bug 解决:在后续 v1.4.9 中为 ndl.traverse() 新增了 b_skip_repeated_non_leaf_node 用于控制是否需要跳过重复的引用。我们只需要在使用该函数时,令参数 b_skip_repeated_non_leaf_node=True即可。
74
+
75
+