kele 0.0.1a1__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/__init__.py +38 -0
- kele/_version.py +1 -0
- kele/config.py +243 -0
- kele/control/README_metrics.md +102 -0
- kele/control/__init__.py +20 -0
- kele/control/callback.py +255 -0
- kele/control/grounding_selector/__init__.py +5 -0
- kele/control/grounding_selector/_rule_strategies/README.md +13 -0
- kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
- kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
- kele/control/grounding_selector/_selector_utils.py +123 -0
- kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
- kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
- kele/control/grounding_selector/rule_selector.py +98 -0
- kele/control/grounding_selector/term_selector.py +89 -0
- kele/control/infer_path.py +306 -0
- kele/control/metrics.py +357 -0
- kele/control/status.py +286 -0
- kele/egg_equiv.pyi +11 -0
- kele/egg_equiv.so +0 -0
- kele/equality/README.md +8 -0
- kele/equality/__init__.py +4 -0
- kele/equality/_egg_equiv/src/lib.rs +267 -0
- kele/equality/_equiv_elem.py +67 -0
- kele/equality/_utils.py +36 -0
- kele/equality/equivalence.py +141 -0
- kele/executer/__init__.py +4 -0
- kele/executer/executing.py +139 -0
- kele/grounder/README.md +83 -0
- kele/grounder/__init__.py +17 -0
- kele/grounder/grounded_rule_ds/__init__.py +6 -0
- kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
- kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
- kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
- kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
- kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
- kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
- kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
- kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
- kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
- kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
- kele/grounder/grounded_rule_ds/rule_check.py +373 -0
- kele/grounder/grounding.py +118 -0
- kele/knowledge_bases/README.md +112 -0
- kele/knowledge_bases/__init__.py +6 -0
- kele/knowledge_bases/builtin_base/__init__.py +1 -0
- kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
- kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
- kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
- kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
- kele/knowledge_bases/fact_base.py +158 -0
- kele/knowledge_bases/ontology_base.py +67 -0
- kele/knowledge_bases/rule_base.py +194 -0
- kele/main.py +464 -0
- kele/py.typed +0 -0
- kele/syntax/CONCEPT_README.md +117 -0
- kele/syntax/__init__.py +40 -0
- kele/syntax/_cnf_converter.py +161 -0
- kele/syntax/_sat_solver.py +116 -0
- kele/syntax/base_classes.py +1482 -0
- kele/syntax/connectives.py +20 -0
- kele/syntax/dnf_converter.py +145 -0
- kele/syntax/external.py +17 -0
- kele/syntax/sub_concept.py +87 -0
- kele/syntax/syntacticsugar.py +201 -0
- kele-0.0.1a1.dist-info/METADATA +165 -0
- kele-0.0.1a1.dist-info/RECORD +73 -0
- kele-0.0.1a1.dist-info/WHEEL +6 -0
- kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from pyvis.network import Network
|
|
5
|
+
from collections import deque
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Literal
|
|
8
|
+
from kele.syntax import Assertion, Formula
|
|
9
|
+
import logging
|
|
10
|
+
from kele.syntax import FACT_TYPE
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from kele.config import RunControlConfig
|
|
14
|
+
from collections.abc import Sequence
|
|
15
|
+
from kele.syntax import Rule
|
|
16
|
+
from kele.equality import Equivalence
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# 单个推理步:记录一个事实由哪条规则得到,以及它与前后事实的连接
|
|
22
|
+
class FactStep:
|
|
23
|
+
"""与上游/下游的事实联系的封装,记录事实来源用的"""
|
|
24
|
+
def __init__(self, content: FACT_TYPE, infer_step: Rule | tuple[Assertion, ...] | None,
|
|
25
|
+
fact_type: Literal['premise', 'equivalence', 'rule_infer']) -> None:
|
|
26
|
+
# 当前仅记录等价类推导的“存在性”,不追溯具体等价链路
|
|
27
|
+
# TODO: 可扩展记录推理深度或来源解释
|
|
28
|
+
self.fact_type: Literal['premise', 'equivalence', 'rule_infer'] = fact_type # 事实的类型
|
|
29
|
+
self.content: FACT_TYPE = content # 实例化后的事实
|
|
30
|
+
self.infer_step: Rule | tuple[Assertion, ...] | None = infer_step # 派生该事实的规则,若由等价关系/同余闭包推导则为tuple,
|
|
31
|
+
# 若为前提事实则为 None
|
|
32
|
+
self._next_facts: list[FactStep] = [] # 由当前事实推演出的下游事实
|
|
33
|
+
self._prev_facts: list[FactStep] = [] # 支撑当前事实的上游事实
|
|
34
|
+
|
|
35
|
+
def add_next(self, fact: FactStep) -> None:
|
|
36
|
+
"""将事实与它帮助推导的下游事实联系起来"""
|
|
37
|
+
self._next_facts.append(fact)
|
|
38
|
+
|
|
39
|
+
def add_prev(self, fact: FactStep) -> None:
|
|
40
|
+
"""将事实连接到支持它的上游事实"""
|
|
41
|
+
self._prev_facts.append(fact)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def next(self) -> tuple[FactStep, ...]:
|
|
45
|
+
"""下游事实"""
|
|
46
|
+
return tuple(self._next_facts)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def prev(self) -> tuple[FactStep, ...]:
|
|
50
|
+
"""上游事实"""
|
|
51
|
+
return tuple(self._prev_facts)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def step_name(self) -> str:
|
|
55
|
+
"""
|
|
56
|
+
FactStep的名称,用于打印
|
|
57
|
+
"""
|
|
58
|
+
if self.fact_type == 'premise':
|
|
59
|
+
return f"无前提事实:{self.content !s}"
|
|
60
|
+
if self.fact_type == 'equivalence':
|
|
61
|
+
return f"等价推出事实:{self.content !s}"
|
|
62
|
+
return f"规则推导:{self.infer_step !s} 新事实({self.content !s})"
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
65
|
+
rule_name = getattr(self.infer_step, "name", None)
|
|
66
|
+
return f"FactStep({self.content}, rule={rule_name})"
|
|
67
|
+
|
|
68
|
+
def __hash__(self) -> int:
|
|
69
|
+
return hash((self.content, self.infer_step))
|
|
70
|
+
|
|
71
|
+
def __eq__(self, other: object) -> bool:
|
|
72
|
+
return isinstance(other, FactStep) and self.content == other.content and self.infer_step == other.infer_step
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class InferencePath:
|
|
76
|
+
"""
|
|
77
|
+
存储推理图:
|
|
78
|
+
1. forward : (antecedent_fact, rule) -> [consequent_facts]
|
|
79
|
+
2. reverse : (consequent_fact, rule) -> [antecedent_facts]
|
|
80
|
+
"""
|
|
81
|
+
def __init__(self, args: RunControlConfig, equivalence: Equivalence) -> None:
|
|
82
|
+
self._args = args
|
|
83
|
+
|
|
84
|
+
self.fact_factstep_pool: dict[FACT_TYPE, FactStep] = {}
|
|
85
|
+
self.terminal_step: FactStep | None = None # 记录最后的终点fact,通常是question对应的fact
|
|
86
|
+
self.equivalence: Equivalence = equivalence
|
|
87
|
+
self.initial_facts: set[FACT_TYPE] = set()
|
|
88
|
+
|
|
89
|
+
self._fact_counter = 1
|
|
90
|
+
self._step_counter = 1
|
|
91
|
+
self.fact_factid_map: dict[FACT_TYPE, str] = {}
|
|
92
|
+
self.step_stepid_map: dict[str, str] = {}
|
|
93
|
+
|
|
94
|
+
def _add_initial_facts(self, facts: Sequence[FACT_TYPE] | FACT_TYPE) -> None:
|
|
95
|
+
"""
|
|
96
|
+
添加初始事实
|
|
97
|
+
"""
|
|
98
|
+
if isinstance(facts, FACT_TYPE):
|
|
99
|
+
facts = [facts]
|
|
100
|
+
|
|
101
|
+
for fact in facts:
|
|
102
|
+
if isinstance(fact, Assertion):
|
|
103
|
+
self.initial_facts.add(fact)
|
|
104
|
+
self.initial_facts.add(self._reverse_fact(fact))
|
|
105
|
+
else:
|
|
106
|
+
self.initial_facts.add(fact)
|
|
107
|
+
|
|
108
|
+
def _is_validate_none_premise_assertion(self, fact: Assertion) -> bool:
|
|
109
|
+
"""
|
|
110
|
+
检查一个fact是否是一个合法的前提为None的Assertion
|
|
111
|
+
以下两种情况前提为None
|
|
112
|
+
1、fact出现在initial_fact里面
|
|
113
|
+
2、左右显然相等
|
|
114
|
+
"""
|
|
115
|
+
return fact in self.initial_facts or fact.lhs == fact.rhs or fact.is_action_assertion
|
|
116
|
+
|
|
117
|
+
def _query_equiv_step(self, fact: FACT_TYPE) -> FactStep:
|
|
118
|
+
"""
|
|
119
|
+
获取一个fact的推理路径,总共由三个可能:
|
|
120
|
+
1. 它是由等价关系推出来的
|
|
121
|
+
2. 它是由规则推出来的
|
|
122
|
+
3. 它是一个前提事实
|
|
123
|
+
最后都会返回一个FactStep
|
|
124
|
+
|
|
125
|
+
:param fact: 待检查的事实
|
|
126
|
+
:type fact: FACT_TYPE
|
|
127
|
+
:raises RuntimeError: 若等价关系处理器未设置
|
|
128
|
+
:return: 若fact是由等价关系推出来的,则返回等价关系的FactStep,否则返回None
|
|
129
|
+
:rtype: FactStep
|
|
130
|
+
""" # noqa: DOC501
|
|
131
|
+
if fact in self.fact_factstep_pool:
|
|
132
|
+
# fact 已记录过推理路径,直接复用
|
|
133
|
+
return self.fact_factstep_pool[fact]
|
|
134
|
+
if isinstance(fact, Assertion) and self._is_validate_none_premise_assertion(fact):
|
|
135
|
+
# fact是一个前提事实
|
|
136
|
+
return FactStep(fact, None, 'premise')
|
|
137
|
+
# Assertion的factstep需要考虑是否是等价关系推导出来的,但是Formula类型的(实质只可能为NOT Assertion)则不需要
|
|
138
|
+
if isinstance(fact, Formula) and isinstance(fact.formula_left, Assertion) and fact.connective == 'NOT':
|
|
139
|
+
# NOT Assertion类型的Fact不需要考虑等价关系
|
|
140
|
+
# 在正常情况下,它自然是成立的前提事实,否则是不可能推理出结果的
|
|
141
|
+
return FactStep(fact, None, 'premise')
|
|
142
|
+
if isinstance(fact, Formula):
|
|
143
|
+
raise TypeError(
|
|
144
|
+
"Rule premises cannot contain connectives other than AND and NOT. "
|
|
145
|
+
"This error may come from CNF_convert."
|
|
146
|
+
)
|
|
147
|
+
if self.equivalence is None:
|
|
148
|
+
raise RuntimeError(
|
|
149
|
+
"Equivalence handler is not set; cannot properly record inference paths for equivalence facts."
|
|
150
|
+
)
|
|
151
|
+
if self.equivalence.query_equivalence(fact):
|
|
152
|
+
fact_step = FactStep(fact, None, 'equivalence') # HACK:暂时不详细处理等价关系推出的事实,
|
|
153
|
+
# 后续需要获取等价关系的解释
|
|
154
|
+
self.fact_factstep_pool[fact] = fact_step
|
|
155
|
+
self.fact_factstep_pool[self._reverse_fact(fact)] = fact_step # 对称事实也要记录进去
|
|
156
|
+
return fact_step
|
|
157
|
+
|
|
158
|
+
raise ValueError(f"Fact {fact!s} is not true; cannot record inference path.")
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _reverse_fact(fact: FACT_TYPE) -> FACT_TYPE:
|
|
162
|
+
if isinstance(fact, Assertion):
|
|
163
|
+
return Assertion.from_parts(fact.rhs, fact.lhs)
|
|
164
|
+
return fact
|
|
165
|
+
|
|
166
|
+
def add_infer_edge(self,
|
|
167
|
+
consequent: FACT_TYPE, # FIXME: 这里得缩减为Assertion
|
|
168
|
+
antecedents: list[FACT_TYPE] | None = None,
|
|
169
|
+
grounded_rule: Rule | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
"""
|
|
172
|
+
录入一条推理边:多前提 → 单结论
|
|
173
|
+
:param antecedents: 对应某条rule的前提,不过已经实例化过了
|
|
174
|
+
:param consequent: 对应规则后件的实例化结果
|
|
175
|
+
:param grounded_rule: 触发推理的规则
|
|
176
|
+
:return: None
|
|
177
|
+
"""
|
|
178
|
+
if antecedents is None:
|
|
179
|
+
return self._add_initial_facts(consequent)
|
|
180
|
+
|
|
181
|
+
if consequent in self.fact_factstep_pool:
|
|
182
|
+
# 事实已经存在,推理路径默认保留一条即可
|
|
183
|
+
# TODO: 可选保留多条推理路径
|
|
184
|
+
return None
|
|
185
|
+
# 记录结论的推理路径
|
|
186
|
+
conse_step = FactStep(consequent, grounded_rule, 'rule_infer')
|
|
187
|
+
|
|
188
|
+
self.fact_factstep_pool[consequent] = conse_step
|
|
189
|
+
self.fact_factstep_pool[self._reverse_fact(consequent)] = conse_step # 对称事实也要记录进去
|
|
190
|
+
for fact in antecedents:
|
|
191
|
+
factstep = self._query_equiv_step(fact)
|
|
192
|
+
self.fact_factstep_pool[fact] = factstep
|
|
193
|
+
self.fact_factstep_pool[self._reverse_fact(fact)] = factstep # 对称事实也要记录进去
|
|
194
|
+
|
|
195
|
+
factstep.add_next(conse_step)
|
|
196
|
+
conse_step.add_prev(factstep)
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
def add_terminal_status(self, termnial_fact: FACT_TYPE) -> None:
|
|
200
|
+
"""记录终点事实"""
|
|
201
|
+
try:
|
|
202
|
+
self.terminal_step = self._query_equiv_step(termnial_fact) # termnimal_step也要考虑由等价关系推出的可能
|
|
203
|
+
except ValueError:
|
|
204
|
+
warnings.warn(f"Terminal fact {termnial_fact!s} is trivially true.", stacklevel=1)
|
|
205
|
+
self.terminal_step = None
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def _print_log_info(prev_fact_steps: list[FactStep], infer_path: deque[FactStep], terminal_fact: FACT_TYPE) -> None:
|
|
209
|
+
logger.info("================Premise facts:=================")
|
|
210
|
+
for prev_fact_counter, fact_step in enumerate(prev_fact_steps):
|
|
211
|
+
logger.info("%d. %s", prev_fact_counter + 1, fact_step.step_name) # FIXME: 这里的注释有点奇怪,一个数字
|
|
212
|
+
# 一个name。博洋改到这里的时候留意一下,反正你的infer最近要动,我就不细究了
|
|
213
|
+
logger.info("================Inference path:=================")
|
|
214
|
+
for infer_fact_counter, fact_step in enumerate(infer_path):
|
|
215
|
+
logger.info("step %d: %s", infer_fact_counter + 1, fact_step.step_name)
|
|
216
|
+
logger.info("================Terminal fact:=================")
|
|
217
|
+
logger.info("Terminal fact: %s", terminal_fact)
|
|
218
|
+
|
|
219
|
+
def get_infer_graph(self, terminal_fact: FACT_TYPE | None = None) -> tuple[list[FactStep], FACT_TYPE | None]:
|
|
220
|
+
"""
|
|
221
|
+
获得推理路径
|
|
222
|
+
:param terminal_fact: 推理的终点事实,默认是question对应的fact
|
|
223
|
+
:return: 推理路径,终点事实
|
|
224
|
+
"""
|
|
225
|
+
if not self._args.trace:
|
|
226
|
+
warnings.warn("Inference path tracing is disabled; cannot print inference path.", stacklevel=5)
|
|
227
|
+
return [], None
|
|
228
|
+
terminal_step = self.terminal_step if terminal_fact is None else self._query_equiv_step(terminal_fact)
|
|
229
|
+
|
|
230
|
+
infered: set[FactStep] = set()
|
|
231
|
+
if terminal_step is not None:
|
|
232
|
+
infer_path: deque[FactStep] = deque()
|
|
233
|
+
prev_fact_steps: list[FactStep] = []
|
|
234
|
+
cur_fact_queue = deque([terminal_step])
|
|
235
|
+
while cur_fact_queue:
|
|
236
|
+
cur_fact_step = cur_fact_queue.popleft()
|
|
237
|
+
if cur_fact_step in infered: # 推出的事实不再重复推导
|
|
238
|
+
continue
|
|
239
|
+
infered.add(cur_fact_step)
|
|
240
|
+
|
|
241
|
+
if cur_fact_step.fact_type != 'premise':
|
|
242
|
+
infer_path.appendleft(cur_fact_step)
|
|
243
|
+
else:
|
|
244
|
+
# 前提事实全部在第一步展示
|
|
245
|
+
# 这里的前提事实指的是原始Premises中真正被用于推理的那些前提事实
|
|
246
|
+
prev_fact_steps.append(cur_fact_step)
|
|
247
|
+
if cur_fact_step.infer_step is not None:
|
|
248
|
+
cur_fact_queue.extend(cur_fact_step.prev)
|
|
249
|
+
|
|
250
|
+
self._print_log_info(prev_fact_steps, infer_path, terminal_step.content)
|
|
251
|
+
prev_fact_steps.extend(infer_path) # 将分开的两个集合合并起来返回,此时顺序已经被确定下来
|
|
252
|
+
return prev_fact_steps, terminal_step.content
|
|
253
|
+
warnings.warn("Inference engine could not derive a result, or the terminal fact is trivially true.", stacklevel=1)
|
|
254
|
+
return [], None
|
|
255
|
+
|
|
256
|
+
def _get_fact_id(self, fact: FACT_TYPE, net: Network) -> str:
|
|
257
|
+
if fact not in self.fact_factid_map:
|
|
258
|
+
self.fact_factid_map[fact] = f"fact{self._fact_counter}"
|
|
259
|
+
self._fact_counter += 1
|
|
260
|
+
net.add_node(self.fact_factid_map[fact], label=self.fact_factid_map[fact], title=str(fact))
|
|
261
|
+
return self.fact_factid_map[fact]
|
|
262
|
+
|
|
263
|
+
def _get_step_id(self, step_name: str, net: Network) -> str:
|
|
264
|
+
if step_name not in self.step_stepid_map:
|
|
265
|
+
self.step_stepid_map[step_name] = f"step{self._step_counter}"
|
|
266
|
+
self._step_counter += 1
|
|
267
|
+
net.add_node(self.step_stepid_map[step_name], label=self.step_stepid_map[step_name], title=str(step_name), shape="square", color="red")
|
|
268
|
+
return self.step_stepid_map[step_name]
|
|
269
|
+
|
|
270
|
+
def gennerate_infer_path_graph(self, infer_path: list[FactStep], terminal_fact: FACT_TYPE | None = None) -> None:
|
|
271
|
+
"""
|
|
272
|
+
生成推理路径的图
|
|
273
|
+
:param infer_path: 推理路径
|
|
274
|
+
:param terminal_fact: 推理的终点事实,默认是question对应的fact
|
|
275
|
+
:return: None
|
|
276
|
+
"""
|
|
277
|
+
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black")
|
|
278
|
+
self._fact_counter = 1
|
|
279
|
+
self._step_counter = 1
|
|
280
|
+
self.fact_factid_map.clear()
|
|
281
|
+
self.step_stepid_map.clear()
|
|
282
|
+
if terminal_fact is not None:
|
|
283
|
+
self.fact_factid_map[terminal_fact] = "terminal_fact"
|
|
284
|
+
net.add_node("terminal_fact", label="终点事实", shape="star", color="red")
|
|
285
|
+
|
|
286
|
+
for fact_step in infer_path:
|
|
287
|
+
cur_fact_id = self._get_fact_id(fact_step.content, net)
|
|
288
|
+
cur_step_id = self._get_step_id(fact_step.step_name, net) if fact_step.infer_step is not None else None
|
|
289
|
+
if cur_step_id is not None:
|
|
290
|
+
for fact in fact_step.prev:
|
|
291
|
+
prev_fact_id = self._get_fact_id(fact.content, net)
|
|
292
|
+
net.add_edge(prev_fact_id, cur_step_id, label="前提", color="blue")
|
|
293
|
+
net.add_edge(cur_step_id, cur_fact_id, label="结论", color="red", arrows="to")
|
|
294
|
+
else:
|
|
295
|
+
for nodes in net.nodes:
|
|
296
|
+
if nodes["id"] == cur_fact_id:
|
|
297
|
+
nodes["label"] = "无前提事实"
|
|
298
|
+
nodes["shape"] = "triangle"
|
|
299
|
+
nodes["color"] = "green"
|
|
300
|
+
net.save_graph("infer_path.html")
|
|
301
|
+
|
|
302
|
+
def reset(self) -> None:
|
|
303
|
+
"""重置推理路径"""
|
|
304
|
+
self.fact_factstep_pool.clear()
|
|
305
|
+
self.terminal_step = None
|
|
306
|
+
self.initial_facts.clear()
|
kele/control/metrics.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# metrics_typed.py
|
|
2
|
+
# pip install prometheus-client psutil
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
import warnings
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from datetime import datetime, UTC
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, ParamSpec, TypeVar, TYPE_CHECKING, Self
|
|
15
|
+
|
|
16
|
+
import psutil
|
|
17
|
+
from prometheus_client import (
|
|
18
|
+
CollectorRegistry,
|
|
19
|
+
Counter,
|
|
20
|
+
Gauge,
|
|
21
|
+
Histogram,
|
|
22
|
+
push_to_gateway,
|
|
23
|
+
start_http_server,
|
|
24
|
+
)
|
|
25
|
+
import logging
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from collections.abc import Callable
|
|
29
|
+
from collections.abc import Mapping
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# 涉及到的主要metrics包括:cpu_percent: float; rss_mib: float; count: int; module: str; phase: str; duration_seconds: float
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
P = ParamSpec("P")
|
|
37
|
+
R = TypeVar("R")
|
|
38
|
+
|
|
39
|
+
type JSONValue = bool | int | float | str | list[JSONValue] | dict[str, JSONValue] | None
|
|
40
|
+
type JSONObject = dict[str, JSONValue]
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"PhaseTimer",
|
|
44
|
+
"RunRecorder",
|
|
45
|
+
"end_run",
|
|
46
|
+
"inc_iter",
|
|
47
|
+
"init_metrics",
|
|
48
|
+
"maybe_push",
|
|
49
|
+
"measure",
|
|
50
|
+
"observe_counts",
|
|
51
|
+
"sample_process_gauges",
|
|
52
|
+
"start_run",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _now_iso() -> str:
|
|
57
|
+
"""此刻时间"""
|
|
58
|
+
return datetime.now(UTC).astimezone().isoformat(timespec="seconds")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _bytes_to_mib(n: int) -> float:
|
|
62
|
+
return n / (1024 * 1024)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class RunRecorder:
|
|
66
|
+
"""
|
|
67
|
+
把一次推理运行的关键过程与资源信息写入 JSON(metrics_logs/<run_id>.json)。
|
|
68
|
+
|
|
69
|
+
- 使用 event 记录任意事件(统一用 timestamp、event 命名)。
|
|
70
|
+
- 使用 observe_cpu_mem 记录 CPU/内存采样点( cpu_percent 、 rss_mib )。
|
|
71
|
+
- 调用 end 收尾并记录( started_at / ended_at 、峰值/均值等)。
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, log_dir: str = "metrics_logs", run_id: str | None = None) -> None:
|
|
75
|
+
"""
|
|
76
|
+
初始化一个运行记录器。
|
|
77
|
+
"""
|
|
78
|
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
|
79
|
+
self.run_id: str = run_id or (time.strftime("%Y%m%d-%H%M%S-") + uuid.uuid1().hex[:6])
|
|
80
|
+
self.log_dir: str = log_dir
|
|
81
|
+
self.meta: dict[str, Any] = {
|
|
82
|
+
"run_id": self.run_id,
|
|
83
|
+
"started_at": _now_iso(), # ISO 8601
|
|
84
|
+
}
|
|
85
|
+
self.events: list[JSONObject] = []
|
|
86
|
+
self._cpu_percent_peaks: list[float] = []
|
|
87
|
+
self._rss_mib_peaks: list[float] = []
|
|
88
|
+
|
|
89
|
+
self._phase_totals: dict[tuple[str, str], float] = {}
|
|
90
|
+
|
|
91
|
+
self._func_totals: dict[tuple[str, str], float] = {}
|
|
92
|
+
|
|
93
|
+
def event(self, kind: str, /, **kwargs: Any) -> None: # noqa: ANN401
|
|
94
|
+
"""记录一条通用事件"""
|
|
95
|
+
self.events.append({"timestamp": _now_iso(), "event": kind, **kwargs})
|
|
96
|
+
|
|
97
|
+
def observe_cpu_mem(self, cpu_pct: float, rss_bytes: int) -> None:
|
|
98
|
+
"""记录一次 CPU(%) 与 RSS(bytes) 采样,并存为事件(MiB 化)"""
|
|
99
|
+
cpu_percent = cpu_pct
|
|
100
|
+
rss_mib = _bytes_to_mib(rss_bytes)
|
|
101
|
+
self._cpu_percent_peaks.append(cpu_percent)
|
|
102
|
+
self._rss_mib_peaks.append(rss_mib)
|
|
103
|
+
self.event("process_sample", cpu_percent=cpu_percent, rss_mib=rss_mib)
|
|
104
|
+
|
|
105
|
+
def end(self, extra_meta: dict[str, Any] | None = None) -> str:
|
|
106
|
+
"""
|
|
107
|
+
结束当前运行,汇总峰值指标并将记录保存。
|
|
108
|
+
"""
|
|
109
|
+
self.meta["ended_at"] = _now_iso()
|
|
110
|
+
if extra_meta:
|
|
111
|
+
self.meta.update(extra_meta)
|
|
112
|
+
if self._cpu_percent_peaks:
|
|
113
|
+
self.meta["cpu_percent_max"] = max(self._cpu_percent_peaks)
|
|
114
|
+
self.meta["cpu_percent_mean"] = sum(self._cpu_percent_peaks) / len(self._cpu_percent_peaks)
|
|
115
|
+
if self._rss_mib_peaks:
|
|
116
|
+
self.meta["rss_max_mib"] = max(self._rss_mib_peaks)
|
|
117
|
+
|
|
118
|
+
if self._phase_totals:
|
|
119
|
+
self.meta["phase_durations_seconds_total"] = [
|
|
120
|
+
{"module": m, "phase": p, "duration_seconds_total": t}
|
|
121
|
+
for (m, p), t in sorted(self._phase_totals.items())
|
|
122
|
+
]
|
|
123
|
+
self.meta["all_phases_duration_seconds_total"] = sum(self._phase_totals.values())
|
|
124
|
+
else:
|
|
125
|
+
self.meta["all_phases_duration_seconds_total"] = "no running time"
|
|
126
|
+
|
|
127
|
+
if self._func_totals:
|
|
128
|
+
self.meta["function_durations_seconds_total"] = [
|
|
129
|
+
{"module": m, "name": n, "duration_seconds_total": t}
|
|
130
|
+
for (m, n), t in sorted(self._func_totals.items())
|
|
131
|
+
]
|
|
132
|
+
self.meta["all_functions_duration_seconds_total"] = sum(self._func_totals.values())
|
|
133
|
+
|
|
134
|
+
path = str(Path(self.log_dir) / f"{self.run_id}.json")
|
|
135
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
136
|
+
json.dump({"meta": self.meta, "events": self.events}, f, ensure_ascii=False, indent=4)
|
|
137
|
+
|
|
138
|
+
logger.result("Elapsed time: %ss", self.meta['all_phases_duration_seconds_total']) # type: ignore[attr-defined]
|
|
139
|
+
|
|
140
|
+
return path
|
|
141
|
+
|
|
142
|
+
def add_phase_duration(self, module: str, phase: str, seconds: float) -> None:
|
|
143
|
+
"""记录一次 phase 的持续时间"""
|
|
144
|
+
key = (module, phase)
|
|
145
|
+
self._phase_totals[key] = self._phase_totals.get(key, 0.0) + seconds
|
|
146
|
+
|
|
147
|
+
def add_func_duration(self, module: str, name: str, seconds: float) -> None:
|
|
148
|
+
"""记录一次函数的持续时间(用于 JSON 汇总)"""
|
|
149
|
+
key = (module, name)
|
|
150
|
+
self._func_totals[key] = self._func_totals.get(key, 0.0) + seconds
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# --------- 把所有“全局变量”折叠进一个 State 对象,避免 global 赋值(PLW0603) ---------
|
|
154
|
+
@dataclass
|
|
155
|
+
class _State:
|
|
156
|
+
registry: CollectorRegistry | None = None
|
|
157
|
+
pushgateway: str | None = None
|
|
158
|
+
job: str = "al_inference"
|
|
159
|
+
grouping: dict[str, str] = field(default_factory=dict)
|
|
160
|
+
proc: psutil.Process = field(default_factory=lambda: psutil.Process(os.getpid()))
|
|
161
|
+
|
|
162
|
+
# metrics
|
|
163
|
+
h_func_lat: Histogram | None = None
|
|
164
|
+
h_phase_lat: Histogram | None = None
|
|
165
|
+
g_rss: Gauge | None = None
|
|
166
|
+
g_cpu_pct: Gauge | None = None
|
|
167
|
+
c_iter: Counter | None = None
|
|
168
|
+
h_grounded_rules: Histogram | None = None
|
|
169
|
+
h_facts_count: Histogram | None = None
|
|
170
|
+
|
|
171
|
+
# run recorder
|
|
172
|
+
run: RunRecorder | None = None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
STATE = _State()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _new_hist(
|
|
179
|
+
name: str,
|
|
180
|
+
help_: str,
|
|
181
|
+
buckets: tuple[float, ...] | None = None,
|
|
182
|
+
labels: tuple[str, ...] = (),
|
|
183
|
+
) -> Histogram:
|
|
184
|
+
"""构造一个直方图指标(要求先调用 init_metrics)"""
|
|
185
|
+
if buckets is None:
|
|
186
|
+
buckets = (0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10, float("inf"))
|
|
187
|
+
|
|
188
|
+
return Histogram(name, help_, labels, buckets=buckets, registry=STATE.registry)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# ---- 需要调取的函数 ----
|
|
192
|
+
def start_run(log_dir: str = "metrics_logs", run_id: str | None = None) -> str:
|
|
193
|
+
"""在一次完整推理前调用;初始化 RunRecorder 并返回 run_id"""
|
|
194
|
+
STATE.run = RunRecorder(log_dir=log_dir, run_id=run_id)
|
|
195
|
+
return STATE.run.run_id
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def end_run(extra_meta: Mapping[str, JSONValue] | None = None) -> str | None:
|
|
199
|
+
"""在一次完整推理结束后调用;写入并返回 JSON 路径。若未 start_run 则返回 None"""
|
|
200
|
+
if STATE.run:
|
|
201
|
+
return STATE.run.end(dict(extra_meta) if extra_meta is not None else None)
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def init_metrics(
|
|
206
|
+
port: int | None = None,
|
|
207
|
+
pushgateway: str | None = None,
|
|
208
|
+
job: str = "al_inference",
|
|
209
|
+
grouping: Mapping[str, str] | None = None,
|
|
210
|
+
) -> None:
|
|
211
|
+
"""
|
|
212
|
+
初始化 Prometheus 指标。
|
|
213
|
+
|
|
214
|
+
- 批处理:不指定 port ,用 Pushgateway 推送。
|
|
215
|
+
- 本地开发:指定 port (如 8000),直接被 Prometheus scrape。
|
|
216
|
+
|
|
217
|
+
上述两个是常用功能,不过我们都默认存储到json了,这两个基本没啥影响
|
|
218
|
+
"""
|
|
219
|
+
STATE.registry = CollectorRegistry()
|
|
220
|
+
STATE.pushgateway = pushgateway
|
|
221
|
+
STATE.job = job
|
|
222
|
+
STATE.grouping = dict(grouping or {})
|
|
223
|
+
|
|
224
|
+
STATE.h_func_lat = _new_hist("func_latency_seconds", "Function/phase latency", labels=("module", "name"))
|
|
225
|
+
STATE.h_phase_lat = _new_hist("phase_latency_seconds", "Inference phase latency", labels=("module", "phase"))
|
|
226
|
+
STATE.g_rss = Gauge("process_rss_bytes", "Process RSS bytes", registry=STATE.registry)
|
|
227
|
+
STATE.g_cpu_pct = Gauge("process_cpu_percent", "Process CPU percent", registry=STATE.registry)
|
|
228
|
+
STATE.c_iter = Counter("inference_iterations_total", "Total inference iterations", ["module"], registry=STATE.registry)
|
|
229
|
+
STATE.h_grounded_rules = _new_hist("grounded_rules_count", "Grounded rules per iteration")
|
|
230
|
+
STATE.h_facts_count = _new_hist("facts_count_snapshot", "Facts count snapshot")
|
|
231
|
+
|
|
232
|
+
# 由于过高的时间开销而移除。以后如果对内存敏感是,再考虑参数控制或者换其他的 tracemalloc.start()
|
|
233
|
+
if port:
|
|
234
|
+
start_http_server(port, registry=STATE.registry)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def maybe_push() -> None:
|
|
238
|
+
"""若配置了 Pushgateway,则推送当前注册表中的指标。注:由于我个人倾向于json记录,此函数并未被引擎仓库使用,但不妨保留"""
|
|
239
|
+
if STATE.pushgateway and STATE.registry:
|
|
240
|
+
push_to_gateway(
|
|
241
|
+
STATE.pushgateway,
|
|
242
|
+
job=STATE.job,
|
|
243
|
+
registry=STATE.registry,
|
|
244
|
+
grouping_key=STATE.grouping,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def sample_process_gauges() -> None:
|
|
249
|
+
"""
|
|
250
|
+
采样一次进程 RSS/CPU 指标,写入对应 Gauge,并记录至运行日志(若有)。
|
|
251
|
+
需要先调用 init_metrics 初始化 Gauge。
|
|
252
|
+
"""
|
|
253
|
+
if STATE.g_rss is None or STATE.g_cpu_pct is None:
|
|
254
|
+
warnings.warn("Gauges not initialized, skipping sample_process_gauges", stacklevel=2)
|
|
255
|
+
return
|
|
256
|
+
rss: int = STATE.proc.memory_info().rss
|
|
257
|
+
STATE.g_rss.set(rss)
|
|
258
|
+
cpu: float = STATE.proc.cpu_percent(interval=None)
|
|
259
|
+
STATE.g_cpu_pct.set(cpu)
|
|
260
|
+
if STATE.run:
|
|
261
|
+
STATE.run.observe_cpu_mem(cpu_pct=cpu, rss_bytes=rss)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def measure(name: str, module: str | None = None, *, skip_process_gauges: bool = True,
|
|
265
|
+
skip_envent_record: bool = True) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
266
|
+
"""
|
|
267
|
+
装饰器/上下文:记录函数耗时并采样一次进程指标。
|
|
268
|
+
|
|
269
|
+
Examples
|
|
270
|
+
--------
|
|
271
|
+
>>> @measure("step", module="pipeline")
|
|
272
|
+
... def work(x: int) -> int:
|
|
273
|
+
... return x * 2
|
|
274
|
+
"""
|
|
275
|
+
resolved_module = module or __name__
|
|
276
|
+
|
|
277
|
+
def _decor(f: Callable[P, R]) -> Callable[P, R]:
|
|
278
|
+
@functools.wraps(f)
|
|
279
|
+
def _wrap(*a: P.args, **k: P.kwargs) -> R:
|
|
280
|
+
t0 = time.perf_counter()
|
|
281
|
+
try:
|
|
282
|
+
return f(*a, **k)
|
|
283
|
+
finally:
|
|
284
|
+
dt = time.perf_counter() - t0
|
|
285
|
+
|
|
286
|
+
if STATE.h_func_lat is not None:
|
|
287
|
+
STATE.h_func_lat.labels(module=resolved_module, name=name).observe(dt)
|
|
288
|
+
if not skip_process_gauges:
|
|
289
|
+
sample_process_gauges()
|
|
290
|
+
if STATE.run:
|
|
291
|
+
STATE.run.add_func_duration(resolved_module, name, dt)
|
|
292
|
+
if not skip_envent_record:
|
|
293
|
+
STATE.run.event("func_timing", module=resolved_module, name=name, duration_seconds=dt)
|
|
294
|
+
|
|
295
|
+
return _wrap
|
|
296
|
+
|
|
297
|
+
return _decor
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class PhaseTimer:
|
|
301
|
+
"""
|
|
302
|
+
上下文管理器:用于手动分段计时并采样进程指标。
|
|
303
|
+
|
|
304
|
+
使用示例::
|
|
305
|
+
|
|
306
|
+
with PhaseTimer("retrieve", module="pipeline"):
|
|
307
|
+
do_retrieve()
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
def __init__(self, phase: str, module: str | None = None, *, skip_process_gauges: bool = True,
|
|
311
|
+
skip_envent_record: bool = True, skip_count_record: bool = True) -> None:
|
|
312
|
+
self.phase: str = phase
|
|
313
|
+
self.module: str = module or __name__
|
|
314
|
+
self.t0: float | None = None
|
|
315
|
+
self.skip_process_gauges: bool = skip_process_gauges
|
|
316
|
+
self.skip_envent_record: bool = skip_envent_record
|
|
317
|
+
|
|
318
|
+
def __enter__(self) -> Self:
|
|
319
|
+
self.t0 = time.perf_counter()
|
|
320
|
+
return self
|
|
321
|
+
|
|
322
|
+
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[no-untyped-def] # noqa: ANN001
|
|
323
|
+
if self.t0 is None:
|
|
324
|
+
return
|
|
325
|
+
dt = time.perf_counter() - self.t0
|
|
326
|
+
|
|
327
|
+
if STATE.h_phase_lat is not None:
|
|
328
|
+
STATE.h_phase_lat.labels(module=self.module, phase=self.phase).observe(dt)
|
|
329
|
+
if not self.skip_process_gauges:
|
|
330
|
+
sample_process_gauges()
|
|
331
|
+
if STATE.run:
|
|
332
|
+
STATE.run.add_phase_duration(self.module, self.phase, dt)
|
|
333
|
+
if not self.skip_envent_record:
|
|
334
|
+
STATE.run.event("phase_timing", module=self.module, phase=self.phase, duration_seconds=dt)
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def observe_counts(grounded_rules: int | None = None, facts_count: int | None = None) -> None:
|
|
339
|
+
"""
|
|
340
|
+
记录离散计数类指标(例如每次 grounding 的规则数、事实库快照大小)。
|
|
341
|
+
若未初始化相应的直方图,调用将被忽略。
|
|
342
|
+
"""
|
|
343
|
+
if grounded_rules is not None and STATE.h_grounded_rules is not None:
|
|
344
|
+
STATE.h_grounded_rules.observe(float(grounded_rules))
|
|
345
|
+
if STATE.run:
|
|
346
|
+
STATE.run.event("grounded_rules", count=grounded_rules)
|
|
347
|
+
if facts_count is not None and STATE.h_facts_count is not None:
|
|
348
|
+
STATE.h_facts_count.observe(float(facts_count))
|
|
349
|
+
if STATE.run:
|
|
350
|
+
STATE.run.event("facts_count", count=facts_count)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def inc_iter(module: str) -> None:
|
|
354
|
+
"""将指定 module 的推理迭代次数自增 1。未初始化则忽略。"""
|
|
355
|
+
if STATE.c_iter is None:
|
|
356
|
+
return
|
|
357
|
+
STATE.c_iter.labels(module=module).inc()
|