kele 0.0.1a1__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/__init__.py +38 -0
- kele/_version.py +1 -0
- kele/config.py +243 -0
- kele/control/README_metrics.md +102 -0
- kele/control/__init__.py +20 -0
- kele/control/callback.py +255 -0
- kele/control/grounding_selector/__init__.py +5 -0
- kele/control/grounding_selector/_rule_strategies/README.md +13 -0
- kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
- kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
- kele/control/grounding_selector/_selector_utils.py +123 -0
- kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
- kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
- kele/control/grounding_selector/rule_selector.py +98 -0
- kele/control/grounding_selector/term_selector.py +89 -0
- kele/control/infer_path.py +306 -0
- kele/control/metrics.py +357 -0
- kele/control/status.py +286 -0
- kele/egg_equiv.pyi +11 -0
- kele/egg_equiv.so +0 -0
- kele/equality/README.md +8 -0
- kele/equality/__init__.py +4 -0
- kele/equality/_egg_equiv/src/lib.rs +267 -0
- kele/equality/_equiv_elem.py +67 -0
- kele/equality/_utils.py +36 -0
- kele/equality/equivalence.py +141 -0
- kele/executer/__init__.py +4 -0
- kele/executer/executing.py +139 -0
- kele/grounder/README.md +83 -0
- kele/grounder/__init__.py +17 -0
- kele/grounder/grounded_rule_ds/__init__.py +6 -0
- kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
- kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
- kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
- kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
- kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
- kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
- kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
- kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
- kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
- kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
- kele/grounder/grounded_rule_ds/rule_check.py +373 -0
- kele/grounder/grounding.py +118 -0
- kele/knowledge_bases/README.md +112 -0
- kele/knowledge_bases/__init__.py +6 -0
- kele/knowledge_bases/builtin_base/__init__.py +1 -0
- kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
- kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
- kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
- kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
- kele/knowledge_bases/fact_base.py +158 -0
- kele/knowledge_bases/ontology_base.py +67 -0
- kele/knowledge_bases/rule_base.py +194 -0
- kele/main.py +464 -0
- kele/py.typed +0 -0
- kele/syntax/CONCEPT_README.md +117 -0
- kele/syntax/__init__.py +40 -0
- kele/syntax/_cnf_converter.py +161 -0
- kele/syntax/_sat_solver.py +116 -0
- kele/syntax/base_classes.py +1482 -0
- kele/syntax/connectives.py +20 -0
- kele/syntax/dnf_converter.py +145 -0
- kele/syntax/external.py +17 -0
- kele/syntax/sub_concept.py +87 -0
- kele/syntax/syntacticsugar.py +201 -0
- kele-0.0.1a1.dist-info/METADATA +165 -0
- kele-0.0.1a1.dist-info/RECORD +73 -0
- kele-0.0.1a1.dist-info/WHEEL +6 -0
- kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
import logging
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from .rule_check import RuleCheckGraph
|
|
10
|
+
from kele.grounder.grounded_rule_ds.grounded_ds_utils import unify_all_terms
|
|
11
|
+
from ._nodes import _FlatCompoundTermNode, _OperatorNode, _RootNode, _AssertionNode, _ConnectiveNode, _RuleNode, _QuestionRuleNode
|
|
12
|
+
from ._nodes._tupletable import _TupleTable
|
|
13
|
+
from collections import deque
|
|
14
|
+
from kele.syntax import Formula, Assertion
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from kele.config import Config
|
|
18
|
+
from kele.control import InferencePath
|
|
19
|
+
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION
|
|
20
|
+
from kele.syntax import Constant, Rule, FACT_TYPE, SankuManagementSystem, CompoundTerm
|
|
21
|
+
from kele.syntax.base_classes import _QuestionRule
|
|
22
|
+
from collections.abc import Sequence, Mapping
|
|
23
|
+
from kele.equality import Equivalence
|
|
24
|
+
from kele.syntax import Variable
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GroundedRule:
|
|
30
|
+
"""
|
|
31
|
+
管理单条规则的实例化状态,负责 term-level unify、合并变量候选表与 check 阶段的执行。
|
|
32
|
+
|
|
33
|
+
GroundedRule 不保存完整展开后的 grounded rules,而是维护变量候选表与执行状态,
|
|
34
|
+
在 check 时再按需展开并生成新的事实。
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, rule: Rule, equivalence: Equivalence, sk_system_handler: SankuManagementSystem,
|
|
38
|
+
args: Config, inference_path: InferencePath) -> None:
|
|
39
|
+
if rule.unsafe_variables:
|
|
40
|
+
raise TypeError(f"""Rule {rule!s} is unsafe because it contains unsafe variables {[str(u) for u in rule.unsafe_variables]}.\n
|
|
41
|
+
This error likely appears because the rule was not added to RuleBase or did not go through preprocessing.
|
|
42
|
+
""")
|
|
43
|
+
if not self._is_conjunctive_body(rule):
|
|
44
|
+
warnings.warn(f"""
|
|
45
|
+
Rule {rule!s} body must be a conjunction of positive and negative assertions; this rule does not meet the requirement.\n
|
|
46
|
+
This warning likely appears because the rule was not added to RuleBase or did not go through preprocessing.\n
|
|
47
|
+
For more information, see: #TODO
|
|
48
|
+
""", stacklevel=5) # TODO: add engine tutorial URL
|
|
49
|
+
if not self._is_conjunctive_head(rule):
|
|
50
|
+
warnings.warn(
|
|
51
|
+
f"Rule {rule!s} head contains connectives other than AND, which may prevent correct fact generation.",
|
|
52
|
+
stacklevel=5,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
self.args = args
|
|
56
|
+
self.rule = rule
|
|
57
|
+
self.rule_checker = RuleCheckGraph(self, self.args) # risk: 如果fact里面里面是forall的很难处理,实操时候需要先实例化fact再推。
|
|
58
|
+
self.is_concept_compatible_binding = partial(self.rule.is_concept_compatible_binding,
|
|
59
|
+
fuzzy_match=self.args.grounder.conceptual_fuzzy_unification)
|
|
60
|
+
|
|
61
|
+
self.inference_path: InferencePath = inference_path
|
|
62
|
+
|
|
63
|
+
self.equivalence = equivalence
|
|
64
|
+
self.sk_system_handler = sk_system_handler
|
|
65
|
+
|
|
66
|
+
self.all_freevar_table: list[_TupleTable] = []
|
|
67
|
+
self._action_new_fact_list: list[FACT_TYPE] = []
|
|
68
|
+
self._past_all_freevar_table: list[_TupleTable] = []
|
|
69
|
+
self._past_df_prefix_sum: list[_TupleTable] = [] # 前缀和会有较大的浪费,即第二个df包含第一个df,以此类推
|
|
70
|
+
|
|
71
|
+
columns = set()
|
|
72
|
+
# 引擎只使用部分nodes进入grounding过程,部分node只做substitution。此外,substitution nodes中的action op相关nodes也会影响grounded rule的生成
|
|
73
|
+
# 但其不增加已通过semi-naive策略生成的grounded rule的数量,也不需要进行unification,因此不记录与之相关的past_df
|
|
74
|
+
for assertion_node in self.rule_checker.grounding_nodes:
|
|
75
|
+
columns |= assertion_node.grounding_arguments
|
|
76
|
+
self._past_df_prefix_sum.append(_TupleTable(column_name=tuple(columns)))
|
|
77
|
+
|
|
78
|
+
self.total_table: _TupleTable
|
|
79
|
+
self._back_up_total_table: _TupleTable
|
|
80
|
+
self._back_up_true_table: _TupleTable
|
|
81
|
+
|
|
82
|
+
def unify(self, terms: list[CompoundTerm[Constant | CompoundTerm] | Constant]) -> None:
|
|
83
|
+
"""
|
|
84
|
+
对传入的 term 进行 unify,仅生成实例化候选值,不检查事实的正确性。
|
|
85
|
+
|
|
86
|
+
- FlatCompoundTermNode 会先将常量替换为等价类代表元;
|
|
87
|
+
- unify 仅走 term-level,后续由 AssertionNode 进行 join 与 check。
|
|
88
|
+
|
|
89
|
+
:params terms: 用于实例化的事实 term 列表
|
|
90
|
+
"""
|
|
91
|
+
node_queue: deque[_RootNode | _OperatorNode | _FlatCompoundTermNode] = deque()
|
|
92
|
+
root: _RootNode = self.rule_checker.graph_root
|
|
93
|
+
node_queue.append(root)
|
|
94
|
+
|
|
95
|
+
while node_queue:
|
|
96
|
+
cur_node = node_queue.popleft()
|
|
97
|
+
if isinstance(cur_node, _FlatCompoundTermNode):
|
|
98
|
+
cur_node.process_equiv_represent_elem()
|
|
99
|
+
elif isinstance(cur_node, (_OperatorNode, _RootNode)):
|
|
100
|
+
node_queue.extend(cur_node.query_for_children())
|
|
101
|
+
|
|
102
|
+
for t in terms:
|
|
103
|
+
self._unify_single(t)
|
|
104
|
+
self._start_passing_process()
|
|
105
|
+
|
|
106
|
+
def check_grounding(self) -> list[FACT_TYPE]:
|
|
107
|
+
"""
|
|
108
|
+
执行 check 阶段并返回新事实。
|
|
109
|
+
|
|
110
|
+
AssertionNode 会执行 action assertion 的计算并生成额外事实,最终由 RuleNode 汇总。
|
|
111
|
+
|
|
112
|
+
:returns list[FACT_TYPE]: check得到的新事实
|
|
113
|
+
"""
|
|
114
|
+
execute_queue: deque[_AssertionNode | _ConnectiveNode | _RuleNode] = deque(self.rule_checker.execute_nodes)
|
|
115
|
+
new_facts: list[FACT_TYPE] = []
|
|
116
|
+
|
|
117
|
+
self._start_join_process()
|
|
118
|
+
|
|
119
|
+
# 同时由于execute_queue的创建时直接加入AssertionNode,我们保证了AssertionNode总是最先被执行,而是否可以执行ConnectiveNode,由
|
|
120
|
+
# ConnectiveNode的ready_for_execute属性控制。注意可以执行即要求它的父节点都向它传递了TfIndexs。
|
|
121
|
+
while execute_queue:
|
|
122
|
+
cur_node = execute_queue.popleft()
|
|
123
|
+
if isinstance(cur_node, (_ConnectiveNode, _AssertionNode)) and cur_node.ready_to_execute:
|
|
124
|
+
cur_node.exec_check()
|
|
125
|
+
cur_node.pass_tf_index()
|
|
126
|
+
execute_queue.extend(cur_node.query_for_children())
|
|
127
|
+
elif isinstance(cur_node, _RuleNode):
|
|
128
|
+
# 一般来说对于单个rule,RuleNode只有一个,所以直接赋值即可 # FIXME: 会多的以后,先留着
|
|
129
|
+
new_facts = cur_node.exec_check()
|
|
130
|
+
if self._action_new_fact_list:
|
|
131
|
+
new_facts.extend(set(self._action_new_fact_list))
|
|
132
|
+
self._action_new_fact_list.clear()
|
|
133
|
+
return new_facts
|
|
134
|
+
|
|
135
|
+
def _unify_single(self, term: CompoundTerm[Constant | CompoundTerm] | Constant) -> None:
|
|
136
|
+
node_queue: deque[_RootNode | _OperatorNode | _FlatCompoundTermNode] = deque()
|
|
137
|
+
root: _RootNode = self.rule_checker.graph_root
|
|
138
|
+
node_queue.append(root)
|
|
139
|
+
|
|
140
|
+
while node_queue:
|
|
141
|
+
cur_node = node_queue.popleft()
|
|
142
|
+
if isinstance(cur_node, _FlatCompoundTermNode) and not cur_node.only_substitution: # FIXME: 带着一个term往下走,似乎要多判断很多次这个
|
|
143
|
+
# 另外这个only的判断可能也有待优化
|
|
144
|
+
cur_node.exec_unify(term, allow_unify_with_nested_term=self.args.grounder.allow_unify_with_nested_term)
|
|
145
|
+
# 对_FlatCompoundTermNode进行unification操作
|
|
146
|
+
elif isinstance(cur_node, (_OperatorNode, _RootNode)):
|
|
147
|
+
node_queue.extend(cur_node.query_for_children(term))
|
|
148
|
+
|
|
149
|
+
def receive_true_table(self, true_table: _TupleTable) -> None:
|
|
150
|
+
"""
|
|
151
|
+
接收从AssertionNode传递来的true indexs
|
|
152
|
+
"""
|
|
153
|
+
if not hasattr(self, '_back_up_true_table') or self._back_up_true_table.height == 0:
|
|
154
|
+
self._back_up_true_table = true_table.copy()
|
|
155
|
+
else:
|
|
156
|
+
self._back_up_true_table = self._back_up_true_table.concat_table(true_table)
|
|
157
|
+
|
|
158
|
+
def _start_passing_process(self) -> None:
|
|
159
|
+
"""
|
|
160
|
+
启动传递过程,将 freevar_table 传递到 _AssertionNode。
|
|
161
|
+
"""
|
|
162
|
+
node_queue: list[_RootNode | _OperatorNode | _FlatCompoundTermNode | _AssertionNode] = []
|
|
163
|
+
root: _RootNode = self.rule_checker.graph_root
|
|
164
|
+
node_queue.append(root)
|
|
165
|
+
|
|
166
|
+
while node_queue:
|
|
167
|
+
cur_node = node_queue.pop()
|
|
168
|
+
if isinstance(cur_node, _FlatCompoundTermNode):
|
|
169
|
+
cur_node.pass_freevar_to_child()
|
|
170
|
+
node_queue.extend(cur_node.query_for_child())
|
|
171
|
+
elif isinstance(cur_node, (_OperatorNode, _RootNode)):
|
|
172
|
+
node_queue.extend(cur_node.query_for_children())
|
|
173
|
+
|
|
174
|
+
def _start_join_process(self) -> None:
|
|
175
|
+
"""
|
|
176
|
+
将各 AssertionNode 的表合并为规则级别总表,并广播给执行节点。
|
|
177
|
+
"""
|
|
178
|
+
for assertion_node in self.rule_checker.grounding_nodes:
|
|
179
|
+
self._past_all_freevar_table.append(assertion_node.past_freevar_table)
|
|
180
|
+
self.all_freevar_table.append(assertion_node.exec_join())
|
|
181
|
+
|
|
182
|
+
if self.rule_checker.grounding_nodes:
|
|
183
|
+
total_table = self._calc_total_table()
|
|
184
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
185
|
+
logger.debug(
|
|
186
|
+
"Grounded rule total table before anti-join: rule=%s summary=%s",
|
|
187
|
+
self.rule,
|
|
188
|
+
total_table.debug_summary(),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if self.args.executor.anti_join_used_facts and hasattr(self, "_back_up_true_table"):
|
|
192
|
+
# 在config中开启anti_join_used_facts、且已经有备份的情况下,执行anti join操作
|
|
193
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
194
|
+
logger.debug(
|
|
195
|
+
"Grounded rule anti-join: rule=%s base=%s anti=%s",
|
|
196
|
+
self.rule,
|
|
197
|
+
total_table.debug_summary(),
|
|
198
|
+
self._back_up_true_table.debug_summary(),
|
|
199
|
+
)
|
|
200
|
+
total_table = total_table.anti_join(self._back_up_true_table)
|
|
201
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
202
|
+
logger.debug(
|
|
203
|
+
"Grounded rule total table after anti-join: rule=%s summary=%s",
|
|
204
|
+
self.rule,
|
|
205
|
+
total_table.debug_summary(),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
self.total_table = total_table
|
|
209
|
+
else: # 规则前提不含 free vars 时,total table 为空列。 TODO:考虑是否将这种特殊规则单独出来,不走unify的流程
|
|
210
|
+
self.total_table = _TupleTable(column_name=())
|
|
211
|
+
|
|
212
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
213
|
+
logger.debug(
|
|
214
|
+
"Grounded rule final total table broadcast: rule=%s summary=%s",
|
|
215
|
+
self.rule,
|
|
216
|
+
self.total_table.debug_summary(),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
self._broadcast_total_table(self.total_table)
|
|
220
|
+
self._back_up_total_table = self.__dict__.pop('total_table')
|
|
221
|
+
|
|
222
|
+
def _calc_total_table(self) -> _TupleTable:
|
|
223
|
+
"""对齐等价类代表元并执行 semi-naive 合并。"""
|
|
224
|
+
# 将过往的所有table对齐
|
|
225
|
+
self._past_all_freevar_table = [table.update_equiv_element(self.equivalence) for table in self._past_all_freevar_table]
|
|
226
|
+
self._past_df_prefix_sum = [table.update_equiv_element(self.equivalence) for table in self._past_df_prefix_sum]
|
|
227
|
+
# 将当前table对齐
|
|
228
|
+
self.all_freevar_table = [table.update_equiv_element(self.equivalence) for table in self.all_freevar_table]
|
|
229
|
+
# FIXME: 注意这里new_table实质上进行了两遍对齐:在unify过程一次对齐,在这里另外一次对齐。可能需要考虑并移除一次对齐
|
|
230
|
+
total_table, mid_table = self._semi_naive() # FIXME: 这里需要拆分变量绑定一致性检查、exec_check和最后的semi。第一个考虑到内存,改换Yannakakis
|
|
231
|
+
# HACK:在semi_naive流程中记录mid_table,用于action_node的exec_action。这在后续流程更改之后应该移除
|
|
232
|
+
# HACK: mid table本身是一个HACK,正确算法是区分action/non action assertion,以及not。它们与 “semi join--check--semi join--semi naive”
|
|
233
|
+
# 的流程正交来写这个代码。等待下一个版本修正
|
|
234
|
+
for action_node in self.rule_checker.action_nodes:
|
|
235
|
+
self._action_new_fact_list.extend(action_node.exec_action(mid_table))
|
|
236
|
+
return total_table
|
|
237
|
+
|
|
238
|
+
def _semi_naive(self) -> tuple[_TupleTable, _TupleTable]:
|
|
239
|
+
"""TODO: 目前版本是(A1+A1')(A2+A2')...的,还没加入设计的其他公式"""
|
|
240
|
+
t1_last = self._past_df_prefix_sum[0] # 对应前j-1个的总last,此时是第0个
|
|
241
|
+
t1_new = self.all_freevar_table[0] # 对应前j-1个总new,此时是第0个
|
|
242
|
+
mid_table = _TupleTable(()) # HACK: 默认mid_table是一个empty_table
|
|
243
|
+
|
|
244
|
+
self._past_df_prefix_sum[0] = self._past_df_prefix_sum[0].concat_table(t1_new) # 更新为第T轮的prefix sum
|
|
245
|
+
|
|
246
|
+
for i in range(1, len(self.all_freevar_table)):
|
|
247
|
+
# 对于任意两项,其可以拆解为A1A2 + A1A2' + A1'A2 + A1'A2'。t1_last对应A1,t1new对应A1'。t2同理
|
|
248
|
+
if i == len(self.rule_checker.grounding_nodes) - len(self.rule_checker.action_nodes):
|
|
249
|
+
# HACK:如果现在是第一个action项,那么mid_table就是之前的结果
|
|
250
|
+
mid_table = t1_new # HACK: 当所有的none_action项都计算完毕之后,记录mid_table
|
|
251
|
+
# HACK: 可以直接这样记录是因为我们保证了none_action项在action项之前。
|
|
252
|
+
|
|
253
|
+
t2_last = self._past_all_freevar_table[i]
|
|
254
|
+
t2_new = self.all_freevar_table[i]
|
|
255
|
+
|
|
256
|
+
new_tables = []
|
|
257
|
+
if t1_new.height > 0 and t2_last.height > 0:
|
|
258
|
+
new_tables.append(t1_new.union_table(t2_last)) # A1'A2
|
|
259
|
+
|
|
260
|
+
if t1_last.height > 0 and t2_new.height > 0:
|
|
261
|
+
new_tables.append(t1_last.union_table(t2_new)) # A1A2'
|
|
262
|
+
|
|
263
|
+
if t1_new.height > 0 and t2_new.height > 0:
|
|
264
|
+
new_tables.append(t1_new.union_table(t2_new)) # A1'A2'
|
|
265
|
+
|
|
266
|
+
nxt_prefix_sum = self._past_df_prefix_sum[i]
|
|
267
|
+
|
|
268
|
+
if len(new_tables) > 0:
|
|
269
|
+
new_table = new_tables[0].concat_table(*new_tables[1:])
|
|
270
|
+
self._past_df_prefix_sum[i] = self._past_df_prefix_sum[i].concat_table(new_table) # 更新prefix sum,用于下次
|
|
271
|
+
else:
|
|
272
|
+
new_table = _TupleTable(self._past_df_prefix_sum[i].raw_column_name)
|
|
273
|
+
|
|
274
|
+
t1_last = nxt_prefix_sum # 继续为i+1项做准备而迭代
|
|
275
|
+
t1_new = new_table
|
|
276
|
+
return t1_new, mid_table
|
|
277
|
+
|
|
278
|
+
@staticmethod
|
|
279
|
+
def _is_conjunctive_body(rule: Rule) -> bool:
|
|
280
|
+
"""判断当前规则的body的连接词只出现and, not,且not只能作用在Assertion上(换言之就是没有OR的DNF格式)"""
|
|
281
|
+
is_standard = True
|
|
282
|
+
cur_term_queue: deque[Formula | Assertion | None] = deque([rule.body])
|
|
283
|
+
while cur_term_queue:
|
|
284
|
+
cur_term = cur_term_queue.popleft()
|
|
285
|
+
if isinstance(cur_term, Formula):
|
|
286
|
+
if cur_term.connective not in {'AND', 'NOT'}:
|
|
287
|
+
is_standard = False
|
|
288
|
+
break
|
|
289
|
+
if cur_term.connective == 'NOT' and not isinstance(cur_term.formula_left, Assertion):
|
|
290
|
+
# NOT必须作用在Assertion上,不能作用在其他公式上
|
|
291
|
+
is_standard = False
|
|
292
|
+
break
|
|
293
|
+
cur_term_queue.append(cur_term.formula_left)
|
|
294
|
+
cur_term_queue.append(cur_term.formula_right)
|
|
295
|
+
return is_standard
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
def _is_conjunctive_head(rule: Rule) -> bool:
|
|
299
|
+
"""判断当前规则的head的连接词只出现and"""
|
|
300
|
+
is_standard = True
|
|
301
|
+
cur_term_queue: deque[Formula | Assertion | None] = deque([rule.head])
|
|
302
|
+
while cur_term_queue:
|
|
303
|
+
cur_term = cur_term_queue.popleft()
|
|
304
|
+
if isinstance(cur_term, Formula):
|
|
305
|
+
if cur_term.connective != 'AND':
|
|
306
|
+
is_standard = False
|
|
307
|
+
break
|
|
308
|
+
cur_term_queue.append(cur_term.formula_left)
|
|
309
|
+
cur_term_queue.append(cur_term.formula_right)
|
|
310
|
+
return is_standard
|
|
311
|
+
|
|
312
|
+
def _broadcast_total_table(self, total_table: _TupleTable) -> None:
|
|
313
|
+
"""
|
|
314
|
+
将规则级别总表广播到 AssertionNode / RuleNode。
|
|
315
|
+
|
|
316
|
+
:params: total_table (_TupleTable): 等待储存的总表
|
|
317
|
+
"""
|
|
318
|
+
for cur_assertion in self.rule_checker.execute_nodes:
|
|
319
|
+
cur_assertion.broadcast_total_table(total_table)
|
|
320
|
+
|
|
321
|
+
def print_all_grounded_rules(self) -> list[str]:
|
|
322
|
+
"""为了调试方便,打印所有实例化后的规则(注意控制内存开销)"""
|
|
323
|
+
grounded_rule_text = []
|
|
324
|
+
for combination in self._back_up_total_table.iter_rows():
|
|
325
|
+
rule_text = self.rule.replace_variable(combination)
|
|
326
|
+
grounded_rule_text.append(str(rule_text))
|
|
327
|
+
|
|
328
|
+
return grounded_rule_text
|
|
329
|
+
|
|
330
|
+
def total_table_unique_height(self) -> int:
|
|
331
|
+
"""
|
|
332
|
+
用于日志/调试输出的去重规则行数。
|
|
333
|
+
"""
|
|
334
|
+
if hasattr(self, "_back_up_total_table"):
|
|
335
|
+
return self._back_up_total_table.unique_height()
|
|
336
|
+
if hasattr(self, "total_table"):
|
|
337
|
+
return self.total_table.unique_height()
|
|
338
|
+
return 0
|
|
339
|
+
|
|
340
|
+
def get_question_solutions(self) -> tuple[list[Mapping[Variable, Constant | CompoundTerm]], _QuestionRule | None]:
|
|
341
|
+
"""获取问题规则节点(如果存在)"""
|
|
342
|
+
if isinstance(self.rule_checker.rule_node, _QuestionRuleNode):
|
|
343
|
+
question_node = self.rule_checker.rule_node
|
|
344
|
+
return question_node.solutions, question_node.question_rule
|
|
345
|
+
return [], None
|
|
346
|
+
|
|
347
|
+
def reset(self) -> None: # fixme: 此函数实现仅起过渡作用,令当前commit与之前的代码表现一致。后续会逐渐用更多的reset替代这个reset
|
|
348
|
+
"""重置GroundedRule的状态,用于新一轮(iteration)的推理。_past_prefix_sum变量和_backup变量不应删除"""
|
|
349
|
+
self.rule_checker.reset()
|
|
350
|
+
self._past_all_freevar_table.clear()
|
|
351
|
+
self.all_freevar_table.clear()
|
|
352
|
+
self._action_new_fact_list.clear()
|
|
353
|
+
if hasattr(self, 'total_table'):
|
|
354
|
+
del self.total_table
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class GroundedRuleDS:
|
|
358
|
+
"""
|
|
359
|
+
维护全局 GroundedRule 的生命周期与当前轮次的实例化输入。
|
|
360
|
+
|
|
361
|
+
GroundedRuleDS 不追求存储“最终 grounded rule”,而是提供实例化、合并与复用的管理入口。
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
def __init__(self, equivalence: Equivalence, sk_system_handler: SankuManagementSystem, args: Config, inference_path: InferencePath) -> None:
|
|
365
|
+
# FIXME:这里虽然因为对齐需要传入equivalence,但具体到底是否应该在init传入,以及是否应该换用其他方式调用都有待进一步讨论 # noqa: TD004
|
|
366
|
+
self.grounded_rule_pool: dict[Rule, GroundedRule] = {}
|
|
367
|
+
self.inference_path = inference_path
|
|
368
|
+
self.equivalence = equivalence
|
|
369
|
+
self.sk_system_handler = sk_system_handler
|
|
370
|
+
self.args = args
|
|
371
|
+
self.current_grounded_rule_terms: Sequence[tuple[GroundedRule, Sequence[GROUNDED_TYPE_FOR_UNIFICATION]]] | None = None
|
|
372
|
+
|
|
373
|
+
def _add_rule(self, rule: Rule) -> GroundedRule:
|
|
374
|
+
"""
|
|
375
|
+
将一条规则纳入 GroundedRuleDS 管理,并返回对应 GroundedRule。
|
|
376
|
+
"""
|
|
377
|
+
if rule not in self.grounded_rule_pool:
|
|
378
|
+
grounded_rule = GroundedRule(rule, self.equivalence, self.sk_system_handler, self.args, self.inference_path)
|
|
379
|
+
self.grounded_rule_pool[rule] = grounded_rule
|
|
380
|
+
|
|
381
|
+
return self.grounded_rule_pool[rule]
|
|
382
|
+
# 这里转化为GroundedRule的时候,就已经将rule的图结构生成好了。暂时没有明确的处理范畴约束,先置空,仅
|
|
383
|
+
# 转一下GroundedRule。但是以后可能会留存历史以来所有选中的Rule,并进行恰当的事实选取以进行额外的实例化
|
|
384
|
+
|
|
385
|
+
def start(self, cur_rules_facts: Sequence[tuple[Rule, Sequence[GROUNDED_TYPE_FOR_UNIFICATION]]]) -> None:
|
|
386
|
+
"""
|
|
387
|
+
设置本轮需要执行的cur_rules和facts,及一些可能需要的初始化操作
|
|
388
|
+
|
|
389
|
+
:raises RuntimeError: grounding过程还未结束时,再次调用start
|
|
390
|
+
""" # noqa: DOC501
|
|
391
|
+
if self.current_grounded_rule_terms is not None:
|
|
392
|
+
raise RuntimeError("Grounding process is not ended")
|
|
393
|
+
|
|
394
|
+
self.current_grounded_rule_terms = []
|
|
395
|
+
for rule, facts in cur_rules_facts:
|
|
396
|
+
cur_rule = self._add_rule(rule)
|
|
397
|
+
cur_rule.reset()
|
|
398
|
+
self.current_grounded_rule_terms.append((cur_rule, facts))
|
|
399
|
+
|
|
400
|
+
def end(self) -> None:
|
|
401
|
+
"""grounding过程结束,移除cur_rules和facts。"""
|
|
402
|
+
self.current_grounded_rule_terms = None
|
|
403
|
+
|
|
404
|
+
@staticmethod
|
|
405
|
+
def _unify(cur_rule: GroundedRule, facts: Sequence[GROUNDED_TYPE_FOR_UNIFICATION]) -> None:
|
|
406
|
+
"""对单条规则使用对应事实进行实例化。"""
|
|
407
|
+
useful_terms: list[CompoundTerm[Constant | CompoundTerm] | Constant] = []
|
|
408
|
+
for single_fact in facts:
|
|
409
|
+
useful_terms.extend(unify_all_terms(single_fact))
|
|
410
|
+
|
|
411
|
+
cur_rule.unify(useful_terms)
|
|
412
|
+
|
|
413
|
+
def _grounding_term_level(self) -> None:
|
|
414
|
+
"""遍历当前轮次规则并完成 term-level 实例化。"""
|
|
415
|
+
if self.current_grounded_rule_terms is not None:
|
|
416
|
+
for rule_tuple in self.current_grounded_rule_terms:
|
|
417
|
+
single_rule, single_facts = rule_tuple
|
|
418
|
+
self._unify(single_rule, single_facts)
|
|
419
|
+
|
|
420
|
+
def exec_grounding(self) -> None:
|
|
421
|
+
"""TODO: 这个函数还可以优化。可能是对self.current_grounded_rule_terms做一些调整,如加入过往选择的rule?"""
|
|
422
|
+
self._grounding_term_level()
|
|
423
|
+
|
|
424
|
+
def get_corresponding_grounded_rules(self, abstract_rules: list[Rule]) -> list[GroundedRule]:
|
|
425
|
+
"""
|
|
426
|
+
取出给定rule的GroundedRule,用于下一步executor的执行
|
|
427
|
+
1. 如果grounding信息的存储方式是现在逐条的check graph的话,记录好二者的映射关系即可;
|
|
428
|
+
2. 如果存储方式是一张大图的话,直接return规则末端的节点。
|
|
429
|
+
此外,如果某条rule在本阶段没有得到实际的实例化,那是可以不return的。 risk: 这里有个权衡,是本阶段没有 or 至今没有
|
|
430
|
+
TODO: executor.check时,要注意避免check已check的部分(这个和used还不完全一样,有可能没过全局的check所有没有use)。一个可能的策略
|
|
431
|
+
是所有check过的都丢到下一个节点存储,当前节点移除。但还是会存在,比如op(x)=y中,x=1是used,不需要二次实例化。但x=1曾经被check过后,
|
|
432
|
+
不代表日后不能被check(因为可能y变了)。但另一方面,如果等价类够快的话,检查是否已check可能开销check一下差不多。
|
|
433
|
+
|
|
434
|
+
risk: 从上面的分析来看,这个函数的返回值约束不会出现太大的问题,因为无论是GroundedRule还是末端代表规则的节点,它们都能够索引到
|
|
435
|
+
规则本身的信息 + 实例化信息,这样后续executor进行check时总是有办法快速适配的。但总之还是多加小心一些
|
|
436
|
+
"""
|
|
437
|
+
return [self.grounded_rule_pool[r] for r in abstract_rules]
|
|
438
|
+
|
|
439
|
+
def reset(self) -> None:
|
|
440
|
+
"""
|
|
441
|
+
清空grounded_rule_pool,这是为了防止重复创建Inference_engine带来的错误
|
|
442
|
+
"""
|
|
443
|
+
self.grounded_rule_pool.clear()
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class GroundedProcess:
|
|
447
|
+
"""执行 grounding 的上下文管理器。"""
|
|
448
|
+
def __init__(self, grounded_structure: GroundedRuleDS, cur_rules_terms: Sequence[tuple[Rule, Sequence[GROUNDED_TYPE_FOR_UNIFICATION]]]):
|
|
449
|
+
self.grounded_structure = grounded_structure
|
|
450
|
+
self.cur_terms_facts = cur_rules_terms
|
|
451
|
+
|
|
452
|
+
def __enter__(self) -> GroundedRuleDS:
|
|
453
|
+
self.grounded_structure.start(cur_rules_facts=self.cur_terms_facts)
|
|
454
|
+
return self.grounded_structure
|
|
455
|
+
|
|
456
|
+
def __exit__(self, exc_type, exc_value, traceback) -> bool | None: # type: ignore[no-untyped-def] # noqa: ANN001
|
|
457
|
+
if exc_type: # XXX: 暂时没有对异常做特殊处理,且强制返回了True。以后使用时根据实际出现的异常逐步优化
|
|
458
|
+
raise exc_value
|
|
459
|
+
|
|
460
|
+
self.grounded_structure.end()
|
|
461
|
+
return True
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import cast, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from kele.knowledge_bases.builtin_base.builtin_concepts import FREEVARANY_CONCEPT
|
|
6
|
+
from kele.syntax import (FACT_TYPE, TERM_TYPE, Assertion, Formula, CompoundTerm,
|
|
7
|
+
Constant, Variable, ATOM_TYPE)
|
|
8
|
+
from functools import singledispatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@singledispatch
|
|
12
|
+
def unify_all_terms(fact: FACT_TYPE | TERM_TYPE) -> tuple[CompoundTerm[Constant | CompoundTerm] | Constant, ...]:
|
|
13
|
+
# hack: 注意这里和split等函数,回头可以再细分一下。比如有的可以支持带Variable的(这种回头也得改泛型)
|
|
14
|
+
"""
|
|
15
|
+
主要是将作为formula的fact拆开成Assertion用的,对于单个的Assertion,我们拆成TERMTYPE,传入其他函数处理
|
|
16
|
+
这里直接拆到FlatCompoundTerm方便一些
|
|
17
|
+
"""
|
|
18
|
+
return ()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@unify_all_terms.register(Assertion)
|
|
22
|
+
def _(fact: Assertion) -> tuple[CompoundTerm[Constant | CompoundTerm] | Constant, ...]:
|
|
23
|
+
return unify_all_terms(fact.lhs) + unify_all_terms(fact.rhs)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@unify_all_terms.register(Formula)
|
|
27
|
+
def _(fact: Formula) -> tuple[CompoundTerm[Constant | CompoundTerm] | Constant, ...]:
|
|
28
|
+
tuple_left = unify_all_terms(fact.formula_left)
|
|
29
|
+
tuple_right = unify_all_terms(fact.formula_right) if fact.formula_right is not None else ()
|
|
30
|
+
return tuple_right + tuple_left
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@unify_all_terms.register(CompoundTerm)
|
|
34
|
+
def _(fact: CompoundTerm[Constant | CompoundTerm]) -> tuple[CompoundTerm[Constant | CompoundTerm] | Constant, ...]:
|
|
35
|
+
return tuple(split_all_terms(fact))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@unify_all_terms.register(Constant)
|
|
39
|
+
def _(fact: Constant) -> tuple[CompoundTerm[Constant | CompoundTerm] | Constant, ...]:
|
|
40
|
+
return (fact, )
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@unify_all_terms.register(Variable)
|
|
44
|
+
def _(fact: Variable) -> tuple[CompoundTerm[Constant | CompoundTerm] | Constant, ...]:
|
|
45
|
+
warnings.warn("Variable should not exist in fact", stacklevel=2)
|
|
46
|
+
return ()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class FREEVARANY(Constant):
|
|
50
|
+
"""
|
|
51
|
+
ANY标签,在free_variables用于占位,暂定为一种特殊的Constant。
|
|
52
|
+
本引擎在flat term的level上进行grounding操作,即规则、事实中的fact都会被拆解到flat term层级进行匹配。因此nested term需要被拆解为多个
|
|
53
|
+
flat term完成,并且nested term的arguments里的Term类型的值,需要被替换为通配符,即FREEVARANY类
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, value: str) -> None:
|
|
57
|
+
concept = FREEVARANY_CONCEPT
|
|
58
|
+
super().__init__(value, concept)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
FREEANY = FREEVARANY('FREEVARANY')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def split_all_terms(term: CompoundTerm[Constant | CompoundTerm]) -> list[CompoundTerm[Constant | CompoundTerm]]:
|
|
65
|
+
"""
|
|
66
|
+
这个函能将一个Term的所有复合子结构取出来,返回一个list。
|
|
67
|
+
"""
|
|
68
|
+
# 倒是Constant可能也算flat term,如果对类型标注比较麻烦就算了,可以区分为俩倒是,就注释时候仔细一点即可
|
|
69
|
+
split_terms: list[CompoundTerm[Constant | CompoundTerm]] = []
|
|
70
|
+
split_terms.append(term) # 将一个复合的term中的复合子结构取出来
|
|
71
|
+
|
|
72
|
+
if TYPE_CHECKING:
|
|
73
|
+
term.arguments = cast("tuple[Constant | CompoundTerm[Constant | CompoundTerm], ...]", term.arguments)
|
|
74
|
+
|
|
75
|
+
for var in term.arguments:
|
|
76
|
+
if isinstance(var, CompoundTerm):
|
|
77
|
+
split_terms.extend(split_all_terms(var))
|
|
78
|
+
|
|
79
|
+
return split_terms
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def flatten_arguments(arguments: Sequence[TERM_TYPE]) -> tuple[ATOM_TYPE, ...]: # 暂时先作为对外函数,
|
|
83
|
+
# 另外这一页的singledispatch按说可以改成正常的if,以获得更清晰的阅读体验(比如把class丢最上面)
|
|
84
|
+
"""
|
|
85
|
+
给定一个term,这个函数会将term的arguments中的所有Term替换为$F
|
|
86
|
+
无论在fact还是rule中这个函数都是可用的,因为是否存在variable并不影响这个函数的工作
|
|
87
|
+
"""
|
|
88
|
+
return tuple(
|
|
89
|
+
FREEANY if isinstance(var, CompoundTerm) else var
|
|
90
|
+
for var in arguments
|
|
91
|
+
)
|