kevin-toolbox-dev 1.3.1__py3-none-any.whl → 1.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kevin_toolbox/__init__.py +2 -2
- kevin_toolbox/computer_science/algorithm/cache_manager/__init__.py +1 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/cache/__init__.py +2 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/cache/cache_base.py +89 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/cache/memo_cache.py +36 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/cache_manager.py +218 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/strategy/__init__.py +5 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/strategy/fifo_strategy.py +21 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/strategy/lfu_strategy.py +80 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/strategy/lru_strategy.py +43 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/strategy/lst_strategy.py +26 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/strategy/strategy_base.py +45 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/test/__init__.py +0 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/test/test_cache_builder.py +37 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/test/test_cache_manager.py +197 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/test/test_cache_strategy.py +129 -0
- kevin_toolbox/computer_science/algorithm/cache_manager/variable.py +28 -0
- kevin_toolbox/computer_science/algorithm/registration/registry.py +38 -16
- kevin_toolbox/data_flow/core/cache/__init__.py +1 -1
- kevin_toolbox/data_flow/core/cache/cache_manager_for_iterator.py +36 -168
- kevin_toolbox/data_flow/core/cache/test/__init__.py +0 -0
- kevin_toolbox/data_flow/core/cache/test/test_cache_manager_for_iterator.py +34 -0
- kevin_toolbox/data_flow/core/reader/file_iterative_reader.py +44 -9
- kevin_toolbox/data_flow/core/reader/unified_reader.py +2 -2
- kevin_toolbox/data_flow/core/reader/unified_reader_base.py +4 -5
- kevin_toolbox/data_flow/file/json_/converter/__init__.py +2 -2
- kevin_toolbox/data_flow/file/json_/converter/escape_tuple_and_set.py +23 -0
- kevin_toolbox/data_flow/file/json_/converter/{unescape_tuple.py → unescape_tuple_and_set.py} +7 -5
- kevin_toolbox/data_flow/file/json_/read_json.py +3 -3
- kevin_toolbox/data_flow/file/json_/write_json.py +3 -3
- kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_reader.py +6 -5
- kevin_toolbox/data_flow/file/kevin_notation/read.py +4 -2
- kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation.py +15 -3
- kevin_toolbox/data_flow/file/markdown/generate_table.py +2 -2
- kevin_toolbox/math/utils/__init__.py +1 -1
- kevin_toolbox/math/utils/{spilt_integer_most_evenly.py → split_integer_most_evenly.py} +2 -2
- kevin_toolbox/nested_dict_list/get_nodes.py +9 -4
- kevin_toolbox/nested_dict_list/name_handler/build_name.py +1 -1
- kevin_toolbox/nested_dict_list/name_handler/parse_name.py +1 -1
- kevin_toolbox/nested_dict_list/set_default.py +44 -28
- kevin_toolbox/patches/for_matplotlib/__init__.py +1 -0
- kevin_toolbox/patches/for_matplotlib/generate_color_list.py +33 -0
- kevin_toolbox/patches/for_numpy/linalg/__init__.py +1 -0
- kevin_toolbox/patches/for_numpy/linalg/entropy.py +26 -0
- kevin_toolbox/patches/for_numpy/random/__init__.py +3 -0
- kevin_toolbox/patches/for_numpy/random/get_rng.py +64 -0
- kevin_toolbox/patches/for_numpy/random/truncated_multivariate_normal.py +129 -0
- kevin_toolbox/patches/for_numpy/random/truncated_normal.py +89 -0
- kevin_toolbox/patches/for_numpy/random/variable.py +10 -0
- kevin_toolbox/patches/for_optuna/serialize/for_study/dump.py +10 -2
- kevin_toolbox_dev-1.3.3.dist-info/METADATA +75 -0
- {kevin_toolbox_dev-1.3.1.dist-info → kevin_toolbox_dev-1.3.3.dist-info}/RECORD +54 -29
- kevin_toolbox/data_flow/file/json_/converter/escape_tuple.py +0 -20
- kevin_toolbox_dev-1.3.1.dist-info/METADATA +0 -91
- {kevin_toolbox_dev-1.3.1.dist-info → kevin_toolbox_dev-1.3.3.dist-info}/WHEEL +0 -0
- {kevin_toolbox_dev-1.3.1.dist-info → kevin_toolbox_dev-1.3.3.dist-info}/top_level.txt +0 -0
@@ -134,11 +134,23 @@ def test_writer_1(expected_metadata, expected_content, file_path):
|
|
134
134
|
zip(metadata_ls, content_ls, file_path_ls))
|
135
135
|
def test_read(expected_metadata, expected_content, file_path):
|
136
136
|
print("test read()")
|
137
|
-
# 读取
|
137
|
+
# 使用 file_path 读取
|
138
138
|
metadata, content = kevin_notation.read(file_path=file_path)
|
139
|
+
|
140
|
+
# 使用 file_obj 读取
|
141
|
+
file_obj = open(file_path, "r")
|
142
|
+
metadata_1, content_1 = kevin_notation.read(file_obj=file_obj)
|
143
|
+
assert len(file_obj.read()) > 0 # 不影响输入的 file_obj
|
144
|
+
|
145
|
+
# 使用字符串构造 file_obj 读取
|
146
|
+
from io import StringIO
|
147
|
+
file_obj = StringIO(initial_value=open(file_path, "r").read())
|
148
|
+
metadata_2, content_2 = kevin_notation.read(file_obj=file_obj)
|
149
|
+
assert len(file_obj.read()) > 0
|
150
|
+
|
139
151
|
# 检验
|
140
|
-
check_consistency(expected_metadata, metadata)
|
141
|
-
check_consistency(expected_content, content)
|
152
|
+
check_consistency(expected_metadata, metadata, metadata_1, metadata_2)
|
153
|
+
check_consistency(expected_content, content, content_1, content_2)
|
142
154
|
|
143
155
|
|
144
156
|
@pytest.mark.parametrize("expected_metadata, expected_content, file_path",
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from kevin_toolbox.math.utils import
|
1
|
+
from kevin_toolbox.math.utils import split_integer_most_evenly
|
2
2
|
|
3
3
|
|
4
4
|
def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_size=None, b_allow_misaligned_values=False,
|
@@ -60,7 +60,7 @@ def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_siz
|
|
60
60
|
# 按照 chunk_nums 或者 chunk_size 对表格进行分割
|
61
61
|
if chunk_nums is not None or chunk_size is not None:
|
62
62
|
if chunk_nums is not None:
|
63
|
-
split_len_ls =
|
63
|
+
split_len_ls = split_integer_most_evenly(x=max_length, group_nums=chunk_nums)
|
64
64
|
else:
|
65
65
|
split_len_ls = [chunk_size] * (max_length // chunk_size)
|
66
66
|
if max_length % chunk_size != 0:
|
@@ -2,4 +2,4 @@ from .get_function_table_for_array_and_tensor import get_function_table_for_arra
|
|
2
2
|
from .convert_dtype import convert_dtype
|
3
3
|
from .get_crop_by_box import get_crop_by_box
|
4
4
|
from .set_crop_by_box import set_crop_by_box
|
5
|
-
from .
|
5
|
+
from .split_integer_most_evenly import split_integer_most_evenly
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
|
4
|
-
def
|
4
|
+
def split_integer_most_evenly(x, group_nums):
|
5
5
|
assert isinstance(x, (int, np.integer,)) and x >= 0 and group_nums > 0
|
6
6
|
|
7
7
|
res = np.ones(group_nums, dtype=int) * (x // group_nums)
|
@@ -10,4 +10,4 @@ def spilt_integer_most_evenly(x, group_nums):
|
|
10
10
|
|
11
11
|
|
12
12
|
if __name__ == '__main__':
|
13
|
-
print(
|
13
|
+
print(split_integer_most_evenly(x=100, group_nums=7))
|
@@ -25,21 +25,26 @@ def get_nodes(var, level=-1, b_strict=True):
|
|
25
25
|
assert isinstance(level, (int,))
|
26
26
|
if level == 0:
|
27
27
|
return [("", var)]
|
28
|
+
|
29
|
+
# 首先找出所有叶节点 level=-1,以及空的 level=-2 的节点
|
28
30
|
res = []
|
31
|
+
res_empty = set()
|
29
32
|
|
30
|
-
# 首先找出所有叶节点 level=-1
|
31
33
|
def func(_, idx, v):
|
32
34
|
nonlocal res
|
33
35
|
if not isinstance(v, (list, dict,)):
|
34
36
|
res.append((idx, v))
|
37
|
+
elif len(v) == 0:
|
38
|
+
res_empty.add(idx + "@None") # 添加哨兵,表示空节点,并不会被实际解释
|
35
39
|
return False
|
36
40
|
|
37
41
|
traverse(var=var, match_cond=func, action_mode="skip", b_use_name_as_idx=True)
|
38
42
|
|
39
43
|
if level != -1:
|
40
44
|
names = set()
|
45
|
+
leaf_node_names = res_empty.union(set(i for i, _ in res))
|
41
46
|
if level < -1:
|
42
|
-
for name
|
47
|
+
for name in leaf_node_names:
|
43
48
|
root_node, _, node_ls = parse_name(name=name, b_de_escape_node=False)
|
44
49
|
node_ls.insert(0, root_node)
|
45
50
|
temp = [len(i) for i in node_ls[level + 1:]]
|
@@ -50,7 +55,7 @@ def get_nodes(var, level=-1, b_strict=True):
|
|
50
55
|
temp = max(temp, 0)
|
51
56
|
names.add(name[:temp])
|
52
57
|
elif level > 0:
|
53
|
-
for name
|
58
|
+
for name in leaf_node_names:
|
54
59
|
root_node, _, node_ls = parse_name(name=name, b_de_escape_node=False)
|
55
60
|
node_ls.insert(0, root_node)
|
56
61
|
temp = [len(i) for i in node_ls[:level + 1]]
|
@@ -61,7 +66,7 @@ def get_nodes(var, level=-1, b_strict=True):
|
|
61
66
|
else:
|
62
67
|
raise ValueError
|
63
68
|
res.clear()
|
64
|
-
for name in names:
|
69
|
+
for name in names.difference(res_empty):
|
65
70
|
res.append((name, get_value(var=var, name=name)))
|
66
71
|
|
67
72
|
return res
|
@@ -1,40 +1,56 @@
|
|
1
1
|
from kevin_toolbox.nested_dict_list import get_value, set_value
|
2
2
|
|
3
3
|
|
4
|
-
def set_default(var, name, default, b_force=False, cache_for_verified_names=None):
|
4
|
+
def set_default(var, name, default, b_force=False, cache_for_verified_names=None, b_return_var=False):
|
5
5
|
"""
|
6
6
|
当 name 指向的位置在 var 中不存在时,将会把 default 插入到对应的位置。
|
7
7
|
(类似于 dict.setdefault() 的行为)
|
8
8
|
|
9
9
|
参数:
|
10
|
-
var:
|
11
|
-
name:
|
12
|
-
|
13
|
-
default:
|
14
|
-
b_force:
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
10
|
+
var: 任意支持索引赋值的变量
|
11
|
+
name: <string> 名字(位置)
|
12
|
+
名字 name 的具体介绍参见函数 name_handler.parse_name()
|
13
|
+
default: 默认值
|
14
|
+
b_force: <boolean> 当无法将 default 设置到 name 指定的位置时,是否尝试创建或者修改 var 中的节点
|
15
|
+
默认为 False,此时若无法设置,则报错
|
16
|
+
当设置为 True,可能会对 var 的结构产生不可逆的改变,请谨慎使用。
|
17
|
+
注意:
|
18
|
+
若 b_force 为 True 有可能不会在 var 的基础上进行改变,而是返回一个新的ndl结构,
|
19
|
+
因此建议使用赋值 var = ndl.set_default(var) 来避免可能的错误。
|
20
20
|
|
21
21
|
cache_for_verified_names: <set> 用于缓存已检验的 name
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
22
|
+
默认为 None,不使用缓存。
|
23
|
+
当设为某个集合时,将开启缓存。此时将首先判断 name 是否在缓存中,若在,则视为之前已经对该 name 成功进行过 set_default()
|
24
|
+
操作,没有必要再重复执行,因此直接跳过后续流程;若不在,则执行 set_default() 操作,并在成功执行之后将该 name
|
25
|
+
补充到缓存中。
|
26
|
+
合理利用该缓存机制将可以避免对同一个 name 反复进行 set_default() 操作,从而提高效率。
|
27
|
+
b_return_var: <boolean> 是否将 var 也添加到返回值中
|
28
|
+
默认为 False
|
29
|
+
由于本函数内调用的 set_value() 函数有可能不会在 var 的基础上进行改变,而是返回一个新的ndl结构,因此可以通过该参数,
|
30
|
+
将 var 通过返回以获取新的ndl。
|
31
|
+
|
32
|
+
返回:
|
33
|
+
res: 返回指定位置的值,或者默认值(与 dict.setdefault() 的行为一致)
|
32
34
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
35
|
+
特别地,当 b_return_var=True 时,返回的是一个元组 res, var。
|
36
|
+
"""
|
37
|
+
assert isinstance(cache_for_verified_names, (set, type(None)))
|
38
|
+
if cache_for_verified_names is not None and name in cache_for_verified_names:
|
39
|
+
try:
|
40
|
+
res = get_value(var=var, name=name)
|
41
|
+
except :
|
42
|
+
raise Exception(f'name {name} is in the cache, but it cannot be found in var')
|
43
|
+
else:
|
44
|
+
temp = list()
|
45
|
+
res = get_value(var=var, name=name, default=temp)
|
46
|
+
if id(temp) == id(res):
|
47
|
+
# 获取失败,说明 name 指向的位置不存在,则尝试创建
|
48
|
+
var = set_value(var=var, name=name, value=default, b_force=b_force)
|
49
|
+
res = default
|
50
|
+
if cache_for_verified_names is not None:
|
51
|
+
cache_for_verified_names.add(name)
|
37
52
|
|
38
|
-
if
|
39
|
-
|
40
|
-
|
53
|
+
if b_return_var:
|
54
|
+
return res, var
|
55
|
+
else:
|
56
|
+
return res
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
PREDEFINED = ['blue', 'red', 'green', 'orange', 'purple', 'yellow', "brown", "pink", "gray", "olive", "cyan"]
|
4
|
+
|
5
|
+
|
6
|
+
def generate_color_list(nums, seed=None, exclude_ls=None):
|
7
|
+
"""
|
8
|
+
参数:
|
9
|
+
nums: <int> 生成颜色的数量
|
10
|
+
seed: 随机种子
|
11
|
+
exclude: <list of str> 需要排除的颜色
|
12
|
+
"""
|
13
|
+
global PREDEFINED
|
14
|
+
if exclude_ls is None:
|
15
|
+
exclude_ls = []
|
16
|
+
assert isinstance(exclude_ls, (list, tuple))
|
17
|
+
|
18
|
+
colors = [i for i in PREDEFINED if i not in exclude_ls][:nums] # 优先输出预定义的颜色
|
19
|
+
|
20
|
+
# 随机生成剩余数量的颜色
|
21
|
+
if seed is not None:
|
22
|
+
random.seed(seed)
|
23
|
+
while len(colors) < nums:
|
24
|
+
c = "#" + ''.join(random.choices('0123456789ABCDEF', k=6))
|
25
|
+
if c not in colors and c not in exclude_ls:
|
26
|
+
colors.append(c)
|
27
|
+
|
28
|
+
return colors
|
29
|
+
|
30
|
+
|
31
|
+
if __name__ == '__main__':
|
32
|
+
color_list = generate_color_list(20)
|
33
|
+
print(color_list)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def entropy(pdf, b_need_normalize=False):
|
5
|
+
"""
|
6
|
+
参数:
|
7
|
+
pdf: <list/array> 概率分布
|
8
|
+
b_need_normalize: <boolean> 是否将输出的熵归一化到 0~1 区间
|
9
|
+
默认为 False,此时输出的熵遵从原始定义,其大小范围与输出概率分布的维度 n 有关,为 0 ~ log_{2}(n)
|
10
|
+
当设置为 True,将会除以 log_{2}(n) 以进行归一化。
|
11
|
+
"""
|
12
|
+
pdf = np.asarray(pdf, dtype=float).reshape(-1)
|
13
|
+
# 将概率分布数组中小于等于0的元素设为1,避免出现log(0)的情况
|
14
|
+
pdf = np.maximum(pdf, 1e-15)
|
15
|
+
#
|
16
|
+
pdf /= np.sum(pdf)
|
17
|
+
# 计算熵值
|
18
|
+
res = -np.sum(pdf * np.log2(pdf))
|
19
|
+
if b_need_normalize:
|
20
|
+
res /= np.log2(len(pdf))
|
21
|
+
return res
|
22
|
+
|
23
|
+
|
24
|
+
if __name__ == '__main__':
|
25
|
+
print(entropy(pdf=[0.1, 0.1, 0.7, 0.1], b_need_normalize=True)) # 0.6783898247235198
|
26
|
+
print(entropy(pdf=[0.05, 0.05, 0.05, 0.05, 0.8], b_need_normalize=True)) # 0.4831881303119284
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
class DEFAULT_RNG:
|
5
|
+
pass
|
6
|
+
|
7
|
+
|
8
|
+
for name in np.random.__all__:
|
9
|
+
setattr(DEFAULT_RNG, name, getattr(np.random, name))
|
10
|
+
|
11
|
+
func_map_s = {
|
12
|
+
'rand': lambda rng, *arg, **kwargs: rng.random(size=arg, **kwargs),
|
13
|
+
'randint': 'integers',
|
14
|
+
'randn': lambda rng, *arg, **kwargs: rng.normal(size=arg, **kwargs),
|
15
|
+
'random_integers': 'integers',
|
16
|
+
'random_sample': 'random',
|
17
|
+
'ranf': 'random'
|
18
|
+
}
|
19
|
+
|
20
|
+
|
21
|
+
class My_RNG:
|
22
|
+
def __init__(self, rng):
|
23
|
+
self._rng = rng
|
24
|
+
|
25
|
+
for name in np.random.__all__:
|
26
|
+
setattr(DEFAULT_RNG, name, getattr(np.random, name))
|
27
|
+
|
28
|
+
# self.key
|
29
|
+
def __getattr__(self, key):
|
30
|
+
if "_rng" not in self.__dict__:
|
31
|
+
# _rng 未被设置,未完成初始化。
|
32
|
+
return super().__getattr__(key)
|
33
|
+
else:
|
34
|
+
res = getattr(self._rng, key, None)
|
35
|
+
if res is None and key in func_map_s:
|
36
|
+
if callable(func_map_s[key]):
|
37
|
+
res = lambda *arg, **kwargs: func_map_s[key](self._rng, *arg, **kwargs)
|
38
|
+
else:
|
39
|
+
res = getattr(self._rng, func_map_s[key], None)
|
40
|
+
#
|
41
|
+
if res is None:
|
42
|
+
raise AttributeError(f"attribute '{key}' not found in {type(self)}")
|
43
|
+
else:
|
44
|
+
return res
|
45
|
+
|
46
|
+
|
47
|
+
def get_rng(seed=None, rng=None, **kwargs):
|
48
|
+
if seed is not None:
|
49
|
+
# 注意,随机生成器相较于 numpy.random 有部分属性缺失:
|
50
|
+
# ['get_state', 'rand', 'randint', 'randn', 'random_integers', 'random_sample', 'ranf', 'sample', 'seed',
|
51
|
+
# 'set_state', 'Generator', 'RandomState', 'SeedSequence', 'MT19937', 'Philox', 'PCG64', 'PCG64DXSM',
|
52
|
+
# 'SFC64', 'default_rng', 'BitGenerator']
|
53
|
+
rng = My_RNG(rng=np.random.default_rng(seed=seed))
|
54
|
+
if rng is not None:
|
55
|
+
return rng
|
56
|
+
else:
|
57
|
+
return DEFAULT_RNG
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == '__main__':
|
61
|
+
a = get_rng(seed=2)
|
62
|
+
|
63
|
+
# 尝试访问随机生成器中缺失的部分方法
|
64
|
+
print(a.randn(2, 3))
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from scipy.stats import chi2
|
3
|
+
from kevin_toolbox.patches.for_numpy.linalg import normalize
|
4
|
+
from kevin_toolbox.patches.for_numpy.random import get_rng
|
5
|
+
from kevin_toolbox.patches.for_numpy.random.variable import DEFAULT_SETTINGS
|
6
|
+
from kevin_toolbox.computer_science.algorithm.cache_manager import Cache_Manager
|
7
|
+
|
8
|
+
cache_for_cov = Cache_Manager(upper_bound=20, refactor_size=0.5)
|
9
|
+
|
10
|
+
|
11
|
+
def truncated_multivariate_normal(
|
12
|
+
mean, cov=None, low_radius=None, high_radius=None, size=None, b_check_cov=False,
|
13
|
+
hit_ratio_threshold=DEFAULT_SETTINGS["truncated_multivariate_normal"]["hit_ratio_threshold"],
|
14
|
+
expand_ratio=DEFAULT_SETTINGS["truncated_multivariate_normal"]["expand_ratio"],
|
15
|
+
**kwargs):
|
16
|
+
"""
|
17
|
+
从截断的多维高斯分布中进行随机采样
|
18
|
+
|
19
|
+
参数:
|
20
|
+
mean: <list of float> 均值
|
21
|
+
cov: <matrix> 协方差矩阵
|
22
|
+
low_radius,high_radius: <float> 截断边界
|
23
|
+
注意:
|
24
|
+
- 是截断多少倍 sigma 以内的部分。
|
25
|
+
- low_radius 是指排除该半径距离内的点
|
26
|
+
- high_radius 是指包含该半径距离内的点
|
27
|
+
- 区间为左闭右开,亦即 [low_radius, high_radius)
|
28
|
+
size: <tuple/list/int/None> 输出的形状
|
29
|
+
b_check_cov: <boolean> 是否检查 cov 是正半定矩阵。
|
30
|
+
默认为 False 此时不检查
|
31
|
+
当设置为 True 时,若不通过检查将报错。
|
32
|
+
|
33
|
+
用于调节采样效率的超参数(与设备情况有关):
|
34
|
+
hit_ratio_threshold: <float> 决定采样方式的阈值
|
35
|
+
当 hit_ratio 小于该阈值时,使用方式 2 (重要性采样)来生成,
|
36
|
+
当大于阈值时,使用方式 1 采样 expand_ratio * size 个样本再挑选符合落在截断区间内的样本
|
37
|
+
expand_ratio: <float> 方式1的系数
|
38
|
+
要求大于 1
|
39
|
+
|
40
|
+
其他参数:
|
41
|
+
seed: <int> 随机种子
|
42
|
+
rng: <Random Generator> 给定的随机采样器
|
43
|
+
以上参数二选一
|
44
|
+
|
45
|
+
返回:
|
46
|
+
res: 当 shape 为 None 时,返回的是与mean大小相同的 n 维向量
|
47
|
+
否则返回的是 shape+[len(mean)] 维度的张量
|
48
|
+
"""
|
49
|
+
# 检查参数
|
50
|
+
assert len(mean) > 1
|
51
|
+
assert high_radius is None or 0 < high_radius
|
52
|
+
assert low_radius is None or 0 <= low_radius
|
53
|
+
if high_radius is not None and low_radius is not None:
|
54
|
+
assert low_radius < high_radius
|
55
|
+
if b_check_cov and cov is not None:
|
56
|
+
cov = np.asarray(cov)
|
57
|
+
assert np.allclose(cov, cov.T) and np.all(np.linalg.eig(cov)[0] > 0)
|
58
|
+
#
|
59
|
+
rng = get_rng(**kwargs)
|
60
|
+
low = None if low_radius is None else low_radius ** 2
|
61
|
+
high = None if high_radius is None else high_radius ** 2
|
62
|
+
|
63
|
+
# quick return
|
64
|
+
if (low is None or low == 0) and high is None:
|
65
|
+
return rng.multivariate_normal(mean, cov, size=size, check_valid="warn" if b_check_valid else "ignore")
|
66
|
+
|
67
|
+
# 因为标准高维高斯分布的采样点的方向服从均匀分布,而距离服从自由度为k的卡方分布
|
68
|
+
# 因此可以把方向和距离分开来进行采样
|
69
|
+
raw_size = 1 if size is None else np.prod([size])
|
70
|
+
|
71
|
+
# 对方向进行采样
|
72
|
+
theta = rng.normal(0, 1, size=[raw_size, len(mean)])
|
73
|
+
theta = normalize(v=theta, ord=2, axis=-1)
|
74
|
+
|
75
|
+
# 对距离进行采样
|
76
|
+
# 计算命中概率
|
77
|
+
cdf_high = chi2.cdf(high, df=len(mean)) if high is not None else 1
|
78
|
+
cdf_low = chi2.cdf(low, df=len(mean)) if low is not None else 0
|
79
|
+
hit_prob = cdf_high - cdf_low
|
80
|
+
if hit_prob >= hit_ratio_threshold:
|
81
|
+
# 采样方式1
|
82
|
+
delta = np.empty(raw_size)
|
83
|
+
count = 0
|
84
|
+
while count < raw_size:
|
85
|
+
temp = rng.chisquare(len(mean), int((raw_size - count) / hit_prob * expand_ratio) + 1)
|
86
|
+
if low is not None:
|
87
|
+
temp = temp[temp >= low]
|
88
|
+
if high is not None:
|
89
|
+
temp = temp[temp < high]
|
90
|
+
delta[count:count + len(temp)] = temp[:raw_size - count]
|
91
|
+
count += len(temp)
|
92
|
+
else:
|
93
|
+
# 采样方式2(重要性采样)
|
94
|
+
# 从均匀分布中采样
|
95
|
+
delta = rng.uniform(cdf_low, cdf_high, raw_size)
|
96
|
+
# 对均匀分布的样本进行逆变换得到截断正态分布的样本
|
97
|
+
delta = chi2.ppf(delta, df=len(mean))
|
98
|
+
|
99
|
+
# 整合方向和距离
|
100
|
+
res = theta * delta[:, None] ** 0.5
|
101
|
+
|
102
|
+
# 根据协方差矩阵进行缩放
|
103
|
+
if cov is not None:
|
104
|
+
A = cache_for_cov.get(cov.tobytes(), default_factory=lambda: np.linalg.cholesky(cov), b_add_if_not_found=True)
|
105
|
+
res = (A @ res.T).T
|
106
|
+
res += mean
|
107
|
+
|
108
|
+
if size is None:
|
109
|
+
res = res[0]
|
110
|
+
else:
|
111
|
+
size = [size] if isinstance(size, int) else list(size)
|
112
|
+
res = res.reshape(size + [len(mean)])
|
113
|
+
|
114
|
+
return res
|
115
|
+
|
116
|
+
|
117
|
+
if __name__ == '__main__':
|
118
|
+
print(truncated_multivariate_normal(mean=[0, 0], cov=np.array([[1, 0.5], [0.5, 1]]),
|
119
|
+
high_radius=2, seed=114, size=[2, 5]))
|
120
|
+
|
121
|
+
# 可视化
|
122
|
+
import matplotlib.pyplot as plt
|
123
|
+
|
124
|
+
res_ = truncated_multivariate_normal(mean=[0, 0], cov=np.array([[1, 0.5], [0.5, 1]]), low_radius=0.5,
|
125
|
+
high_radius=1.5, seed=114, b_check_cov=True,
|
126
|
+
size=10000)
|
127
|
+
plt.scatter(res_[:, 0], res_[:, 1], c='r', s=1)
|
128
|
+
|
129
|
+
plt.show()
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import math
|
2
|
+
import torch
|
3
|
+
import numpy as np
|
4
|
+
from kevin_toolbox.patches.for_numpy.random import get_rng
|
5
|
+
from kevin_toolbox.patches.for_numpy.random.variable import DEFAULT_SETTINGS
|
6
|
+
|
7
|
+
|
8
|
+
def cdf(x, mean=0, sigma=1):
|
9
|
+
return (1 + math.erf((x - mean) / (sigma * 2 ** 0.5))) * 0.5
|
10
|
+
|
11
|
+
|
12
|
+
def truncated_normal(mean=0, sigma=1, low=None, high=None, size=None,
|
13
|
+
hit_ratio_threshold=DEFAULT_SETTINGS["truncated_normal"]["hit_ratio_threshold"],
|
14
|
+
expand_ratio=DEFAULT_SETTINGS["truncated_normal"]["expand_ratio"],
|
15
|
+
**kwargs):
|
16
|
+
"""
|
17
|
+
从截断的高斯分布中进行随机采样
|
18
|
+
|
19
|
+
参数:
|
20
|
+
mean,sigma <float> 均值、标准差
|
21
|
+
low,high <float> 截断边界
|
22
|
+
size: <tuple/list/int/None> 输出的形状
|
23
|
+
|
24
|
+
用于调节采样效率的超参数(与设备情况有关):
|
25
|
+
hit_ratio_threshold: <float> 决定采样方式的阈值
|
26
|
+
我们称 截断内的部分的概率和/1 为命中概率 hit_ratio ,亦即进行一次全区间的采样,有多大概率落在截断区间内
|
27
|
+
当 hit_ratio 小于该阈值时,使用方式 2 (重要性采样)来生成,
|
28
|
+
当大于阈值时,使用方式 1 采样 expand_ratio * size 个样本再挑选符合落在截断区间内的样本
|
29
|
+
该参数应该根据实际情况下方式1和2的耗时差异来进行调整。
|
30
|
+
expand_ratio: <float> 方式1的系数
|
31
|
+
要求大于 1
|
32
|
+
|
33
|
+
其他参数:
|
34
|
+
seed: <int> 随机种子
|
35
|
+
rng: <Random Generator> 给定的随机采样器
|
36
|
+
以上参数二选一
|
37
|
+
"""
|
38
|
+
if high is not None and low is not None:
|
39
|
+
assert high > low
|
40
|
+
assert expand_ratio > 1 and 0 <= hit_ratio_threshold <= 1
|
41
|
+
rng = get_rng(**kwargs)
|
42
|
+
|
43
|
+
raw_size = 1 if size is None else np.prod([size])
|
44
|
+
# 计算命中概率
|
45
|
+
cdf_high = cdf(x=high, mean=mean, sigma=sigma) if high is not None else 1
|
46
|
+
cdf_low = cdf(x=low, mean=mean, sigma=sigma) if low is not None else 0
|
47
|
+
hit_prob = cdf_high - cdf_low
|
48
|
+
|
49
|
+
if hit_prob >= hit_ratio_threshold:
|
50
|
+
# 采样方式1
|
51
|
+
res = np.empty(raw_size)
|
52
|
+
count = 0
|
53
|
+
while count < raw_size:
|
54
|
+
temp = rng.normal(mean, sigma, int((raw_size - count) / hit_prob * expand_ratio) + 1)
|
55
|
+
if low is not None:
|
56
|
+
temp = temp[temp >= low]
|
57
|
+
if high is not None:
|
58
|
+
temp = temp[temp < high]
|
59
|
+
res[count:count + len(temp)] = temp[:raw_size - count]
|
60
|
+
count += len(temp)
|
61
|
+
else:
|
62
|
+
# 采样方式2(重要性采样)
|
63
|
+
# 从均匀分布中采样
|
64
|
+
res = rng.uniform(cdf_low, cdf_high, raw_size)
|
65
|
+
# 对均匀分布的样本进行逆变换得到截断正态分布的样本
|
66
|
+
res = mean + sigma * (2 ** 0.5) * torch.erfinv(torch.from_numpy(2 * res - 1))
|
67
|
+
res = res.numpy()
|
68
|
+
|
69
|
+
if size is None:
|
70
|
+
res = res[0]
|
71
|
+
else:
|
72
|
+
res = res.reshape(size)
|
73
|
+
|
74
|
+
return res
|
75
|
+
|
76
|
+
|
77
|
+
if __name__ == '__main__':
|
78
|
+
points = truncated_normal(mean=0, sigma=2, low=-3, high=None, size=10000, hit_ratio_threshold=0.01,
|
79
|
+
expand_ratio=1.5)
|
80
|
+
|
81
|
+
import numpy as np
|
82
|
+
import matplotlib.pyplot as plt
|
83
|
+
|
84
|
+
# 绘制概率分布图
|
85
|
+
plt.hist(points, bins=20, density=True, alpha=0.7)
|
86
|
+
plt.xlabel('Value')
|
87
|
+
plt.ylabel('Probability Density')
|
88
|
+
plt.title('Uniform Distribution')
|
89
|
+
plt.show()
|
@@ -23,7 +23,13 @@ def dump(study: optuna.study.Study):
|
|
23
23
|
keys = ['best_params', 'best_trial', 'best_trials', 'best_value', 'direction', 'directions', 'trials', 'user_attrs']
|
24
24
|
if version.compare(optuna.__version__, "<", "3.1.0"):
|
25
25
|
keys.append('system_attrs')
|
26
|
-
res_s = {k: getattr(study, k, None) for k in keys}
|
26
|
+
# 【bug fix】不能直接用 res_s = {k: getattr(study, k, None) for k in keys} 以免在获取某些属性时,比如 best_trials,产生错误
|
27
|
+
res_s = dict()
|
28
|
+
for k in keys:
|
29
|
+
try:
|
30
|
+
res_s[k] = getattr(study, k)
|
31
|
+
except:
|
32
|
+
res_s[k] = None
|
27
33
|
|
28
34
|
# 其他信息
|
29
35
|
res_s["__dict__"] = dict()
|
@@ -35,7 +41,9 @@ def dump(study: optuna.study.Study):
|
|
35
41
|
res_s["__dict__"][k] = v
|
36
42
|
|
37
43
|
# 序列化
|
38
|
-
|
44
|
+
# 【bug fix】不能对原始的 res_s 进行遍历和替换,以免意外修改 study 中的属性。
|
45
|
+
res_s = ndl.traverse(var=ndl.copy_(var=res_s, b_deepcopy=True),
|
46
|
+
match_cond=lambda _, __, v: not isinstance(v, (list, dict,)),
|
39
47
|
action_mode="replace", converter=__converter, b_use_name_as_idx=False,
|
40
48
|
b_traverse_matched_element=False)
|
41
49
|
|