kevin-toolbox-dev 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. kevin_toolbox/__init__.py +2 -2
  2. kevin_toolbox/computer_science/algorithm/pareto_front/__init__.py +1 -0
  3. kevin_toolbox/computer_science/algorithm/pareto_front/optimum_picker.py +218 -0
  4. kevin_toolbox/computer_science/algorithm/statistician/__init__.py +1 -0
  5. kevin_toolbox/computer_science/algorithm/statistician/accumulator_base.py +2 -2
  6. kevin_toolbox/computer_science/algorithm/statistician/accumulator_for_ndl.py +69 -0
  7. kevin_toolbox/computer_science/algorithm/statistician/average_accumulator.py +16 -4
  8. kevin_toolbox/computer_science/algorithm/statistician/exponential_moving_average.py +5 -12
  9. kevin_toolbox/data_flow/file/json_/read_json.py +11 -5
  10. kevin_toolbox/data_flow/file/kevin_notation/kevin_notation_writer.py +4 -1
  11. kevin_toolbox/data_flow/file/kevin_notation/test/test_kevin_notation_debug.py +27 -0
  12. kevin_toolbox/data_flow/file/kevin_notation/write.py +3 -0
  13. kevin_toolbox/data_flow/file/markdown/__init__.py +3 -0
  14. kevin_toolbox/data_flow/file/markdown/find_tables.py +65 -0
  15. kevin_toolbox/data_flow/file/markdown/generate_table.py +19 -5
  16. kevin_toolbox/data_flow/file/markdown/parse_table.py +135 -0
  17. kevin_toolbox/data_flow/file/markdown/save_images_in_ndl.py +81 -0
  18. kevin_toolbox/data_flow/file/markdown/variable.py +17 -0
  19. kevin_toolbox/nested_dict_list/serializer/backends/_ndl.py +4 -1
  20. kevin_toolbox/nested_dict_list/serializer/read.py +18 -14
  21. kevin_toolbox/nested_dict_list/serializer/write.py +23 -7
  22. kevin_toolbox/patches/for_matplotlib/common_charts/__init__.py +6 -0
  23. kevin_toolbox/patches/for_matplotlib/common_charts/plot_bars.py +54 -0
  24. kevin_toolbox/patches/for_matplotlib/common_charts/plot_confusion_matrix.py +60 -0
  25. kevin_toolbox/patches/for_matplotlib/common_charts/plot_distribution.py +65 -0
  26. kevin_toolbox/patches/for_matplotlib/common_charts/plot_lines.py +61 -0
  27. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters.py +53 -0
  28. kevin_toolbox/patches/for_matplotlib/common_charts/plot_scatters_matrix.py +54 -0
  29. kevin_toolbox/patches/for_os/__init__.py +3 -0
  30. kevin_toolbox/patches/for_os/copy.py +33 -0
  31. kevin_toolbox/patches/for_os/find_files_in_dir.py +30 -0
  32. kevin_toolbox/patches/for_os/path/__init__.py +2 -0
  33. kevin_toolbox/patches/for_os/path/find_illegal_chars.py +47 -0
  34. kevin_toolbox/patches/for_os/path/replace_illegal_chars.py +49 -0
  35. kevin_toolbox/patches/for_test/check_consistency.py +104 -33
  36. kevin_toolbox_dev-1.3.6.dist-info/METADATA +95 -0
  37. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/RECORD +39 -20
  38. kevin_toolbox_dev-1.3.4.dist-info/METADATA +0 -67
  39. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/WHEEL +0 -0
  40. {kevin_toolbox_dev-1.3.4.dist-info → kevin_toolbox_dev-1.3.6.dist-info}/top_level.txt +0 -0
kevin_toolbox/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "1.3.4"
1
+ __version__ = "1.3.6"
2
2
 
3
3
 
4
4
  import os
@@ -12,5 +12,5 @@ os.system(
12
12
  os.system(
13
13
  f'python {os.path.split(__file__)[0]}/env_info/check_validity_and_uninstall.py '
14
14
  f'--package_name kevin-toolbox-dev '
15
- f'--expiration_timestamp 1727961379 --verbose 0'
15
+ f'--expiration_timestamp 1735563213 --verbose 0'
16
16
  )
@@ -1 +1,2 @@
1
1
  from .get_pareto_points_idx import get_pareto_points_idx, Direction
2
+ from .optimum_picker import Optimum_Picker
@@ -0,0 +1,218 @@
1
+ from kevin_toolbox.computer_science.data_structure import Executor
2
+ from kevin_toolbox.computer_science.algorithm.pareto_front import get_pareto_points_idx, Direction
3
+ import kevin_toolbox.nested_dict_list as ndl
4
+ import numpy as np
5
+
6
+
7
+ class Optimum_Picker:
8
+ """
9
+ 记录并更新帕累托最优值
10
+ 同时支持监控以下行为:
11
+ - 新加值是一个新的帕累托最优值
12
+ - 抛弃一个不再是最优的旧的最优值
13
+ 并触发设定的执行器,详见参数 trigger_for_new 和 trigger_for_out
14
+ """
15
+
16
+ def __init__(self, **kwargs):
17
+ """
18
+ 参数:
19
+ directions: <list of Direction> 比较的方向
20
+ trigger_for_new: <Executor> 触发器
21
+ 当 add() 新添加的监控值是一个新的帕累托最优时,执行该触发器
22
+ 在执行前,将自动往触发器的 kwargs 中添加 {"metrics": <metrics>, "step": <step>, ...} 等信息
23
+ trigger_for_out: <Executor> 触发器
24
+ 当 add() 时候需要抛弃一些不再是的帕累托最优的历史值时,执行该触发器
25
+ 在执行前,将自动往触发器的 kwargs 中添加 {"metrics": <metrics>, "step": <step>, ...} 等信息
26
+ warmup_steps: <int> 在经过多少次 add() 之后,再开始比较监控值
27
+ 默认为 0
28
+ pick_per_steps: <int> 每经过多少次 add(),就比较一次监控值
29
+ 默认为 1
30
+ """
31
+ # 默认参数
32
+ paras = {
33
+ "directions": None,
34
+ "trigger_for_new": None,
35
+ "trigger_for_out": None,
36
+ "warmup_steps": 0,
37
+ "pick_per_steps": 1,
38
+ }
39
+
40
+ # 获取参数
41
+ paras.update(kwargs)
42
+
43
+ # 校验参数
44
+ assert paras["warmup_steps"] >= 0 and paras["pick_per_steps"] >= 1
45
+ paras["directions"] = [Direction(i) for i in paras["directions"]]
46
+ for k in ["trigger_for_out", "trigger_for_new"]:
47
+ assert isinstance(paras[k], (type(None), Executor))
48
+
49
+ self.paras = paras
50
+ self._state = self._init_state()
51
+
52
+ def _init_state(self):
53
+ return dict(
54
+ optimal_ls=list(), # [{"metrics":metrics, "record":record, "step":step}, ...]
55
+ step=0,
56
+ b_empty_cache=True,
57
+ last_optimal_nums=0
58
+ )
59
+
60
+ def add(self, metrics, b_force_clear_cache=False, **kwargs):
61
+ """
62
+ 添加指标
63
+
64
+ 参数:
65
+ metrics: 指标
66
+ b_force_clear_cache: <boolean> 是否强制清空缓存
67
+ 默认为 False,此时将根据设定的 warmup_steps 和 pick_per_steps 来决定何时清空一次缓存
68
+ **kwargs: 用户自定义记录
69
+ 将被添加到 record 中
70
+ """
71
+ assert metrics is not None
72
+ metrics = np.asarray(metrics).reshape(1, -1)
73
+ assert metrics.shape[-1] == len(self.paras["directions"])
74
+
75
+ optimal_ls, step = self._state["optimal_ls"], self._state["step"]
76
+ new_record = dict(metrics=metrics, step=step)
77
+ new_record.update(kwargs)
78
+ #
79
+ optimal_ls.append(new_record)
80
+ self._state["step"] += 1
81
+
82
+ # warmup & cache
83
+ if not b_force_clear_cache and (step < self.paras["warmup_steps"] or
84
+ (step - self.paras["warmup_steps"]) % self.paras["pick_per_steps"] != 0):
85
+ self._state["b_empty_cache"] = False
86
+ return
87
+
88
+ # 找出新的帕累托最优值
89
+ points = np.concatenate([i["metrics"] for i in optimal_ls])
90
+ idx_ls = get_pareto_points_idx(points=points, directions=self.paras["directions"])
91
+ idx_ls.sort()
92
+ # 进行触发操作
93
+ if self.paras["trigger_for_new"] is not None:
94
+ for i in filter(lambda i: i >= self._state["last_optimal_nums"], idx_ls):
95
+ self.paras["trigger_for_new"].run(**optimal_ls[i])
96
+ #
97
+ if self.paras["trigger_for_out"] is not None:
98
+ for i in set(range(self._state["last_optimal_nums"])).difference(set(idx_ls)):
99
+ self.paras["trigger_for_out"].run(**optimal_ls[i])
100
+
101
+ # 更新
102
+ self._state["optimal_ls"] = [optimal_ls[i] for i in idx_ls]
103
+ self._state["b_empty_cache"] = True
104
+ self._state["last_optimal_nums"] = len(idx_ls)
105
+
106
+ def get(self, b_force_clear_cache=False):
107
+ """
108
+ 获取当前最优值记录
109
+
110
+ 参数:
111
+ b_force_clear_cache: <boolean> 是否强制清空缓存
112
+
113
+ 返回:
114
+ record_ls: <list of dict> 最优记录
115
+ b_empty_cache: <boolean> 缓存是否清空,亦即 record_ls 是否是真正的最优记录
116
+ 当 pick_per_steps > 1 时,将有部分记录留存在缓存中,没有进行比较,此时的最优记录并不是完整的也不是最新的
117
+ """
118
+ if b_force_clear_cache and not self._state["b_empty_cache"] and len(self._state["optimal_ls"]) > 0:
119
+ # 需要清空缓存,就把最后一次缓存的记录拿出来,重新使用 b_force_clear_cache=False 去 add 一次
120
+ record = self._state["optimal_ls"].pop(-1)
121
+ metrics, step = record.pop("metrics"), record.pop("step")
122
+ self._state["step"] -= 1
123
+ assert self._state["step"] == step
124
+ self.add(metrics=metrics, b_force_clear_cache=True, **record)
125
+
126
+ return self._state["optimal_ls"][:], self._state["b_empty_cache"] or b_force_clear_cache
127
+
128
+ def clear(self):
129
+ self._state = self._init_state()
130
+
131
+ def __len__(self):
132
+ return self._state["step"]
133
+
134
+ # ---------------------- 用于保存和加载状态 ---------------------- #
135
+ def load_state_dict(self, state_dict):
136
+ """
137
+ 加载状态
138
+ """
139
+ self.clear()
140
+ self._state.update(state_dict)
141
+
142
+ def state_dict(self):
143
+ """
144
+ 获取状态
145
+ """
146
+ return ndl.copy_(var=self._state, b_deepcopy=True, b_keep_internal_references=True)
147
+
148
+
149
+ if __name__ == '__main__':
150
+ """
151
+ 模拟场景
152
+ 在训练模型时,要求比较 val_acc_1(maximize) 和 val_error_2(minimize),
153
+ 要求保存其帕累托最优时的模型。
154
+ """
155
+ import torch
156
+ import matplotlib.pyplot as plt
157
+
158
+ # 一个打乱的圆的采样点序列
159
+ metrics = torch.tensor([(-4.045084971874739, -2.9389262614623632),
160
+ (-3.1871199487434474, -3.852566213878947),
161
+ (-2.1288964578253635, 4.524135262330097),
162
+ (-4.648882429441257, -1.8406227634233896),
163
+ (-4.648882429441256, 1.8406227634233907),
164
+ (-0.936906572928623, 4.911436253643443),
165
+ (0.31395259764656414, -4.990133642141358),
166
+ (-4.960573506572389, 0.6266661678215226),
167
+ (-3.1871199487434487, 3.852566213878946),
168
+ (4.381533400219316, -2.4087683705085805),
169
+ (0.31395259764656763, 4.990133642141358),
170
+ (2.6791339748949827, 4.221639627510076),
171
+ (4.8429158056431545, -1.2434494358242767),
172
+ (1.5450849718747361, -4.755282581475768),
173
+ (4.842915805643155, 1.243449435824274),
174
+ (3.644843137107056, -3.422735529643445),
175
+ (5.0, 0.0),
176
+ (-2.128896457825361, -4.524135262330099),
177
+ (2.6791339748949836, -4.221639627510075),
178
+ (-4.9605735065723895, -0.6266661678215214),
179
+ (3.644843137107058, 3.422735529643443),
180
+ (4.381533400219318, 2.4087683705085765),
181
+ (-0.9369065729286231, -4.911436253643443),
182
+ (-4.045084971874736, 2.9389262614623664),
183
+ (1.5450849718747373, 4.755282581475767)])
184
+ # 右下角的点是帕累托最优
185
+ best_idx_ls = [6, 9, 12, 13, 15, 16, 18]
186
+
187
+ # 将x和y坐标分别存储在两个列表中
188
+ x_coords = metrics[:, 0].numpy().tolist()
189
+ y_coords = metrics[:, 1].numpy().tolist()
190
+ # 按顺序绘制点
191
+ plt.plot(x_coords, y_coords, marker='o')
192
+ # 添加顺序标签
193
+ for i, txt in enumerate(range(len(metrics))):
194
+ plt.annotate(txt, (x_coords[i], y_coords[i]), textcoords="offset points", xytext=(0, 5), ha='center')
195
+ plt.show()
196
+ import os
197
+ from kevin_toolbox.data_flow.file import json_
198
+ from kevin_toolbox.patches.for_os import remove
199
+
200
+ temp_dir = os.path.join(os.path.dirname(__file__), "temp")
201
+ remove(temp_dir, ignore_errors=True)
202
+
203
+ opt_picker = Optimum_Picker(
204
+ warmup_steps=9, pick_per_steps=5,
205
+ trigger_for_new=Executor(
206
+ func=lambda metrics, step: json_.write(metrics.tolist(), os.path.join(temp_dir, f'{step}.json'))),
207
+ trigger_for_out=Executor(func=lambda step, **kwargs: remove(os.path.join(temp_dir, f'{step}.json'))),
208
+ directions=["maximize", "minimize"]
209
+ )
210
+ for s, v in enumerate(metrics):
211
+ opt_picker.add(metrics=v)
212
+ print()
213
+ print(s, v)
214
+ print(opt_picker.get()[1])
215
+ print([i["step"] for i in opt_picker.get()[0]])
216
+
217
+ for i in best_idx_ls:
218
+ assert os.path.isfile(os.path.join(temp_dir, f'{i}.json'))
@@ -1,3 +1,4 @@
1
1
  from .accumulator_base import Accumulator_Base
2
2
  from .exponential_moving_average import Exponential_Moving_Average
3
3
  from .average_accumulator import Average_Accumulator
4
+ from .accumulator_for_ndl import Accumulator_for_Ndl
@@ -71,7 +71,7 @@ class Accumulator_Base(object):
71
71
 
72
72
  def add_sequence(self, var_ls, **kwargs):
73
73
  # for var in var_ls:
74
- # self.add(var)
74
+ # self.add(var, **kwargs)
75
75
  raise NotImplementedError
76
76
 
77
77
  def add(self, var, **kwargs):
@@ -112,7 +112,7 @@ class Accumulator_Base(object):
112
112
  if like is not None:
113
113
  var = init_var.by_like(var=like)
114
114
  elif data_format is not None:
115
- var = init_var.by_data_format(**kwargs)
115
+ var = init_var.by_data_format(**data_format)
116
116
  else:
117
117
  var = None
118
118
  return var
@@ -0,0 +1,69 @@
1
+ import torch
2
+ import kevin_toolbox.nested_dict_list as ndl
3
+ from kevin_toolbox.computer_science.algorithm.statistician import Accumulator_Base
4
+
5
+
6
+ class Accumulator_for_Ndl:
7
+ """
8
+ 适用于 ndl 结构的统计器
9
+ """
10
+
11
+ def __init__(self, accumulator_builder):
12
+ """
13
+ 参数:
14
+ accumulator_builder: ndl叶节点统计器的构造函数
15
+ """
16
+ assert callable(accumulator_builder) or isinstance(accumulator_builder, Accumulator_Base)
17
+ self.accumulator_builder = accumulator_builder
18
+
19
+ self.var = None
20
+
21
+ def add(self, var, **kwargs):
22
+ if self.var is None and isinstance(var, (dict, list)):
23
+ self.var = type(var)()
24
+ for name, value in ndl.get_nodes(var=var, level=-1, b_strict=True):
25
+ accumulator = ndl.get_value(var=self.var, name=name, default=None)
26
+ if accumulator is None:
27
+ accumulator = self.accumulator_builder()
28
+ self.var = ndl.set_value(var=self.var, name=name, value=accumulator, b_force=True)
29
+ value = value.detach().cpu().numpy() if torch.is_tensor(value) else value
30
+ accumulator.add(value, **kwargs)
31
+
32
+ def add_sequence(self, var_ls, **kwargs):
33
+ for var in var_ls:
34
+ self.add(var, **kwargs)
35
+
36
+ def get(self, **kwargs):
37
+ return ndl.traverse(
38
+ var=ndl.copy_(var=self.var, b_deepcopy=False),
39
+ match_cond=lambda _, __, v: not isinstance(v, (dict, list)) and hasattr(v, "get"), action_mode="replace",
40
+ converter=lambda _, v: v.get(**kwargs)
41
+ )
42
+
43
+
44
+ if __name__ == '__main__':
45
+ from kevin_toolbox.data_flow.file import markdown
46
+ import numpy as np
47
+ from kevin_toolbox.computer_science.algorithm.statistician import Average_Accumulator
48
+
49
+ worker = Accumulator_for_Ndl(accumulator_builder=Average_Accumulator)
50
+
51
+ worker.add({
52
+ 1: 2.1,
53
+ "233": torch.ones(10),
54
+ "543": [
55
+ np.array([1, 2, 3]),
56
+ np.array([4, 5, 6]),
57
+ ]
58
+ }, weight=0.8)
59
+
60
+ worker.add({
61
+ 1: 3.1,
62
+ "233": torch.zeros(10),
63
+ "543": [
64
+ np.array([0, 2, 3]),
65
+ np.array([0, 5, 6]),
66
+ ]
67
+ }, weight=1.4)
68
+
69
+ print(markdown.generate_list(var=worker.get()))
@@ -31,20 +31,22 @@ class Average_Accumulator(Accumulator_Base):
31
31
 
32
32
  def add_sequence(self, var_ls, **kwargs):
33
33
  for var in var_ls:
34
- self.add(var)
34
+ self.add(var, **kwargs)
35
35
 
36
- def add(self, var, **kwargs):
36
+ def add(self, var, weight=1, **kwargs):
37
37
  """
38
38
  添加单个数据
39
39
 
40
40
  参数:
41
41
  var: 数据
42
+ weight: 权重
42
43
  """
43
44
  if self.var is None:
44
45
  self.var = self._init_var(like=var)
45
46
  # 累积
46
- self.var += var
47
+ self.var = self.var + var * weight
47
48
  self.state["total_nums"] += 1
49
+ self.state["total_weights"] += weight
48
50
 
49
51
  def get(self, **kwargs):
50
52
  """
@@ -53,7 +55,17 @@ class Average_Accumulator(Accumulator_Base):
53
55
  """
54
56
  if len(self) == 0:
55
57
  return None
56
- return self.var / len(self)
58
+ return self.var / self.state["total_weights"]
59
+
60
+ @staticmethod
61
+ def _init_state():
62
+ """
63
+ 初始化状态
64
+ """
65
+ return dict(
66
+ total_nums=0,
67
+ total_weights=0,
68
+ )
57
69
 
58
70
 
59
71
  if __name__ == '__main__':
@@ -58,18 +58,11 @@ class Exponential_Moving_Average(Accumulator_Base):
58
58
  #
59
59
  super(Exponential_Moving_Average, self).__init__(**paras)
60
60
 
61
- def add_sequence(self, var_ls, weight_ls=None):
62
- if weight_ls is not None:
63
- if isinstance(weight_ls, (int, float,)):
64
- weight_ls = [weight_ls] * len(var_ls)
65
- assert len(weight_ls) == len(var_ls)
66
- for var, weight in enumerate(var_ls, weight_ls):
67
- self.add(var, weight)
68
- else:
69
- for var in var_ls:
70
- self.add(var)
61
+ def add_sequence(self, var_ls, **kwargs):
62
+ for var in var_ls:
63
+ self.add(var, **kwargs)
71
64
 
72
- def add(self, var, weight=1):
65
+ def add(self, var, weight=1, **kwargs):
73
66
  """
74
67
  添加单个数据
75
68
 
@@ -88,7 +81,7 @@ class Exponential_Moving_Average(Accumulator_Base):
88
81
  self.state["total_nums"] += 1
89
82
  self.state["bias_fix"] *= keep_ratio
90
83
 
91
- def get(self, bias_correction=None):
84
+ def get(self, bias_correction=None, **kwargs):
92
85
  """
93
86
  获取当前累加值
94
87
  当未初始化时,返回 None
@@ -1,15 +1,17 @@
1
1
  import os
2
2
  import json
3
+ from io import BytesIO, StringIO
3
4
  from kevin_toolbox.data_flow.file.json_.converter import integrate, unescape_tuple_and_set, unescape_non_str_dict_key
4
5
  from kevin_toolbox.nested_dict_list import traverse
5
6
 
6
7
 
7
- def read_json(file_path, converters=None, b_use_suggested_converter=False):
8
+ def read_json(file_path=None, file_obj=None, converters=None, b_use_suggested_converter=False):
8
9
  """
9
10
  读取 json file
10
11
 
11
12
  参数:
12
13
  file_path
14
+ file_obj
13
15
  converters: <list of converters> 对读取内容中每个节点的处理方式
14
16
  转换器 converter 应该是一个形如 def(x): ... ; return x 的函数,具体可以参考
15
17
  json_.converter 中已实现的转换器
@@ -19,13 +21,17 @@ def read_json(file_path, converters=None, b_use_suggested_converter=False):
19
21
  默认为 False。
20
22
  注意:当 converters 非 None,此参数失效,以 converters 中的具体设置为准
21
23
  """
22
- assert os.path.isfile(file_path), f'file {file_path} not found'
24
+ assert file_path is not None or file_obj is not None
25
+ if file_path is not None:
26
+ assert os.path.isfile(file_path), f'file {file_path} not found'
27
+ file_obj = open(file_path, 'r')
28
+ elif isinstance(file_obj, (BytesIO,)):
29
+ file_obj = StringIO(file_obj.read().decode('utf-8'))
30
+ content = json.load(file_obj)
31
+
23
32
  if converters is None and b_use_suggested_converter:
24
33
  converters = [unescape_tuple_and_set, unescape_non_str_dict_key]
25
34
 
26
- with open(file_path, 'r') as f:
27
- content = json.load(f)
28
-
29
35
  if converters is not None:
30
36
  converter = integrate(converters)
31
37
  content = traverse(var=[content],
@@ -5,7 +5,7 @@ import warnings
5
5
  from kevin_toolbox.data_flow.file.kevin_notation.converter import Converter, CONVERTER_FOR_WRITER
6
6
  from kevin_toolbox.data_flow.file import kevin_notation
7
7
 
8
- np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
8
+ warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
9
9
 
10
10
 
11
11
  class Kevin_Notation_Writer:
@@ -313,6 +313,9 @@ class Kevin_Notation_Writer:
313
313
  try:
314
314
  # 解释为多行
315
315
  assert paras.get("b_single_line", None) in (None, False)
316
+ temp = [len(paras["column_dict"][k]) for k in self.metadata["column_name"]]
317
+ if temp:
318
+ assert max(temp) == min(temp), f"Error: the length of each column is not equal!"
316
319
  row_ls = list(zip(*[paras["column_dict"][k] for k in self.metadata["column_name"]]))
317
320
  except:
318
321
  # 解释为单行
@@ -0,0 +1,27 @@
1
+ import pytest
2
+ from kevin_toolbox.patches.for_test import check_consistency
3
+
4
+ import os
5
+ import numpy as np
6
+
7
+ from kevin_toolbox.data_flow.file import kevin_notation
8
+ from kevin_toolbox.data_flow.file.kevin_notation.test.test_data.data_all import metadata_ls, content_ls, file_path_ls
9
+
10
+
11
+ @pytest.mark.parametrize("expected_metadata, expected_content, file_path",
12
+ zip(metadata_ls, content_ls, file_path_ls))
13
+ def test_write(expected_metadata, expected_content, file_path):
14
+ print("test write()")
15
+
16
+ """
17
+ 当写入的列的元素不一致时,是否能正常报错
18
+ """
19
+
20
+ # 新建
21
+ file_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "test_data/temp", os.path.basename(file_path))
22
+
23
+ # 字典方式写入
24
+ if len(expected_content) > 1:
25
+ with pytest.raises(AssertionError):
26
+ list(expected_content.values())[0].clear()
27
+ kevin_notation.write(metadata=expected_metadata, content=expected_content, file_path=file_path)
@@ -5,6 +5,9 @@ def write(metadata, content, file_path):
5
5
  """
6
6
  写入整个文件的快捷接口
7
7
  """
8
+ if "column_num" in metadata:
9
+ metadata=metadata.copy()
10
+ metadata.pop("column_num")
8
11
  with kevin_notation.Writer(file_path=file_path, mode="w", sep=metadata.get("sep", "\t")) as writer:
9
12
  writer.write_metadata(metadata=metadata)
10
13
  if isinstance(content, (dict,)):
@@ -1,3 +1,6 @@
1
1
  from .generate_link import generate_link
2
2
  from .generate_list import generate_list
3
3
  from .generate_table import generate_table
4
+ from .parse_table import parse_table
5
+ from .find_tables import find_tables
6
+ from .save_images_in_ndl import save_images_in_ndl
@@ -0,0 +1,65 @@
1
+ import re
2
+
3
+
4
+ def find_tables(text):
5
+ """
6
+ 查找文本中的表格
7
+ 将返回一个列表,列表每个元素系一个二维的数组,表示一个原始的表格
8
+ """
9
+ table_ls = []
10
+ for sub_text in text.split('\n\n', -1):
11
+ ret = _find_table(text=sub_text)
12
+ if ret is not None:
13
+ table_ls.append(ret)
14
+
15
+ return table_ls
16
+
17
+
18
+ def _find_table(text):
19
+ # 正则表达式匹配Markdown表格
20
+ table_pattern = re.compile(r'\|([^\n]+)\|', re.DOTALL)
21
+ table_matches = table_pattern.findall(text)
22
+ if len(table_matches) < 2:
23
+ # 因为一个合法的 markdown 表格需要含有表头的分隔线,所以行数至少应该为 2
24
+ return None
25
+
26
+ # 去除表头的分隔线
27
+ table_matches.pop(1)
28
+ #
29
+ tables = [] # 每个元素为一行
30
+ for match in table_matches:
31
+ # 分割每一行
32
+ tables.append([i.strip() for i in match.split('|', -1)])
33
+
34
+ return tables
35
+
36
+
37
+ if __name__ == '__main__':
38
+ # # 示例Markdown表格文本
39
+ # file_path = ""
40
+ # with open(file_path, 'r') as f:
41
+ # markdown_text = f.read()
42
+
43
+ markdown_text = """
44
+ | Name | Age | Occupation |
45
+ |------|-----|------------|
46
+ | Alice | 28 | Engineer |
47
+ | Bob | 23 | Teacher |
48
+ | Name | Age | Occupation |
49
+ | Carol | 32 | Hacker |
50
+ | David | 18 | Student |
51
+
52
+ 2333
53
+
54
+ | | a | b | | a | b | | a | b |
55
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- |
56
+ | | 0 | 2 | | 4 | 6 | | 7 | 9 |
57
+ | | 1 | 3 | | 5 | 7 | | 8 | : |
58
+ | | 2 | 4 | | 6 | 8 | | 9 | ; |
59
+ | | 3 | 5 | | | | | | |
60
+ """
61
+
62
+ # 调用函数并打印结果
63
+ tables = find_tables(text=markdown_text)
64
+ print(tables[0])
65
+ print(tables[1])
@@ -8,7 +8,7 @@ def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_siz
8
8
 
9
9
  参数:
10
10
  content_s: <dict> 内容
11
- 支持两种输入模式:
11
+ 目前支持 Table_Format 中的两种输入模式:
12
12
  1.简易模式:
13
13
  content_s = {<title>: <list of value>, ...}
14
14
  此时键作为标题,值作为标题下的一系列值。
@@ -24,7 +24,7 @@ def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_siz
24
24
  chunk_nums: <int> 将表格平均分割为多少份进行并列显示。
25
25
  chunk_size: <int> 将表格按照最大长度进行分割,然后并列显示。
26
26
  注意:以上两个参数只能设置一个,同时设置时将报错
27
- b_allow_misaligned_values <boolean> 允许不对齐的 values
27
+ b_allow_misaligned_values: <boolean> 允许不对齐的 values
28
28
  默认为 False,此时当不同标题下的 values 的长度不相等时,将会直接报错。
29
29
  当设置为 True 时,对于短于最大长度的 values 将直接补充 ""。
30
30
  f_gen_order_of_values: <callable> 生成values排序顺序的函数
@@ -34,6 +34,7 @@ def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_siz
34
34
  assert chunk_nums is None or 1 <= chunk_nums
35
35
  assert chunk_size is None or 1 <= chunk_size
36
36
  assert orientation in ["vertical", "horizontal", "h", "v"]
37
+ assert isinstance(content_s, (dict,))
37
38
 
38
39
  # 将简易模式转换为完整模式
39
40
  if len(content_s.values()) > 0 and not isinstance(list(content_s.values())[0], (dict,)):
@@ -49,6 +50,10 @@ def generate_table(content_s, orientation="vertical", chunk_nums=None, chunk_siz
49
50
  v["values"].extend([""] * (max_length - len(v["values"])))
50
51
  # 对值进行排序
51
52
  if callable(f_gen_order_of_values):
53
+ # 检查是否有重复的 title
54
+ temp = [v["title"] for v in content_s.values()]
55
+ assert len(set(temp)) == len(temp), \
56
+ f'table has duplicate titles, thus cannot be sorted using f_gen_order_of_values'
52
57
  idx_ls = list(range(max_length))
53
58
  idx_ls.sort(key=lambda x: f_gen_order_of_values({v["title"]: v["values"][x] for v in content_s.values()}))
54
59
  for v in content_s.values():
@@ -108,9 +113,9 @@ def _show_table(content_s, orientation="vertical"):
108
113
 
109
114
 
110
115
  if __name__ == '__main__':
111
- content_s = {0: dict(title="a", values=[1, 2, 3]), 2: dict(title="b", values=[4, 5, 6])}
112
- doc = generate_table(content_s=content_s, orientation="h")
113
- print(doc)
116
+ # content_s = {0: dict(title="a", values=[1, 2, 3]), 2: dict(title="b", values=[4, 5, 6])}
117
+ # doc = generate_table(content_s=content_s, orientation="h")
118
+ # print(doc)
114
119
 
115
120
  # from collections import OrderedDict
116
121
  #
@@ -128,3 +133,12 @@ if __name__ == '__main__':
128
133
  # "/home/SENSETIME/xukaiming/Desktop/my_repos/python_projects/kevin_toolbox/kevin_toolbox/data_flow/file/markdown/test/test_data/for_generate_table",
129
134
  # f"data_5.md"), "w") as f:
130
135
  # f.write(doc)
136
+
137
+ doc = generate_table(
138
+ content_s={'y/n': ['False', 'False', 'False', 'False', 'False', 'True', 'True', 'True', 'True', 'True'],
139
+ 'a': ['5', '8', '7', '6', '9', '2', '1', '4', '0', '3'],
140
+ 'b': ['', '', '', '', '', '6', '4', ':', '2', '8']},
141
+ orientation="v", chunk_size=4, b_allow_misaligned_values=True,
142
+ f_gen_order_of_values=lambda x: (-int(eval(x["y/n"]) is False), -(int(x["a"]) % 3))
143
+ )
144
+ print(doc)