kele 0.0.1a1__cp313-cp313-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/__init__.py +38 -0
- kele/_version.py +1 -0
- kele/config.py +243 -0
- kele/control/README_metrics.md +102 -0
- kele/control/__init__.py +20 -0
- kele/control/callback.py +255 -0
- kele/control/grounding_selector/__init__.py +5 -0
- kele/control/grounding_selector/_rule_strategies/README.md +13 -0
- kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
- kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
- kele/control/grounding_selector/_selector_utils.py +123 -0
- kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
- kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
- kele/control/grounding_selector/rule_selector.py +98 -0
- kele/control/grounding_selector/term_selector.py +89 -0
- kele/control/infer_path.py +306 -0
- kele/control/metrics.py +357 -0
- kele/control/status.py +286 -0
- kele/egg_equiv.pyd +0 -0
- kele/egg_equiv.pyi +11 -0
- kele/equality/README.md +8 -0
- kele/equality/__init__.py +4 -0
- kele/equality/_egg_equiv/src/lib.rs +267 -0
- kele/equality/_equiv_elem.py +67 -0
- kele/equality/_utils.py +36 -0
- kele/equality/equivalence.py +141 -0
- kele/executer/__init__.py +4 -0
- kele/executer/executing.py +139 -0
- kele/grounder/README.md +83 -0
- kele/grounder/__init__.py +17 -0
- kele/grounder/grounded_rule_ds/__init__.py +6 -0
- kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
- kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
- kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
- kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
- kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
- kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
- kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
- kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
- kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
- kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
- kele/grounder/grounded_rule_ds/rule_check.py +373 -0
- kele/grounder/grounding.py +118 -0
- kele/knowledge_bases/README.md +112 -0
- kele/knowledge_bases/__init__.py +6 -0
- kele/knowledge_bases/builtin_base/__init__.py +1 -0
- kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
- kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
- kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
- kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
- kele/knowledge_bases/fact_base.py +158 -0
- kele/knowledge_bases/ontology_base.py +67 -0
- kele/knowledge_bases/rule_base.py +194 -0
- kele/main.py +464 -0
- kele/py.typed +0 -0
- kele/syntax/CONCEPT_README.md +117 -0
- kele/syntax/__init__.py +40 -0
- kele/syntax/_cnf_converter.py +161 -0
- kele/syntax/_sat_solver.py +116 -0
- kele/syntax/base_classes.py +1482 -0
- kele/syntax/connectives.py +20 -0
- kele/syntax/dnf_converter.py +145 -0
- kele/syntax/external.py +17 -0
- kele/syntax/sub_concept.py +87 -0
- kele/syntax/syntacticsugar.py +201 -0
- kele-0.0.1a1.dist-info/METADATA +166 -0
- kele-0.0.1a1.dist-info/RECORD +74 -0
- kele-0.0.1a1.dist-info/WHEEL +4 -0
- kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
- kele-0.0.1a1.dist-info/licenses/licensecheck.json +20 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, cast
|
|
4
|
+
import logging
|
|
5
|
+
from functools import reduce, partial
|
|
6
|
+
from itertools import product
|
|
7
|
+
import polars as pl
|
|
8
|
+
|
|
9
|
+
from ._tftable import TfTables
|
|
10
|
+
from ._conn import _ConnectiveNode
|
|
11
|
+
from ._tupletable import _TupleTable
|
|
12
|
+
|
|
13
|
+
from kele.syntax import Variable, Assertion, CompoundTerm, Constant, FlatCompoundTerm, Concept
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from kele.syntax import TERM_TYPE
|
|
17
|
+
from ._rule import _RuleNode
|
|
18
|
+
from collections.abc import Generator
|
|
19
|
+
from kele.grounder import GroundedRule
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _AssertionNode:
|
|
25
|
+
"""
|
|
26
|
+
负责断言节点的 join 与 check 执行。
|
|
27
|
+
|
|
28
|
+
- join:合并来自 term 节点的变量候选表;
|
|
29
|
+
- check:基于 total_table 执行真假判断;
|
|
30
|
+
- action assertion:计算 action term 并将结果写回事实。
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, content: Assertion,
|
|
34
|
+
grounded_rule: GroundedRule,
|
|
35
|
+
*,
|
|
36
|
+
negated_assertion: bool) -> None:
|
|
37
|
+
|
|
38
|
+
self.content: Assertion = content
|
|
39
|
+
self.rule_or_connective_children: list[_ConnectiveNode | _RuleNode] = [] # 亦或者应该统一叫children,然后不要next_node函数。 # 我们此刻的图
|
|
40
|
+
# 这里不是list,就是一个ConnectiveNode,但考虑到共享不妨留一下。另外init里估计要传这个参数
|
|
41
|
+
# 还需要一个标记是否有匹配结果的标记符
|
|
42
|
+
self.grounded_rule = grounded_rule
|
|
43
|
+
self._is_concept_compatible = partial(
|
|
44
|
+
Concept.is_compatible,
|
|
45
|
+
fuzzy_match=self.grounded_rule.args.grounder.conceptual_fuzzy_unification,
|
|
46
|
+
)
|
|
47
|
+
# TODO: Consider refactoring concept-compatibility configuration for cleaner ownership.
|
|
48
|
+
|
|
49
|
+
self.tf_table: TfTables
|
|
50
|
+
self.all_freevar_table: list[_TupleTable] = []
|
|
51
|
+
self.freevar_table: _TupleTable
|
|
52
|
+
self.grounding_arguments: set[Variable]
|
|
53
|
+
self._action_result: list[Assertion] = [] # FIXME: 这里返回的是含有action_op的Assertion,后续有待进一步讨论这里的格式
|
|
54
|
+
|
|
55
|
+
if negated_assertion:
|
|
56
|
+
# 否定assertion的初始化
|
|
57
|
+
self.grounding_arguments = set()
|
|
58
|
+
self.action_assertion = self.content.is_action_assertion
|
|
59
|
+
elif self.content.is_action_assertion:
|
|
60
|
+
# 非否定但是action_assertion的初始化
|
|
61
|
+
self.action_assertion = True
|
|
62
|
+
self.grounding_arguments = set()
|
|
63
|
+
for term in (content.lhs, content.rhs):
|
|
64
|
+
if not term.is_action_term:
|
|
65
|
+
self.grounding_arguments.update(term.free_variables)
|
|
66
|
+
else:
|
|
67
|
+
# 其他情况的初始化
|
|
68
|
+
self.action_assertion = False
|
|
69
|
+
self.grounding_arguments = set(self.content.free_variables)
|
|
70
|
+
self.past_freevar_table: _TupleTable = _TupleTable(column_name=tuple(self.grounding_arguments))
|
|
71
|
+
# 过往freevar_table,每次都会通过concat将当前freevar_table加入记忆中
|
|
72
|
+
|
|
73
|
+
self.keep_table: bool | None = None # 基于SAT solver的结果,确定需要保留的table
|
|
74
|
+
# true 保留true table, false保留false table,None表示均保留
|
|
75
|
+
self.negated_assertion: bool = negated_assertion # 记录了当前assertion是否被not算子影响。
|
|
76
|
+
# 受到not算子影响的Assertion不会进入grounding, join流程,也不会建立CompoundTerm节点
|
|
77
|
+
# 记录了最后合并得到的大表
|
|
78
|
+
self.total_table: _TupleTable
|
|
79
|
+
|
|
80
|
+
def __str__(self) -> str:
|
|
81
|
+
return str(self.content)
|
|
82
|
+
|
|
83
|
+
def add_child(self, node: _ConnectiveNode | _RuleNode) -> None:
|
|
84
|
+
self.rule_or_connective_children.append(node)
|
|
85
|
+
|
|
86
|
+
def exec_join(self) -> _TupleTable:
|
|
87
|
+
"""
|
|
88
|
+
执行join操作,将所有子节点的结果合并起来。
|
|
89
|
+
|
|
90
|
+
:return: 合并后的freevar_table
|
|
91
|
+
:rtype: _TupleTable
|
|
92
|
+
"""
|
|
93
|
+
self.freevar_table = reduce(lambda x, y: x.union_table(y), self.all_freevar_table)
|
|
94
|
+
self.freevar_table = self._drop_invalid_bindings(self.freevar_table)
|
|
95
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
96
|
+
input_summaries = [table.debug_summary() for table in self.all_freevar_table]
|
|
97
|
+
logger.debug(
|
|
98
|
+
"Assertion node merged freevar tables: assertion=%s inputs=%s merged=%s",
|
|
99
|
+
self.content,
|
|
100
|
+
input_summaries,
|
|
101
|
+
self.freevar_table.debug_summary(),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return self.freevar_table
|
|
105
|
+
|
|
106
|
+
def exec_check(self) -> None:
|
|
107
|
+
"""
|
|
108
|
+
检查每个可能的赋值,并传递 true/false 索引表。
|
|
109
|
+
"""
|
|
110
|
+
empty_table = _TupleTable(column_name=tuple(self.content.free_variables)) # check操作是验证assertion实例化正确性的,
|
|
111
|
+
# 需要和assertion本身相关的变量。而grounding_arguments则用于判断参与unify过程的判断
|
|
112
|
+
|
|
113
|
+
small_table = self.total_table.get_small_table(tuple(self.content.free_variables))
|
|
114
|
+
|
|
115
|
+
# 对于无变量的assertion而言,只需要判断和传递自身的True/False
|
|
116
|
+
if not small_table.column_name:
|
|
117
|
+
check_result = self._check_self() # 判断当前assertion是否为真
|
|
118
|
+
if check_result and (self.keep_table is None or self.keep_table): # 如果为真,且当前Assertion也有为真的解释,则保留
|
|
119
|
+
self.tf_table = TfTables(true=_TupleTable.create_empty_table_with_emptyset(), false=empty_table)
|
|
120
|
+
# 保留用空列的table表示,即不参与table的运算。
|
|
121
|
+
# 因此空列table与其他table做cross时,将直接返回另一个table。不保留则用空行table表示,因为空行意味着某个变量组合没有合法替换方案。
|
|
122
|
+
elif not check_result and (self.keep_table is None or not self.keep_table):
|
|
123
|
+
self.tf_table = TfTables(true=empty_table, false=_TupleTable.create_empty_table_with_emptyset())
|
|
124
|
+
else:
|
|
125
|
+
# 如果keep_table与自身检查结果不匹配(例如自身为True,但是keep_table为False),那么两个表都是空
|
|
126
|
+
self.tf_table = TfTables(true=empty_table, false=empty_table)
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
tf_list: list[bool] = []
|
|
130
|
+
|
|
131
|
+
for combination in small_table.iter_rows():
|
|
132
|
+
if self.action_assertion:
|
|
133
|
+
tf_list.append(self._check_single_action_assertion(combination))
|
|
134
|
+
else:
|
|
135
|
+
tf_list.append(self._check_single(combination))
|
|
136
|
+
|
|
137
|
+
true_table, false_table = small_table.get_true_false_table(tf_list, keep_table=self.keep_table)
|
|
138
|
+
|
|
139
|
+
self.tf_table = TfTables(true=true_table, false=false_table)
|
|
140
|
+
|
|
141
|
+
if not self.only_substitution:
|
|
142
|
+
self.past_freevar_table = self.past_freevar_table.concat_table(self.freevar_table)
|
|
143
|
+
|
|
144
|
+
def exec_action(self, temp_table: _TupleTable) -> list[Assertion]:
|
|
145
|
+
"""
|
|
146
|
+
执行action操作,将当前assertion实例化后的结果返回。
|
|
147
|
+
|
|
148
|
+
:param temp_table: 当前assertion实例化后的表
|
|
149
|
+
:type temp_table: _TupleTable
|
|
150
|
+
:return: 当前assertion实例化后的结果
|
|
151
|
+
:rtype: list[Assertion]
|
|
152
|
+
"""
|
|
153
|
+
result_list: list[Assertion] = []
|
|
154
|
+
for combination in temp_table.iter_rows():
|
|
155
|
+
if isinstance(self.content.lhs, CompoundTerm) and self.content.lhs.is_action_term:
|
|
156
|
+
replaced_lhs = self.content.lhs.replace_variable(combination)
|
|
157
|
+
value = self._exec_implement_func(replaced_lhs)
|
|
158
|
+
if value is not None:
|
|
159
|
+
result_list.append(Assertion.from_parts(replaced_lhs, value)) # 记录计算出来的值,这个值将在后续加入事实库中
|
|
160
|
+
|
|
161
|
+
if isinstance(self.content.rhs, CompoundTerm) and self.content.rhs.is_action_term:
|
|
162
|
+
replaced_rhs = self.content.rhs.replace_variable(combination)
|
|
163
|
+
value = self._exec_implement_func(replaced_rhs)
|
|
164
|
+
if value is not None:
|
|
165
|
+
result_list.append(Assertion.from_parts(replaced_rhs, value)) # 记录计算出来的值,这个值将在后续加入事实库中
|
|
166
|
+
|
|
167
|
+
return result_list
|
|
168
|
+
|
|
169
|
+
def _check_self(self) -> bool:
|
|
170
|
+
"""
|
|
171
|
+
检查自身是否为真。
|
|
172
|
+
|
|
173
|
+
:raise ValueError: grounding_arguments不为空时,combination不能为None
|
|
174
|
+
""" # noqa: DOC501
|
|
175
|
+
if self.grounding_arguments:
|
|
176
|
+
raise ValueError("Assertion with pure arguments must be checked with combination")
|
|
177
|
+
return self._ask_equivalence(self.content) or self._ask_sk_system(self.content)
|
|
178
|
+
|
|
179
|
+
def _exec_implement_func(self, term: CompoundTerm) -> CompoundTerm | Constant | None:
|
|
180
|
+
"""
|
|
181
|
+
给定 action_term 计算结果并返回。
|
|
182
|
+
"""
|
|
183
|
+
implement_func = term.operator.implement_func
|
|
184
|
+
if implement_func is None:
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
equivalence = self.grounded_rule.equivalence
|
|
188
|
+
|
|
189
|
+
candidate_arguments: list[list[TERM_TYPE]] = []
|
|
190
|
+
for arg in term.arguments:
|
|
191
|
+
if TYPE_CHECKING:
|
|
192
|
+
arg = cast("CompoundTerm | Constant", arg)
|
|
193
|
+
equiv_items = list(equivalence.get_equiv_item(arg))
|
|
194
|
+
candidate_arguments.append(equiv_items) # 这里get_equiv_item已经把元素自身加入进去了,同时去重操作也已经进行过了
|
|
195
|
+
# 获得符合条件的所有参数的组合(写成list[list[TERM_TYPE]],每个list对应每个参数的的equiv_items)
|
|
196
|
+
|
|
197
|
+
for arguments in product(*candidate_arguments):
|
|
198
|
+
candidate_term = FlatCompoundTerm.from_parts(term.operator, arguments)
|
|
199
|
+
try:
|
|
200
|
+
result = implement_func(candidate_term)
|
|
201
|
+
if TYPE_CHECKING:
|
|
202
|
+
result = cast("CompoundTerm | Constant", result)
|
|
203
|
+
except TypeError: # HACK: 这里except TypeError的实质是处理无法计算的情况,将在进一步确定implement_func的格式之后修改这里的代码
|
|
204
|
+
continue # 如果计算失败,尝试下一个参数组合,否则将会直接返回结果
|
|
205
|
+
else:
|
|
206
|
+
return result
|
|
207
|
+
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
def _check_single_action_assertion(self, combination: dict[Variable, Constant | CompoundTerm]) -> bool:
|
|
211
|
+
"""执行 action assertion:先计算 action term,再进行真假判断。"""
|
|
212
|
+
assertion = self.content.replace_variable(combination)
|
|
213
|
+
if isinstance(assertion.lhs, CompoundTerm) and assertion.lhs.is_action_term:
|
|
214
|
+
value = self._exec_implement_func(assertion.lhs)
|
|
215
|
+
if value is not None:
|
|
216
|
+
assertion = Assertion.from_parts(value, assertion.rhs)
|
|
217
|
+
# value为None时,意味着implement_func无法计算,此时不再替换assertion.lhs
|
|
218
|
+
|
|
219
|
+
if isinstance(assertion.rhs, CompoundTerm) and assertion.rhs.is_action_term:
|
|
220
|
+
value = self._exec_implement_func(assertion.rhs)
|
|
221
|
+
if value is not None:
|
|
222
|
+
assertion = Assertion.from_parts(assertion.lhs, value)
|
|
223
|
+
|
|
224
|
+
return self._ask_equivalence(assertion) or self._ask_sk_system(assertion)
|
|
225
|
+
|
|
226
|
+
def _check_single(self, combination: dict[Variable, Constant | CompoundTerm]) -> bool:
|
|
227
|
+
"""
|
|
228
|
+
对实例化候选进行检查,并返回真假结果。TODO: 暂且没有和GroundedRule交互用到一些缓存,以后可以优化
|
|
229
|
+
|
|
230
|
+
:param combination: 变量替换表
|
|
231
|
+
"""
|
|
232
|
+
assertion = self.content.replace_variable(combination)
|
|
233
|
+
# 变量表,每个Variable只用唯一一个指针的话,这里的时空开销会更低,不过并行难度会加大。但暂时这个级别的优化似乎意义不大
|
|
234
|
+
if self._ask_equivalence(assertion):
|
|
235
|
+
return True
|
|
236
|
+
if self._ask_sk_system(assertion): # noqa: SIM103
|
|
237
|
+
return True
|
|
238
|
+
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def _drop_invalid_bindings(self, table: _TupleTable) -> _TupleTable:
|
|
242
|
+
"""
|
|
243
|
+
移除 concept mismatch 的绑定,避免将不合法的替换传播到 join/check。
|
|
244
|
+
极致优化时可考虑在 exec_check 中惰性过滤以减少重复 replace,但易引入错误。
|
|
245
|
+
"""
|
|
246
|
+
if table.height == 0:
|
|
247
|
+
return table
|
|
248
|
+
|
|
249
|
+
lhs = self.content.lhs
|
|
250
|
+
rhs = self.content.rhs
|
|
251
|
+
|
|
252
|
+
if not isinstance(lhs, Variable) or not isinstance(rhs, Variable): # 如果有非变量的,说明它被约束了,不需要单独drop
|
|
253
|
+
return table
|
|
254
|
+
|
|
255
|
+
if self.grounded_rule.rule.get_variable_concept_constraints(lhs) or \
|
|
256
|
+
self.grounded_rule.rule.get_variable_concept_constraints(rhs): # 如果变量本身含约束,也不需要单独drop。且任一含就都含
|
|
257
|
+
return table
|
|
258
|
+
|
|
259
|
+
valid_mask: list[bool] = []
|
|
260
|
+
for combination in table.iter_rows():
|
|
261
|
+
lhs_concepts = self._binding_concepts(lhs, combination)
|
|
262
|
+
rhs_concepts = self._binding_concepts(rhs, combination)
|
|
263
|
+
valid_mask.append(self._is_concept_compatible(lhs_concepts, rhs_concepts))
|
|
264
|
+
|
|
265
|
+
if all(valid_mask):
|
|
266
|
+
return table
|
|
267
|
+
|
|
268
|
+
table.make_table_ready()
|
|
269
|
+
filtered_table = _TupleTable(table.raw_column_name)
|
|
270
|
+
mask_series = pl.Series(valid_mask, dtype=pl.Boolean)
|
|
271
|
+
filtered_table.set_base_df(table.base_df.filter(mask_series))
|
|
272
|
+
return filtered_table
|
|
273
|
+
|
|
274
|
+
@staticmethod
|
|
275
|
+
def _binding_concepts(
|
|
276
|
+
term: Constant | CompoundTerm | Variable,
|
|
277
|
+
combination: dict[Variable, Constant | CompoundTerm],
|
|
278
|
+
) -> set[Concept]:
|
|
279
|
+
if isinstance(term, Variable):
|
|
280
|
+
bound = combination[term]
|
|
281
|
+
return _AssertionNode._binding_concepts(bound, combination)
|
|
282
|
+
if isinstance(term, Constant):
|
|
283
|
+
return term.belong_concepts
|
|
284
|
+
return {term.operator.output_concept}
|
|
285
|
+
|
|
286
|
+
def _ask_equivalence(self, assertion: Assertion) -> bool:
|
|
287
|
+
"""查询等价关系是否能证明该断言成立。"""
|
|
288
|
+
return (self.grounded_rule.equivalence is not None and
|
|
289
|
+
self.grounded_rule.equivalence.query_equivalence(assertion))
|
|
290
|
+
|
|
291
|
+
def _ask_sk_system(self, assertion: Assertion) -> bool:
|
|
292
|
+
return (self.grounded_rule.sk_system_handler is not None and
|
|
293
|
+
len(self.grounded_rule.sk_system_handler.query_assertion(assertion)) > 0)
|
|
294
|
+
|
|
295
|
+
def broadcast_total_table(self, total_table: _TupleTable) -> None:
|
|
296
|
+
"""
|
|
297
|
+
广播 total_table,为规则级别的变量候选总表。
|
|
298
|
+
|
|
299
|
+
:params: total_freevar (_TupleTable)
|
|
300
|
+
"""
|
|
301
|
+
self.total_table = total_table
|
|
302
|
+
|
|
303
|
+
def query_for_children(self, term: TERM_TYPE | CompoundTerm | None = None) -> Generator[_ConnectiveNode | _RuleNode]:
|
|
304
|
+
"""
|
|
305
|
+
Yields:
|
|
306
|
+
ConnectiveNode | RuleNode: 跳转到下一级节点,对应嵌套的ConnectiveNode或最终的RuleNode
|
|
307
|
+
"""
|
|
308
|
+
yield from self.rule_or_connective_children
|
|
309
|
+
|
|
310
|
+
def pass_tf_index(self) -> None:
|
|
311
|
+
"""
|
|
312
|
+
传递自身节点的 true/false table 到子节点。
|
|
313
|
+
"""
|
|
314
|
+
# 在一次传递之后,立即移除本身的tf_indexs,防止反复执行时重复传递
|
|
315
|
+
for child in self.query_for_children():
|
|
316
|
+
if not hasattr(child, "left_table"):
|
|
317
|
+
child.left_table = self.tf_table
|
|
318
|
+
elif isinstance(child, _ConnectiveNode):
|
|
319
|
+
child.right_table = self.tf_table
|
|
320
|
+
del self.tf_table
|
|
321
|
+
|
|
322
|
+
def get_action_result(self) -> Generator[Assertion]:
|
|
323
|
+
"""
|
|
324
|
+
获取 action assertion 计算得到的结果。
|
|
325
|
+
|
|
326
|
+
:yield: 执行的结果
|
|
327
|
+
:rtype: Assertion
|
|
328
|
+
""" # noqa: DOC402
|
|
329
|
+
yield from self._action_result
|
|
330
|
+
self._action_result.clear()
|
|
331
|
+
|
|
332
|
+
@property
|
|
333
|
+
def ready_to_execute(self) -> bool:
|
|
334
|
+
"""
|
|
335
|
+
AssertionNode 总是执行队列中的起点,因此始终处于 ready 状态。
|
|
336
|
+
"""
|
|
337
|
+
return True
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def only_substitution(self) -> bool:
|
|
341
|
+
"""
|
|
342
|
+
TODO:safety稳定后更新注释
|
|
343
|
+
"""
|
|
344
|
+
return self.negated_assertion or not self.grounding_arguments
|
|
345
|
+
|
|
346
|
+
def get_all_children(self) -> Generator[_ConnectiveNode | _RuleNode]:
|
|
347
|
+
yield from self.rule_or_connective_children
|
|
348
|
+
|
|
349
|
+
def reset(self) -> None:
|
|
350
|
+
if hasattr(self, 'total_table'):
|
|
351
|
+
del self.total_table
|
|
352
|
+
if hasattr(self, 'tf_table'):
|
|
353
|
+
del self.tf_table
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import reduce
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from ._tftable import TfTables
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from kele.syntax import TERM_TYPE, CompoundTerm, Formula
|
|
10
|
+
from collections.abc import Generator
|
|
11
|
+
from ._rule import _RuleNode
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _ConnectiveNode:
|
|
15
|
+
"""表达p^q等formula,换句话是若干(正常是1-2个)Assertion及逻辑连接词形成了一个Formula。主要是bool值判断和候选值的join、传递"""
|
|
16
|
+
def __init__(self, formula: Formula) -> None:
|
|
17
|
+
self.rule_or_connective_children: list[_ConnectiveNode | _RuleNode] = []
|
|
18
|
+
self.content = formula
|
|
19
|
+
|
|
20
|
+
self.tf_table: TfTables
|
|
21
|
+
self.left_table: TfTables
|
|
22
|
+
self.right_table: TfTables
|
|
23
|
+
|
|
24
|
+
self.left_or_right: int # 记录其为左父节点还是右父节点,0为左父节点,1为右父节点
|
|
25
|
+
|
|
26
|
+
def __str__(self) -> str:
|
|
27
|
+
return str(self.content)
|
|
28
|
+
|
|
29
|
+
def add_child(self, node: _ConnectiveNode | _RuleNode) -> None:
|
|
30
|
+
self.rule_or_connective_children.append(node)
|
|
31
|
+
|
|
32
|
+
def exec_check(self) -> None: # HACK: 当前计算图方案有大量冗余难以优化,后续可能考虑修改
|
|
33
|
+
"""
|
|
34
|
+
对单一ConnectiveNode进行处理:
|
|
35
|
+
对于AND节点而言,一者为假即为假,所以false_table进行并集。需要左右两侧都为真才能为真,所以会将传来的true_table做交集。
|
|
36
|
+
对于OR节点而言,两者为假才能为假,所以false_table进行交集。需要左右一者为真就能为真,所以会将传来的true_table做并集。
|
|
37
|
+
NOT节点只需要反转true_table和false_table就好
|
|
38
|
+
|
|
39
|
+
:raises TypeError: 未知的connective名称,在真正的图中只会处理AND, OR, NOT三种节点
|
|
40
|
+
""" # noqa: DOC501
|
|
41
|
+
# 首先应该基于left_token和right_token获得token组。组合方式自由组合原则,即左侧可以自由地和右侧的一个token组合
|
|
42
|
+
|
|
43
|
+
if self.content.connective == 'NOT':
|
|
44
|
+
# NOT节点导致True,False的table反转,即原本为True的一组取值,在经过NOT节点之后反转为False。我们只需要交换true,false table即可
|
|
45
|
+
# 同时,注意NOT节点只会有一个父节点,于是true_table和false_table都只有0
|
|
46
|
+
false_table = self.left_table.true
|
|
47
|
+
true_table = self.left_table.false
|
|
48
|
+
elif self.content.connective == 'AND':
|
|
49
|
+
# AND节点,只有同时为真的情况才为真,所以将两个true_table union之后就是此节点的true_table
|
|
50
|
+
# 左右一侧为假即为假,所以剩下情况一一对应union得到三个false table。最后直接concat拼接起来,就是此节点的false_table
|
|
51
|
+
all_false_table = [self.left_table.false.union_table(self.right_table.false), self.left_table.false.union_table(self.right_table.true),
|
|
52
|
+
self.left_table.true.union_table(self.right_table.false)]
|
|
53
|
+
false_table = reduce(lambda x, y: x.concat_table(y), all_false_table)
|
|
54
|
+
true_table = self.left_table.true.union_table(self.right_table.true)
|
|
55
|
+
elif self.content.connective == 'OR':
|
|
56
|
+
# OR节点,类似前面
|
|
57
|
+
all_true_table = [self.left_table.true.union_table(self.right_table.true), self.left_table.false.union_table(self.right_table.true),
|
|
58
|
+
self.left_table.true.union_table(self.right_table.false)]
|
|
59
|
+
true_table = reduce(lambda x, y: x.concat_table(y), all_true_table)
|
|
60
|
+
false_table = self.left_table.false.union_table(self.right_table.false)
|
|
61
|
+
else:
|
|
62
|
+
raise TypeError("Unknown connective node")
|
|
63
|
+
self.tf_table = TfTables(true=true_table, false=false_table)
|
|
64
|
+
self._remove_parent_tables()
|
|
65
|
+
# 执行完成之后就能移除父节点传来的tftables了,这样可以保证ready_for_execute属性控制不会执行第二次exec_execute
|
|
66
|
+
|
|
67
|
+
def pass_tf_index(self) -> None:
|
|
68
|
+
"""
|
|
69
|
+
传递自身节点的tf_index到子节点。
|
|
70
|
+
"""
|
|
71
|
+
# 在一次传递之后,立即移除本身的tf_indexs,防止反复执行时重复传递
|
|
72
|
+
for child in self.query_for_children():
|
|
73
|
+
if not hasattr(child, "left_table"):
|
|
74
|
+
child.left_table = self.tf_table
|
|
75
|
+
elif isinstance(child, _ConnectiveNode):
|
|
76
|
+
child.right_table = self.tf_table
|
|
77
|
+
|
|
78
|
+
del self.tf_table
|
|
79
|
+
|
|
80
|
+
def get_all_children(self) -> Generator[_ConnectiveNode | _RuleNode]:
|
|
81
|
+
"""
|
|
82
|
+
Yields:
|
|
83
|
+
ConnectiveNode | RuleNode: 跳转到下一级节点,对应嵌套的ConnectiveNode或最终的RuleNode
|
|
84
|
+
"""
|
|
85
|
+
yield from self.rule_or_connective_children
|
|
86
|
+
|
|
87
|
+
def query_for_children(self, term: TERM_TYPE | CompoundTerm | None = None) -> Generator[_ConnectiveNode | _RuleNode]:
|
|
88
|
+
"""
|
|
89
|
+
Yields:
|
|
90
|
+
ConnectiveNode | RuleNode: 跳转到下一级节点,对应嵌套的ConnectiveNode或最终的RuleNode
|
|
91
|
+
"""
|
|
92
|
+
yield from self.rule_or_connective_children
|
|
93
|
+
|
|
94
|
+
def _remove_parent_tables(self) -> None:
|
|
95
|
+
if hasattr(self, "left_table"):
|
|
96
|
+
del self.left_table
|
|
97
|
+
if hasattr(self, "right_table"):
|
|
98
|
+
del self.right_table
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def ready_to_execute(self) -> bool:
|
|
102
|
+
"""
|
|
103
|
+
这个属性作为判断是否适合执行exec_check的依据
|
|
104
|
+
当它返回false时,表明仅有一个父节点提供了它的tfIndexs信息,仍需要等待第二个父节点提供,外层需要在此时忽此节点
|
|
105
|
+
由于每次父节点执行完成后,都一定会通过query_for_children把此节点加入队列,所以一定会有某次加入时节点ready_to_execute为真
|
|
106
|
+
"""
|
|
107
|
+
if self.content.connective == 'NOT':
|
|
108
|
+
# NOT节点只有一个父节点
|
|
109
|
+
return hasattr(self, "left_table")
|
|
110
|
+
# AND, OR节点都有两个父节点,所以这里要求tf_indexs_list>=2来保证父节点都已经执行。注意:每个父节点只执行一次,这由外层控制
|
|
111
|
+
return hasattr(self, "left_table") and hasattr(self, "right_table")
|
|
112
|
+
|
|
113
|
+
def reset(self) -> None:
|
|
114
|
+
self._remove_parent_tables()
|
|
115
|
+
if hasattr(self, "tf_table"):
|
|
116
|
+
del self.tf_table
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from kele.syntax import CompoundTerm
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from kele.syntax import Operator, Constant
|
|
8
|
+
from ._term import _FlatCompoundTermNode
|
|
9
|
+
from collections.abc import Generator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _OperatorNode:
|
|
13
|
+
"""用于维护从fact → list[_FlatCompoundTermNode]的索引,每当一个fact传入时将会进入一个OperatorNode,然后去对其连接的TermNode进行实例化"""
|
|
14
|
+
def __init__(self, operator: Operator) -> None:
|
|
15
|
+
self.term_children: list[_FlatCompoundTermNode] = []
|
|
16
|
+
self.operator = operator
|
|
17
|
+
|
|
18
|
+
def __str__(self) -> str:
|
|
19
|
+
return str(self.operator)
|
|
20
|
+
|
|
21
|
+
def add_child(self, term: _FlatCompoundTermNode) -> None:
|
|
22
|
+
"""建立OperatorNode和CompoundTermNode的连边"""
|
|
23
|
+
self.term_children.append(term)
|
|
24
|
+
|
|
25
|
+
def exec_unify(self,
|
|
26
|
+
term: CompoundTerm[Constant | CompoundTerm] | Constant,
|
|
27
|
+
*,
|
|
28
|
+
allow_unify_with_nested_term: bool = True) -> None:
|
|
29
|
+
"""
|
|
30
|
+
_OperatorNode会对自己的所有child执行unify,而child的代码保证了这个Unify操作是可控的:
|
|
31
|
+
即它只局限在直接和这个_OperatorNode相连的child
|
|
32
|
+
后续的传递过程也不再会是exec_unnify
|
|
33
|
+
|
|
34
|
+
:param term (CompoundTerm | Constant): 待实例化的Term,某种意义上只有可能是FlatCompoundTerm,但是为了避免外层的类型检查,
|
|
35
|
+
我们还是采纳这种写法
|
|
36
|
+
:param allow_unify_with_nested_term: 是否允许与嵌套的Term进行unify
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(term, CompoundTerm):
|
|
39
|
+
for child in self.term_children:
|
|
40
|
+
child.exec_unify(term, allow_unify_with_nested_term=allow_unify_with_nested_term)
|
|
41
|
+
|
|
42
|
+
def query_for_children(self, term: CompoundTerm | Constant | None = None) -> Generator[_FlatCompoundTermNode]:
|
|
43
|
+
"""
|
|
44
|
+
直接返回所有子节点,暂时没有其他额外操作
|
|
45
|
+
term参数纯粹为了格式统一
|
|
46
|
+
""" # noqa: DOC402
|
|
47
|
+
yield from self.term_children
|
|
48
|
+
|
|
49
|
+
def get_all_children(self) -> Generator[_FlatCompoundTermNode]:
|
|
50
|
+
"""
|
|
51
|
+
Yields:
|
|
52
|
+
TermNode: 用于汇总同Operator的多个TermNode,所以跳转到下一级为TermNode
|
|
53
|
+
"""
|
|
54
|
+
yield from self.term_children
|
|
55
|
+
|
|
56
|
+
def reset(self) -> None:
|
|
57
|
+
pass
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from ._op import _OperatorNode
|
|
3
|
+
from kele.syntax import CompoundTerm, Operator, TERM_TYPE
|
|
4
|
+
from ._term import _FlatCompoundTermNode, _VariableNode, _ConstantNode, _TermNode
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _RootNode:
|
|
8
|
+
"""
|
|
9
|
+
这是图的入口
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
self._variable_nodes: list[_FlatCompoundTermNode] = [] # 用于添加待匹配值为变量的FlatCompoundTermNode
|
|
13
|
+
self._mapping_to_node: dict[Operator, _OperatorNode] = {} # 用于添加某些CompoundTerm所对应的operator所形成的Node
|
|
14
|
+
self._no_variable_nodes: list[_FlatCompoundTermNode] = [] # 用于添加待匹配值不含变量的FlatCompoundTermNode
|
|
15
|
+
|
|
16
|
+
def __str__(self) -> str:
|
|
17
|
+
return 'RootNode'
|
|
18
|
+
|
|
19
|
+
def add_child(self, node: _OperatorNode | _FlatCompoundTermNode) -> None:
|
|
20
|
+
"""
|
|
21
|
+
建立RootNode和OperatorNode | _FlatCompoundTermNode 的连边
|
|
22
|
+
|
|
23
|
+
:param node: 待添加的节点
|
|
24
|
+
"""
|
|
25
|
+
if isinstance(node, _OperatorNode):
|
|
26
|
+
# 实际上这里一个Operator一定只对应一个OperatorNode,所以这么写应该是没问题的
|
|
27
|
+
self._mapping_to_node[node.operator] = node
|
|
28
|
+
elif isinstance(node, _VariableNode):
|
|
29
|
+
self._variable_nodes.append(node)
|
|
30
|
+
elif isinstance(node, (_ConstantNode, _TermNode)):
|
|
31
|
+
self._no_variable_nodes.append(node)
|
|
32
|
+
|
|
33
|
+
def query_for_children(self, term_or_const: TERM_TYPE | None = None) -> \
|
|
34
|
+
tuple[_OperatorNode | _FlatCompoundTermNode, ...]:
|
|
35
|
+
"""
|
|
36
|
+
对于Operator,我们查询它对应的_OperatorNode
|
|
37
|
+
对于Constant,我们找到通配的VariableNode
|
|
38
|
+
|
|
39
|
+
无任何传入参数,返回所有的OperatorNode和VariableNode
|
|
40
|
+
|
|
41
|
+
:yield: 一个生成器,生成器里面是_OperatorNode或_FlatCompoundTermNode。如果没有任何可能的结果,那么返回None
|
|
42
|
+
"""
|
|
43
|
+
if term_or_const is None: # XXX: 用于洪泛地向下执行整张图,希望有机会移除
|
|
44
|
+
return_list = list(self._mapping_to_node.values()) + self._variable_nodes
|
|
45
|
+
return tuple(return_list)
|
|
46
|
+
|
|
47
|
+
return_list = [*self._variable_nodes]
|
|
48
|
+
if isinstance(term_or_const, CompoundTerm) and term_or_const.operator in self._mapping_to_node:
|
|
49
|
+
return_list.append(self._mapping_to_node[term_or_const.operator])
|
|
50
|
+
|
|
51
|
+
return tuple(return_list)
|
|
52
|
+
|
|
53
|
+
def operator_exist(self, operator: Operator) -> _OperatorNode | None: # HACK: 后续继续保留此函数还是选择别的方式记录有待商榷
|
|
54
|
+
# HACK: 以及返回值是直接取出OperatorNode还是单纯返回是否存在也需要考虑 # noqa: ERA001
|
|
55
|
+
"""
|
|
56
|
+
检查operator是否存在于图中
|
|
57
|
+
|
|
58
|
+
:param operator: 待检查的operator
|
|
59
|
+
:return: 如果存在,返回对应的_OperatorNode,否则返回None
|
|
60
|
+
"""
|
|
61
|
+
if operator in self._mapping_to_node:
|
|
62
|
+
return self._mapping_to_node[operator]
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def get_all_children(self) -> Generator[_OperatorNode | _FlatCompoundTermNode]:
|
|
66
|
+
yield from self._mapping_to_node.values()
|
|
67
|
+
yield from self._variable_nodes
|
|
68
|
+
yield from self._no_variable_nodes
|
|
69
|
+
|
|
70
|
+
def reset(self) -> None:
|
|
71
|
+
pass
|