kevin-toolbox-dev 1.4.11__py3-none-any.whl → 1.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/get_pareto_points_idx.py +2 -0
  3. kevin_toolbox/computer_science/algorithm/redirector/redirectable_sequence_fetcher.py +3 -3
  4. kevin_toolbox/computer_science/algorithm/sampler/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/sampler/recent_sampler.py +128 -0
  6. kevin_toolbox/computer_science/algorithm/sampler/reservoir_sampler.py +2 -2
  7. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +2 -0
  8. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +1 -1
  9. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +1 -1
  10. kevin_toolbox/computer_science/algorithm/statistician/maximum_accumulator.py +80 -0
  11. kevin_toolbox/computer_science/algorithm/statistician/minimum_accumulator.py +34 -0
  12. kevin_toolbox/data_flow/file/markdown/table/find_tables.py +38 -12
  13. kevin_toolbox/developing/file_management/__init__.py +1 -0
  14. kevin_toolbox/developing/file_management/file_feature_extractor.py +263 -0
  15. kevin_toolbox/nested_dict_list/serializer/read.py +4 -1
  16. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +5 -0
  17. kevin_toolbox/patches/for_matplotlib/common_charts/plot_2d_matrix.py +134 -0
  18. kevin_toolbox/patches/for_matplotlib/common_charts/plot_3d.py +198 -0
  19. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +7 -4
  20. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +11 -4
  21. kevin_toolbox/patches/for_matplotlib/common_charts/plot_contour.py +157 -0
  22. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +19 -8
  23. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +72 -21
  24. kevin_toolbox/patches/for_matplotlib/common_charts/plot_mean_std_lines.py +135 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +9 -3
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +9 -3
  27. kevin_toolbox/patches/for_matplotlib/common_charts/utils/__init__.py +1 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/utils/log_scaling.py +69 -0
  29. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_plot.py +19 -3
  30. kevin_toolbox/patches/for_matplotlib/common_charts/utils/save_record.py +1 -1
  31. kevin_toolbox/patches/for_numpy/__init__.py +1 -0
  32. kevin_toolbox/patches/for_numpy/linalg/softmax.py +4 -1
  33. kevin_toolbox_dev-1.4.13.dist-info/METADATA +77 -0
  34. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/RECORD +36 -26
  35. kevin_toolbox_dev-1.4.11.dist-info/METADATA +0 -67
  36. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/WHEEL +0 -0
  37. {kevin_toolbox_dev-1.4.11.dist-info → kevin_toolbox_dev-1.4.13.dist-info}/top_level.txt +0 -0
kevin_toolbox/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "1.4.11"
1
+ __version__ = "1.4.13"
2
2
 
3
3
 
4
4
  import os
@@ -12,5 +12,5 @@ os.system(
12
12
  os.system(
13
13
  f'python {os.path.split(__file__)[0]}/env_info/check_validity_and_uninstall.py '
14
14
  f'--package_name kevin-toolbox-dev '
15
- f'--expiration_timestamp 1760340881 --verbose 0'
15
+ f'--expiration_timestamp 1768647143 --verbose 0'
16
16
  )
@@ -31,6 +31,8 @@ def get_pareto_points_idx(points, directions=None):
31
31
  """
32
32
  points = np.asarray(points)
33
33
  assert points.ndim == 2 and len(points) > 0
34
+ if directions is not None and not isinstance(directions, (list, tuple,)):
35
+ directions = [directions] * points.shape[-1]
34
36
  assert directions is None or isinstance(directions, (list, tuple,)) and len(directions) == points.shape[-1]
35
37
 
36
38
  # 计算排序的权重
@@ -6,11 +6,11 @@ from kevin_toolbox.computer_science.algorithm.cache_manager import Cache_Manager
6
6
 
7
7
  def _randomly_idx_redirector(idx, seq_len, attempts, rng, *args):
8
8
  if idx == 0:
9
- return rng.randint(1, seq_len - 1)
9
+ return rng.randint(1, seq_len)
10
10
  elif idx == seq_len - 1:
11
- return rng.randint(0, seq_len - 2)
11
+ return rng.randint(0, seq_len - 1)
12
12
  else:
13
- return rng.choice([rng.randint(0, idx - 1), rng.randint(idx + 1, seq_len - 1)], size=1,
13
+ return rng.choice([rng.randint(0, idx), rng.randint(idx + 1, seq_len)], size=1,
14
14
  p=[idx / (seq_len - 1), (seq_len - idx - 1) / (seq_len - 1)])[0]
15
15
 
16
16
 
@@ -1 +1,2 @@
1
1
  from .reservoir_sampler import Reservoir_Sampler
2
+ from .recent_sampler import Recent_Sampler
@@ -0,0 +1,128 @@
1
+ class Recent_Sampler:
2
+ """
3
+ 最近采样器:始终保留最近加入的 capacity 个样本
4
+ """
5
+
6
+ def __init__(self, **kwargs):
7
+ """
8
+ 参数:
9
+ capacity: <int> 缓冲区或窗口的容量
10
+ """
11
+ # 默认参数
12
+ paras = {
13
+ "capacity": 1,
14
+ }
15
+
16
+ # 获取并更新参数
17
+ paras.update(kwargs)
18
+
19
+ # 校验 capacity
20
+ assert paras["capacity"] >= 1
21
+
22
+ self.paras = paras
23
+ self.cache = [] # 用列表来保存最近的样本
24
+ self.state = self._init_state() # state 只记录 total_nums
25
+
26
+ @staticmethod
27
+ def _init_state():
28
+ """
29
+ 初始化状态,仅记录已添加的总样本数
30
+ """
31
+ return dict(
32
+ total_nums=0,
33
+ )
34
+
35
+ def add(self, item, **kwargs):
36
+ """
37
+ 添加单个数据 item 到采样器中。
38
+ - 更新 total_nums 计数
39
+ - 将 item 追加到 cache 末尾
40
+ - 如果超出 capacity,则删除最旧的一个(即列表开头的元素)
41
+ """
42
+ self.state["total_nums"] += 1
43
+ self.cache.append(item)
44
+ if len(self.cache) > self.paras["capacity"]:
45
+ self.cache.pop(0)
46
+
47
+ def add_sequence(self, item_ls, **kwargs):
48
+ """
49
+ 批量添加:对列表中每个元素多次调用 add
50
+ """
51
+ for item in item_ls:
52
+ self.add(item, **kwargs)
53
+
54
+ def get(self, **kwargs):
55
+ """
56
+ 返回当前缓冲区中的数据列表(浅拷贝)。
57
+ """
58
+ return self.cache.copy()
59
+
60
+ def clear(self):
61
+ """
62
+ 清空已有数据和状态,重置采样器。
63
+ """
64
+ self.cache.clear()
65
+ self.state = self._init_state()
66
+
67
+ def __len__(self):
68
+ """
69
+ 返回已添加的总样本数(state["total_nums"]),
70
+ 而不是当前缓冲区长度
71
+ """
72
+ return self.state["total_nums"]
73
+
74
+ # ---------------------- 用于保存和加载状态 ---------------------- #
75
+
76
+ def load_state_dict(self, state_dict):
77
+ """
78
+ 加载状态
79
+ - 清空当前缓冲区和 state
80
+ - 恢复 state["total_nums"]
81
+ - 恢复 cache 列表内容
82
+ - 恢复 rng 状态
83
+ """
84
+ self.clear()
85
+ self.state.update(state_dict["state"])
86
+ self.cache.extend(state_dict["cache"])
87
+
88
+ def state_dict(self, b_deepcopy=True):
89
+ """
90
+ 获取当前状态,包含:
91
+ - state: {"total_nums": ...}
92
+ - cache: 当前缓冲区列表
93
+ """
94
+ temp = {
95
+ "state": self.state,
96
+ "cache": self.cache
97
+ }
98
+ if b_deepcopy:
99
+ import kevin_toolbox.nested_dict_list as ndl
100
+ temp = ndl.copy_(var=temp, b_deepcopy=True, b_keep_internal_references=True)
101
+ return temp
102
+
103
+
104
+ # 测试示例
105
+ if __name__ == "__main__":
106
+ # 创建一个容量为 5 的 Recent_Sampler
107
+ sampler = Recent_Sampler(capacity=5)
108
+
109
+ # 逐个添加 1 到 10 的数字
110
+ for i in range(1, 11):
111
+ sampler.add(i)
112
+ print(f"添加 {i} 后缓冲区: {sampler.get()}")
113
+
114
+ # 到这里,缓冲区中应该只保留最近加入的 5 个样本:6,7,8,9,10
115
+ print("最终缓冲区:", sampler.get()) # 预期输出: [6,7,8,9,10]
116
+ print("总共添加个数:", len(sampler)) # 预期输出: 10
117
+
118
+ # 保存当前状态
119
+ state = sampler.state_dict()
120
+ print("状态字典:", state)
121
+
122
+ # 清空后再恢复状态
123
+ sampler.clear()
124
+ print("清空后缓冲区:", sampler.get()) # 预期输出: []
125
+
126
+ sampler.load_state_dict(state)
127
+ print("恢复后缓冲区:", sampler.get()) # 预期输出: [6,7,8,9,10]
128
+ print("恢复后总共添加个数:", len(sampler)) # 预期输出: 10
@@ -61,13 +61,13 @@ class Reservoir_Sampler:
61
61
 
62
62
  def get(self, **kwargs):
63
63
  """
64
- 返回当前水库中的数据列表(拷贝)。
64
+ 返回当前水库中的数据列表(浅拷贝)。
65
65
  """
66
66
  return self.reservoir.copy()
67
67
 
68
68
  def clear(self):
69
69
  """
70
- 清空已有数据和状态,重置采样器。
70
+ 清空已有数据和状态,重置采样器。
71
71
  """
72
72
  self.reservoir.clear()
73
73
  self.state = self._init_state()
@@ -2,3 +2,5 @@ from .accumulator_base import Accumulator_Base
2
2
  from .exponential_moving_average import Exponential_Moving_Average
3
3
  from .average_accumulator import Average_Accumulator
4
4
  from .accumulator_for_ndl import Accumulator_for_Ndl
5
+ from .maximum_accumulator import Maximum_Accumulator
6
+ from .minimum_accumulator import Minimum_Accumulator
@@ -27,7 +27,7 @@ class Average_Accumulator(Accumulator_Base):
27
27
  以上三种方式,默认选用最后一种。
28
28
  如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
29
29
  """
30
- super(Average_Accumulator, self).__init__(**kwargs)
30
+ super().__init__(**kwargs)
31
31
 
32
32
  def add_sequence(self, var_ls, **kwargs):
33
33
  for var in var_ls:
@@ -56,7 +56,7 @@ class Exponential_Moving_Average(Accumulator_Base):
56
56
  # 校验参数
57
57
  assert isinstance(paras["keep_ratio"], (int, float,)) and 0 <= paras["keep_ratio"] <= 1
58
58
  #
59
- super(Exponential_Moving_Average, self).__init__(**paras)
59
+ super().__init__(**paras)
60
60
 
61
61
  def add_sequence(self, var_ls, **kwargs):
62
62
  for var in var_ls:
@@ -0,0 +1,80 @@
1
+ import numpy as np
2
+ import torch
3
+ from kevin_toolbox.computer_science.algorithm.statistician import Accumulator_Base
4
+
5
+
6
+ class Maximum_Accumulator(Accumulator_Base):
7
+ """
8
+ 用于计算最大值的累积器
9
+ """
10
+
11
+ def __init__(self, **kwargs):
12
+ """
13
+ 参数:
14
+ data_format: 指定数据格式
15
+ like: 指定数据格式
16
+ 指定输入数据的格式,有三种方式:
17
+ 1. 显式指定数据的形状和所在设备等。
18
+ data_format: <dict of paras>
19
+ 其中需要包含以下参数:
20
+ type_: <str>
21
+ "numpy": np.ndarray
22
+ "torch": torch.tensor
23
+ shape: <list of integers>
24
+ device: <torch.device>
25
+ dtype: <torch.dtype>
26
+ 2. 根据输入的数据,来推断出形状、设备等。
27
+ like: <torch.tensor / np.ndarray / int / float>
28
+ 3. 均不指定 data_format 和 like,此时将等到第一次调用 add()/add_sequence() 时再根据输入来自动推断。
29
+ 以上三种方式,默认选用最后一种。
30
+ 如果三种方式同时被指定,则优先级与对应方式在上面的排名相同。
31
+ """
32
+ super().__init__(**kwargs)
33
+
34
+ def add_sequence(self, var_ls, **kwargs):
35
+ for var in var_ls:
36
+ self.add(var, **kwargs)
37
+
38
+ def add(self, var, **kwargs):
39
+ """
40
+ 添加单个数据
41
+
42
+ 参数:
43
+ var: 数据
44
+ """
45
+ if self.var is None:
46
+ self.var = var
47
+ else:
48
+ # 统计
49
+ if torch.is_tensor(var):
50
+ self.var = torch.maximum(self.var, var)
51
+ else:
52
+ self.var = np.maximum(self.var, var)
53
+ self.state["total_nums"] += 1
54
+
55
+ def get(self, **kwargs):
56
+ """
57
+ 获取当前累加的平均值
58
+ 当未有累积时,返回 None
59
+ """
60
+ if len(self) == 0:
61
+ return None
62
+ return self.var
63
+
64
+ @staticmethod
65
+ def _init_state():
66
+ """
67
+ 初始化状态
68
+ """
69
+ return dict(
70
+ total_nums=0
71
+ )
72
+
73
+
74
+ if __name__ == '__main__':
75
+
76
+ seq = list(torch.tensor(range(1, 10))-5)
77
+ avg = Maximum_Accumulator()
78
+ for i, v in enumerate(seq):
79
+ avg.add(var=v)
80
+ print(i, v, avg.get())
@@ -0,0 +1,34 @@
1
+ import numpy as np
2
+ import torch
3
+ from kevin_toolbox.computer_science.algorithm.statistician import Maximum_Accumulator
4
+
5
+
6
+ class Minimum_Accumulator(Maximum_Accumulator):
7
+ """
8
+ 用于计算最小值的累积器
9
+ """
10
+
11
+ def add(self, var, **kwargs):
12
+ """
13
+ 添加单个数据
14
+
15
+ 参数:
16
+ var: 数据
17
+ """
18
+ if self.var is None:
19
+ self.var = var
20
+ else:
21
+ # 统计
22
+ if torch.is_tensor(var):
23
+ self.var = torch.minimum(self.var, var)
24
+ else:
25
+ self.var = np.minimum(self.var, var)
26
+ self.state["total_nums"] += 1
27
+
28
+
29
+ if __name__ == '__main__':
30
+ seq = list(torch.tensor(range(1, 10)) + 5)
31
+ avg = Minimum_Accumulator()
32
+ for i, v in enumerate(seq):
33
+ avg.add(var=v)
34
+ print(i, v, avg.get())
@@ -52,23 +52,49 @@ def find_tables(text, b_compact_format=True):
52
52
  return table_ls, part_slices_ls, table_idx_ls
53
53
 
54
54
 
55
+ # def _find_table(text):
56
+ # # 正则表达式匹配Markdown表格
57
+ # table_pattern = re.compile(r'\|([^\n]+)\|', re.DOTALL)
58
+ # table_matches = table_pattern.findall(text)
59
+ # if len(table_matches) < 2:
60
+ # # 因为一个合法的 markdown 表格需要含有表头的分隔线,所以行数至少应该为 2
61
+ # return None
62
+ #
63
+ # # 去除表头的分隔线
64
+ # table_matches.pop(1)
65
+ # #
66
+ # tables = [] # 每个元素为一行
67
+ # for match in table_matches:
68
+ # # 分割每一行
69
+ # tables.append([i.strip() for i in match.split('|', -1)])
70
+ #
71
+ # return {"matrix": tables, "orientation": None}
72
+
55
73
  def _find_table(text):
56
- # 正则表达式匹配Markdown表格
57
- table_pattern = re.compile(r'\|([^\n]+)\|', re.DOTALL)
58
- table_matches = table_pattern.findall(text)
59
- if len(table_matches) < 2:
74
+ # 按行分割文本
75
+ lines = text.splitlines()
76
+ table_rows = []
77
+ for line in lines:
78
+ # 移除行首尾空白
79
+ stripped_line = line.strip()
80
+ if not stripped_line:
81
+ continue # 跳过空行
82
+ # 移除行首尾的可选竖线(如果存在)
83
+ if stripped_line.startswith('|'):
84
+ stripped_line = stripped_line[1:]
85
+ if stripped_line.endswith('|'):
86
+ stripped_line = stripped_line[:-1]
87
+ # 分割单元格并去除每个单元格的空白
88
+ row_cells = [cell.strip() for cell in stripped_line.split('|')]
89
+ table_rows.append(row_cells)
90
+
91
+ if len(table_rows) < 2:
60
92
  # 因为一个合法的 markdown 表格需要含有表头的分隔线,所以行数至少应该为 2
61
93
  return None
62
-
63
94
  # 去除表头的分隔线
64
- table_matches.pop(1)
65
- #
66
- tables = [] # 每个元素为一行
67
- for match in table_matches:
68
- # 分割每一行
69
- tables.append([i.strip() for i in match.split('|', -1)])
95
+ table_rows.pop(1)
70
96
 
71
- return {"matrix": tables, "orientation": None}
97
+ return {"matrix": table_rows, "orientation": None}
72
98
 
73
99
 
74
100
  if __name__ == '__main__':
@@ -0,0 +1 @@
1
+ from .file_feature_extractor import File_Feature_Extractor
@@ -0,0 +1,263 @@
1
+ import os
2
+ import time
3
+ import json
4
+ import hashlib
5
+ from enum import Enum
6
+ from kevin_toolbox.patches import for_os
7
+ from kevin_toolbox.data_flow.file import json_
8
+
9
+
10
+ class F_Type(Enum):
11
+ file = 0
12
+ symlink = 1
13
+ dir = 2
14
+ unknown = -1
15
+
16
+
17
+ class File_Feature_Extractor:
18
+ """
19
+ 文件特征提取器类,用于扫描指定目录下所有文件(包括文件夹和符号链接),提取:
20
+ - 文件元数据
21
+ - 浅哈希值(仅支持对文件使用)
22
+ - 完整哈希值等特征(仅支持对文件使用)
23
+ 并支持缓存、更新和持久化。
24
+
25
+ 参数:
26
+ input_dir: <str> 根目录路径
27
+ metadata_cfg: <dict> 提取元信息的方式。
28
+ 接受一个形如 {"attribute": ["size", ...], "include": ["file", ...], ...} 的字典,
29
+ 其中 "attribute" 字段下指定需要添加的元信息,目前支持:
30
+ - size 文件大小
31
+ - created_time、modified_time、accessed_time 时间
32
+ - mode 权限
33
+ - is_symlink、is_dir、is_file 种类
34
+ - is_symlink_valid 链接是否有效
35
+ 而 "include" 字段用于指定要遍历的目标类型。
36
+ 默认 "attribute" 和 "include" 均包含以上全部支持的选项。
37
+ 当设置为 None 时,表示不提取元信息。
38
+ hash_cfg: <dict> 提取浅哈希的方式。
39
+ 接受形如 {"algorithm": ["md5", ...], "read_size": [<int>, None, ...], ...} 的字典
40
+ 其中 "algorithm" 表示使用的哈希算法类型,支持:
41
+ - 'md5', 'sha1', 'sha256'
42
+ 默认 "algorithm" 包含 "md5"。
43
+ 而 "read_size" 表示读取文件内容的最大前 N 个字节的内容来计算哈希值,支持:
44
+ - <int> 表示需要读取前 N 个字节
45
+ - None 表示读取整个文件
46
+ 默认 "read_size" 中的值为 [1024, None, ...]
47
+ b_read_dst_of_symlink: <boolean> 是否读取链接指向的目标文件。
48
+ 默认为 False。
49
+ include: <list> 指定要遍历的目标类型
50
+ 当上面的 xxx_cfg 参数中没有额外指定 "include" 字段时,将以该参数作为该字段的默认参数。
51
+ 当给定值为 str 时,支持:
52
+ - "symlink"、"dir"、"file"
53
+ 当给定值为 dict 时,支持:
54
+ - {"filter_type": "suffix", "option_ls": [...]} 根据后缀进行选择。
55
+ - {"filter_type": "small_than", "size": <int>, "b_include_equal": <boolean>} 根据文件大小(单位为字节)选择。
56
+ 当给定值为函数时,函数应该形如:
57
+ - func(file_path) ==> <boolean> 当函数返回值为 True 时,表示匹配成功。
58
+ 另有一个特殊值为 None,表示匹配所有
59
+ exclude: <list> 指定要排除的目标类型
60
+ 其设置参考 include。
61
+ 默认为 None,表示不排除任何
62
+ walk_paras: <dict> 调用 for_os.walk() 对目录进行遍历时的参数
63
+ 利用该参数可以实现更高级的指定遍历顺序、排除内容的操作
64
+
65
+ 结果的形式:
66
+ {
67
+ <folder_A>:{
68
+ <folder_B>:{
69
+ (<base_name>, <type>):
70
+ {
71
+ "metadata": {"size": ..., ...},
72
+ "hash": {
73
+ <size>: {"md5": ...., "sha": ...}
74
+ },
75
+ "dst_of_symlink": {"metadata": ...., "hash": ...}
76
+ }
77
+ }
78
+ }
79
+ }
80
+ 其中 type 有 "symlink"、"dir"、"file" None 几种取值
81
+
82
+ 方法:
83
+ scan(): 扫描所有文件,提取特征并写入缓存
84
+ update(): 增量更新,只有当文件修改时间变化时才重新提取
85
+ save_cache(file_path): 将当前缓存保存为 JSON 文件
86
+ load_cache(file_path): 从 JSON 文件中加载缓存
87
+ """
88
+
89
+ def __init__(self, **kwargs):
90
+ # 默认参数
91
+ paras = {
92
+ "input_dir": None,
93
+ "metadata_cfg": {"attribute": {"size", "created_time", "modified_time", "accessed_time",
94
+ "mode", "is_symlink", "is_dir", "is_file", "is_symlink_valid"}, },
95
+ "hash_cfg": {"algorithm": {"md5", }, "read_size": {1024, None}, },
96
+ "b_read_dst_of_symlink": False,
97
+ "include": None,
98
+ "exclude": None,
99
+ "walk_paras": dict(topdown=True, onerror=None, followlinks=False, ignore_s=None)
100
+ }
101
+
102
+ # 获取参数
103
+ paras.update(kwargs)
104
+
105
+ # 校验参数
106
+ if not paras["input_dir"] or not os.path.isdir(paras["input_dir"]):
107
+ raise ValueError(f'invalid input_dir {paras["input_dir"]}')
108
+ #
109
+ for k in ["metadata_cfg", "hash_cfg"]:
110
+ paras[k].setdefault('include', paras['include'])
111
+ paras[k].setdefault('exclude', paras['exclude'])
112
+ self.cache = {}
113
+ self.paras = paras
114
+
115
+ @staticmethod
116
+ def _matches(path, rule_ls):
117
+ """
118
+ 判断路径是否符合规则
119
+ """
120
+ path = os.path.realpath(path)
121
+ stat = os.lstat(path)
122
+ for rule in rule_ls:
123
+ # 类型字符串匹配
124
+ if isinstance(rule, str):
125
+ if rule == 'file' and os.path.isfile(path): return True
126
+ if rule == 'dir' and os.path.isdir(path): return True
127
+ if rule == 'symlink' and os.path.islink(path): return True
128
+ return False
129
+ # 后缀过滤
130
+ if isinstance(rule, dict):
131
+ ft = rule.get('filter_type')
132
+ if ft == 'suffix':
133
+ return any(path.endswith(suf) for suf in rule.get('option_ls', []))
134
+ elif ft == 'small_than':
135
+ size = stat.st_size
136
+ limit = rule.get('size', 0)
137
+ eq = rule.get('b_include_equal', False)
138
+ return size < limit or (eq and size == limit)
139
+ # 函数
140
+ if callable(rule):
141
+ return rule(path)
142
+ return False
143
+ return False
144
+
145
+ @staticmethod
146
+ def _get_metadata(path, attribute):
147
+ """
148
+ 获取文件元信息
149
+ """
150
+ path = os.path.realpath(path)
151
+ stat = os.lstat(path)
152
+ res_s = dict()
153
+ for attr in attribute:
154
+ if attr == 'size': res_s['size'] = stat.st_size
155
+ if attr == 'created_time': res_s['created_time'] = stat.st_ctime
156
+ if attr == 'modified_time': res_s['modified_time'] = stat.st_mtime
157
+ if attr == 'accessed_time': res_s['accessed_time'] = stat.st_atime
158
+ if attr == 'mode': res_s['mode'] = stat.st_mode
159
+ if attr == 'is_symlink': res_s['is_symlink'] = os.path.islink(path)
160
+ if attr == 'is_dir': res_s['is_dir'] = os.path.isdir(path)
161
+ if attr == 'is_file': res_s['is_file'] = os.path.isfile(path)
162
+ if attr == 'is_symlink_valid':
163
+ res_s['is_symlink_valid'] = os.path.islink(path) and os.path.exists(os.readlink(path))
164
+ return res_s
165
+
166
+ @staticmethod
167
+ def _get_hash(path, read_size_ls, algorithm_ls):
168
+ """
169
+ 对文件进行哈希,read_size=None 表示完整哈希,否则浅哈希
170
+ """
171
+ res_s = dict()
172
+ for size in read_size_ls:
173
+ for algo in algorithm_ls:
174
+ h = hashlib.new(algo)
175
+ with open(path, 'rb') as f:
176
+ if size is not None:
177
+ data = f.read(size)
178
+ h.update(data)
179
+ else:
180
+ for chunk in iter(lambda: f.read(8192), b''):
181
+ h.update(chunk)
182
+ res_s[size] = res_s.get(size, dict())
183
+ res_s[size][algo] = h.hexdigest()
184
+ return res_s
185
+
186
+ def extract_feature(self, path, metadata_cfg=None, hash_cfg=None):
187
+ metadata_cfg = metadata_cfg or self.paras['metadata_cfg']
188
+ hash_cfg = hash_cfg or self.paras['hash_cfg']
189
+ path = os.path.realpath(path)
190
+ res_s = dict()
191
+ base_ = os.path.basename(path)
192
+ if os.path.islink(path):
193
+ f_type = F_Type.symlink
194
+ elif os.path.isfile(path):
195
+ f_type = F_Type.file
196
+ elif os.path.isdir(path):
197
+ f_type = F_Type.dir
198
+ else:
199
+ f_type = F_Type.unknown
200
+ try:
201
+ if metadata_cfg is not None:
202
+ res_s["metadata"] = self._get_metadata(path, attribute=metadata_cfg['attribute'])
203
+ if hash_cfg is not None and f_type == F_Type.file:
204
+ res_s["hash"] = self._get_hash(path, read_size_ls=hash_cfg['read_size'],
205
+ algorithm_ls=hash_cfg['algorithm'])
206
+ if os.path.islink(path) and self.paras['b_read_dst_of_symlink']:
207
+ dst = os.readlink(path)
208
+ res_s['dst_of_symlink'] = self.extract_feature(dst)
209
+ except Exception as e:
210
+ res_s = {'error': str(e)}
211
+ return base_, f_type.value, res_s
212
+
213
+ def scan_path(self, path, metadata_cfg=None, hash_cfg=None):
214
+ """
215
+ 扫描路径,提取特征并写入缓存
216
+ """
217
+ path = os.path.realpath(path)
218
+ rel = os.path.relpath(path, self.paras["input_dir"])
219
+ parts = rel.split(os.sep)
220
+ node = self.cache
221
+ for p in parts[:-1]:
222
+ node = node.setdefault(p, {})
223
+ base_, f_type, res_s = self.extract_feature(path=path, metadata_cfg=metadata_cfg, hash_cfg=hash_cfg)
224
+ node[(base_, f_type)] = res_s
225
+
226
+ def scan_recursively(self, path=None, metadata_cfg=None, hash_cfg=None):
227
+ """
228
+ 递归扫描目录,提取特征并写入缓存
229
+ """
230
+ path = path or self.paras["input_dir"]
231
+ for root, dirs, files in for_os.walk(top=path, **self.paras["walk_paras"]):
232
+ for name in files + dirs:
233
+ full_path = os.path.join(root, name)
234
+ if self.paras["include"] is not None:
235
+ if not self._matches(full_path, rule_ls=self.paras["include"]):
236
+ continue
237
+ if self.paras["exclude"] is not None:
238
+ if self._matches(full_path, rule_ls=self.paras["exclude"]):
239
+ continue
240
+ self.scan_path(full_path, metadata_cfg=metadata_cfg, hash_cfg=hash_cfg)
241
+
242
+ def update(self):
243
+ """
244
+ 增量更新,重新扫描修改过的文件
245
+ """
246
+ # 简化:重新全量扫描覆盖旧缓存,可按需优化
247
+ self.cache.clear()
248
+ self.scan_recursively()
249
+
250
+ def save_cache(self, file_path):
251
+ json_.write(content=self.cache, file_path=file_path, b_use_suggested_converter=True)
252
+
253
+ def load_cache(self, file_path):
254
+ self.cache = json_.read(file_path=file_path, b_use_suggested_converter=True)
255
+
256
+
257
+ if __name__ == '__main__':
258
+ from kevin_toolbox.data_flow.file import markdown
259
+ file_feature_extractor = File_Feature_Extractor(
260
+ input_dir=os.path.join(os.path.dirname(__file__), "test/test_data")
261
+ )
262
+ file_feature_extractor.scan_recursively()
263
+ print(markdown.generate_list(file_feature_extractor.cache))