kele 0.0.1a1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. kele/__init__.py +38 -0
  2. kele/_version.py +1 -0
  3. kele/config.py +243 -0
  4. kele/control/README_metrics.md +102 -0
  5. kele/control/__init__.py +20 -0
  6. kele/control/callback.py +255 -0
  7. kele/control/grounding_selector/__init__.py +5 -0
  8. kele/control/grounding_selector/_rule_strategies/README.md +13 -0
  9. kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
  10. kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
  11. kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
  12. kele/control/grounding_selector/_selector_utils.py +123 -0
  13. kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
  14. kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
  15. kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
  16. kele/control/grounding_selector/rule_selector.py +98 -0
  17. kele/control/grounding_selector/term_selector.py +89 -0
  18. kele/control/infer_path.py +306 -0
  19. kele/control/metrics.py +357 -0
  20. kele/control/status.py +286 -0
  21. kele/egg_equiv.pyd +0 -0
  22. kele/egg_equiv.pyi +11 -0
  23. kele/equality/README.md +8 -0
  24. kele/equality/__init__.py +4 -0
  25. kele/equality/_egg_equiv/src/lib.rs +267 -0
  26. kele/equality/_equiv_elem.py +67 -0
  27. kele/equality/_utils.py +36 -0
  28. kele/equality/equivalence.py +141 -0
  29. kele/executer/__init__.py +4 -0
  30. kele/executer/executing.py +139 -0
  31. kele/grounder/README.md +83 -0
  32. kele/grounder/__init__.py +17 -0
  33. kele/grounder/grounded_rule_ds/__init__.py +6 -0
  34. kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
  35. kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
  36. kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
  37. kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
  38. kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
  39. kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
  40. kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
  41. kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
  42. kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
  43. kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
  44. kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
  45. kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
  46. kele/grounder/grounded_rule_ds/rule_check.py +373 -0
  47. kele/grounder/grounding.py +118 -0
  48. kele/knowledge_bases/README.md +112 -0
  49. kele/knowledge_bases/__init__.py +6 -0
  50. kele/knowledge_bases/builtin_base/__init__.py +1 -0
  51. kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
  52. kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
  53. kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
  54. kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
  55. kele/knowledge_bases/fact_base.py +158 -0
  56. kele/knowledge_bases/ontology_base.py +67 -0
  57. kele/knowledge_bases/rule_base.py +194 -0
  58. kele/main.py +464 -0
  59. kele/py.typed +0 -0
  60. kele/syntax/CONCEPT_README.md +117 -0
  61. kele/syntax/__init__.py +40 -0
  62. kele/syntax/_cnf_converter.py +161 -0
  63. kele/syntax/_sat_solver.py +116 -0
  64. kele/syntax/base_classes.py +1482 -0
  65. kele/syntax/connectives.py +20 -0
  66. kele/syntax/dnf_converter.py +145 -0
  67. kele/syntax/external.py +17 -0
  68. kele/syntax/sub_concept.py +87 -0
  69. kele/syntax/syntacticsugar.py +201 -0
  70. kele-0.0.1a1.dist-info/METADATA +166 -0
  71. kele-0.0.1a1.dist-info/RECORD +74 -0
  72. kele-0.0.1a1.dist-info/WHEEL +4 -0
  73. kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
  74. kele-0.0.1a1.dist-info/licenses/licensecheck.json +20 -0
kele/control/status.py ADDED
@@ -0,0 +1,286 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum, auto
4
+
5
+ from typing import TYPE_CHECKING, Literal
6
+ import logging
7
+
8
+ if TYPE_CHECKING:
9
+ from kele.syntax import FACT_TYPE, Question
10
+ from kele.equality import Equivalence
11
+ from collections.abc import Sequence, Mapping
12
+ from kele.syntax import Variable, Rule, CompoundTerm, Constant, _QuestionRule
13
+ from kele.syntax import SankuManagementSystem
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class InferenceStatus(Enum):
19
+ """推理过程的终止状态类型"""
20
+
21
+ SUCCESS = auto() # 成功推理出答案
22
+ MAX_STEPS_REACHED = auto() # 达到最大执行轮次限制
23
+ MAX_ITERATIONS_REACHED = auto() # 达到最大迭代次数限制(从main循环)
24
+ FIXPOINT_REACHED = auto() # 达到不动点
25
+ CONFLICT_DETECTED = auto() # 推理中发现矛盾
26
+ EXTERNALLY_INTERRUPTED = auto() # 被外部中断(如人为终止)
27
+ NO_MORE_RULES = auto() # 没有更多实例化规则可以推理
28
+ CONTINUE = auto() # 当前轮未结束,继续执行
29
+
30
+ def log_message(self) -> str:
31
+ """每个状态对应的日志输出信息"""
32
+ return {
33
+ InferenceStatus.SUCCESS: "Successfully inferred the answer.",
34
+ InferenceStatus.MAX_STEPS_REACHED: "Max execution steps reached, terminating.",
35
+ InferenceStatus.MAX_ITERATIONS_REACHED: "Max inference iterations reached, terminating.",
36
+ InferenceStatus.FIXPOINT_REACHED: "Fixpoint reached, terminating.",
37
+ InferenceStatus.CONFLICT_DETECTED: "Conflict detected during execution.",
38
+ InferenceStatus.EXTERNALLY_INTERRUPTED: "Execution was externally interrupted.",
39
+ InferenceStatus.NO_MORE_RULES: "No more grounded rules available for execution.",
40
+ InferenceStatus.CONTINUE: "Execution continues to next round."
41
+ }[self]
42
+
43
+ def is_terminal_for_executor(self) -> bool:
44
+ """判断执行器是否应该终止"""
45
+ return self != InferenceStatus.CONTINUE
46
+
47
+ def is_terminal_for_main_loop(self) -> bool:
48
+ """判断主循环是否应该终止"""
49
+ return self in {
50
+ InferenceStatus.SUCCESS,
51
+ InferenceStatus.MAX_STEPS_REACHED,
52
+ InferenceStatus.MAX_ITERATIONS_REACHED,
53
+ InferenceStatus.FIXPOINT_REACHED,
54
+ InferenceStatus.EXTERNALLY_INTERRUPTED
55
+ }
56
+
57
+
58
+ class QuerySolutionManager:
59
+ """查询解管理器:负责打印、收集和管理查询解"""
60
+
61
+ def __init__(self, interactive_query_mode: Literal["interactive", "first", "all"] = "first",
62
+ *, store_solutions: bool = False) -> None:
63
+ self.collected_solutions: list[Mapping[Variable, Constant | CompoundTerm]] = []
64
+ self.interactive_query_mode = interactive_query_mode
65
+ self.store_solutions = store_solutions
66
+ self._solution_count: int = 0
67
+
68
+ def reset(self) -> None:
69
+ """清空已收集的解和计数"""
70
+ self.collected_solutions.clear()
71
+ self._solution_count = 0
72
+
73
+ def print_and_collect(
74
+ self,
75
+ question: Question,
76
+ solutions: list[Mapping[Variable, Constant | CompoundTerm]],
77
+ question_rule: _QuestionRule
78
+ ) -> bool:
79
+ """
80
+ 打印并收集解
81
+ :return: 是否应该终止推理
82
+ """
83
+ question_str = ", ".join(str(q) for q in question.question)
84
+
85
+ for combination in solutions:
86
+ self._solution_count += 1
87
+
88
+ if combination:
89
+ binding_str = self._display_binding(combination)
90
+ logger.info("Query solution: %s variable bindings: %s", question_str, binding_str)
91
+ if self.store_solutions:
92
+ self.collected_solutions.append(combination)
93
+ else:
94
+ logger.info("Query solution: %s", question_str)
95
+ if self.store_solutions:
96
+ self.collected_solutions.append(combination)
97
+ if not question_rule.free_variables:
98
+ return True
99
+
100
+ # 根据模式决定是否终止
101
+ if self._should_terminate(self.interactive_query_mode):
102
+ return True
103
+
104
+ return False # 继续推理
105
+
106
+ @staticmethod
107
+ def _display_binding(combination: Mapping[Variable, Constant | CompoundTerm]) -> str:
108
+ """ 将内部变量名还原为查询中的原始变量名"""
109
+ entries: list[str] = []
110
+ for var, term in combination.items():
111
+ entries.append(f"{var.display_name}={term}")
112
+ return ", ".join(entries)
113
+
114
+ @staticmethod
115
+ def _should_terminate(mode: Literal["interactive", "first", "all"]) -> bool:
116
+ """根据交互模式决定是否应该终止推理"""
117
+ if mode == "first":
118
+ logger.info("Configured to output only the first query solution; stopping inference.")
119
+ return True
120
+
121
+ if mode == "all":
122
+ return False
123
+
124
+ user_input = input("发现解,输入 ';' 并回车继续本次推理;输入其他任意键并回车将终止本次推理: ").strip() # TODO: input不太适合后续扩展
125
+ if user_input != ';':
126
+ logger.info("用户选择终止推理。")
127
+ return True
128
+ return False
129
+
130
+ def get_all_solutions(self) -> list[Mapping[Variable, Constant | CompoundTerm]]:
131
+ """返回所有已收集的解映射(如果 store_solutions=False,返回空列表)"""
132
+ return self.collected_solutions
133
+
134
+ def get_solution_count(self) -> int:
135
+ """返回找到的解数量"""
136
+ return self._solution_count
137
+
138
+
139
+ class StatusChecker:
140
+ """通用状态检查逻辑"""
141
+
142
+ def __init__(self, equivalence: Equivalence, sk_system_handler: SankuManagementSystem):
143
+ self.equivalence = equivalence
144
+ self.sk_system_handler = sk_system_handler
145
+
146
+ def check_conflict(self, new_facts: list[FACT_TYPE]) -> bool:
147
+ """检查是否发生了矛盾。!!! 注意这里传入的是new_facts"""
148
+ return self._has_conflict_occurred(new_facts)
149
+
150
+ @staticmethod
151
+ def _has_conflict_occurred(new_facts: list[FACT_TYPE]) -> bool:
152
+ """检查是否发生了矛盾"""
153
+ # TODO: 实现矛盾检测逻辑
154
+ return False
155
+
156
+
157
+ class MainLoopManager:
158
+ """主循环管理器"""
159
+
160
+ def __init__(self, status_checker: StatusChecker, max_iterations: int = 300):
161
+ self.status_checker = status_checker
162
+ self.max_iterations = max_iterations
163
+ self._current_iteration = 0
164
+
165
+ self.normal_rule_activated: dict[Rule, bool] = {} # 如果某条rule上一轮有新事实生成,则True;反之False
166
+ self._true_count = -1
167
+
168
+ def check_status(self,
169
+ current_facts: list[FACT_TYPE],
170
+ question: Question) -> InferenceStatus:
171
+ """检查主循环状态"""
172
+ # 先检查迭代次数
173
+ if self._current_iteration >= self.max_iterations:
174
+ return InferenceStatus.MAX_ITERATIONS_REACHED
175
+
176
+ if self._true_count == 0:
177
+ return InferenceStatus.FIXPOINT_REACHED
178
+
179
+ if self._current_iteration == 0 and question.question and all(q in current_facts for q in question.question):
180
+ # FIXME: 这里的判断可能需要分的更细致
181
+ return InferenceStatus.SUCCESS
182
+
183
+ return InferenceStatus.CONTINUE
184
+
185
+ def next_iteration(self) -> None:
186
+ """进入下一轮迭代"""
187
+ self._current_iteration += 1
188
+
189
+ def reset(self) -> None:
190
+ """重置计数"""
191
+ self._current_iteration = 0
192
+
193
+ def initial_manager(self, normal_rules: Sequence[Rule] | None = None, *, resume: bool = False) -> None:
194
+ """
195
+ 修改current rules, 为当前一个question的推理做准备
196
+ :raise ValueError: 如果待推理问题不变,仅中止引擎并重新推理时,认为不需要修改rules(其他各处也应调整)
197
+ """ # noqa: DOC501
198
+ if normal_rules is None:
199
+ if resume:
200
+ self.normal_rule_activated = dict.fromkeys(self.normal_rule_activated, True) # 继续推理当前时,使用之前存储好的rule
201
+ # 但是清空之前的_activated记录(因为可能有新事实变化)
202
+ else:
203
+ raise ValueError("normal_rules is None")
204
+ else:
205
+ self.normal_rule_activated = dict.fromkeys(normal_rules, True) # 如果某条rule上一轮有新事实生成,则True;反之False
206
+ self._true_count = len(self.normal_rule_activated)
207
+
208
+ @property
209
+ def iteration(self) -> int:
210
+ """获取当前迭代次数"""
211
+ return self._current_iteration
212
+
213
+ def update_normal_rule_activation(self, new_facts: list[FACT_TYPE], used_rule: Rule) -> None:
214
+ """每条规则完成推理后,使用本函数更新main_manager,用于判断是否所有的规则同时达到了不动点"""
215
+ activated = bool(new_facts)
216
+
217
+ old_value = self.normal_rule_activated.get(used_rule)
218
+ self.normal_rule_activated[used_rule] = activated
219
+
220
+ if old_value and not activated:
221
+ self._true_count -= 1
222
+ elif not old_value and activated:
223
+ self._true_count += 1
224
+
225
+ def is_at_fixpoint(self) -> bool:
226
+ """是否所有 normal rule 都已不再产生新事实(不动点,仅看 normal rules)"""
227
+ return self._true_count == 0
228
+
229
+
230
+ class ExecutorManager:
231
+ """执行器管理器"""
232
+
233
+ def __init__(self, status_checker: StatusChecker, solution_manager: QuerySolutionManager,
234
+ max_steps: int = 1000):
235
+ self.status_checker = status_checker
236
+ self.max_steps = max_steps
237
+ self._current_step = 1
238
+ self.solution_manager = solution_manager
239
+
240
+ def check_status(self, new_facts: list[FACT_TYPE], question: Question,
241
+ solutions: list[Mapping[Variable, Constant | CompoundTerm]],
242
+ question_rule: _QuestionRule | None) -> InferenceStatus:
243
+ """检查执行器状态"""
244
+ if self.max_steps != -1 and self._current_step >= self.max_steps:
245
+ return InferenceStatus.MAX_STEPS_REACHED # 有解是success,且需要解
246
+
247
+ # 检查是否有矛盾
248
+ if self.status_checker.check_conflict(new_facts): # TODO: 冲突时同步清理 solutions
249
+ return InferenceStatus.CONFLICT_DETECTED
250
+
251
+ if question_rule is not None and solutions:
252
+ terminate_sign = self.solution_manager.print_and_collect(question, solutions, question_rule)
253
+ if terminate_sign:
254
+ return InferenceStatus.SUCCESS
255
+
256
+ return InferenceStatus.CONTINUE
257
+
258
+ def next_step(self) -> None:
259
+ """执行下一步"""
260
+ self._current_step += 1
261
+
262
+ def reset_for_new_inference(self) -> None:
263
+ """为新的推理过程重置步数计数(仅在开始全新推理时调用)"""
264
+ self._current_step = 1
265
+
266
+ @property
267
+ def step_num(self) -> int:
268
+ """获取当前步数"""
269
+ return self._current_step
270
+
271
+
272
+ def create_main_loop_manager(equivalence: Equivalence,
273
+ sk_system_handler: SankuManagementSystem,
274
+ max_iterations: int = 300) -> MainLoopManager:
275
+ """创建主循环管理器"""
276
+ status_checker = StatusChecker(equivalence, sk_system_handler)
277
+ return MainLoopManager(status_checker, max_iterations)
278
+
279
+
280
+ def create_executor_manager(equivalence: Equivalence,
281
+ sk_system_handler: SankuManagementSystem,
282
+ solution_manager: QuerySolutionManager,
283
+ max_steps: int = 1000) -> ExecutorManager:
284
+ """创建执行器管理器"""
285
+ status_checker = StatusChecker(equivalence, sk_system_handler)
286
+ return ExecutorManager(status_checker, solution_manager, max_steps)
kele/egg_equiv.pyd ADDED
Binary file
kele/egg_equiv.pyi ADDED
@@ -0,0 +1,11 @@
1
+ from kele.syntax import TERM_TYPE
2
+ from kele.syntax import Constant, CompoundTerm
3
+
4
+ class EggEquivalence:
5
+ def __init__(self, trace: bool) -> None: ...
6
+ def add_to_equiv(self, lhs: TERM_TYPE, rhs: TERM_TYPE) -> None: ...
7
+ def query_equivalence(self, term_l: TERM_TYPE, term_r: TERM_TYPE) -> bool: ...
8
+ def get_represent_elem(self, term: Constant | CompoundTerm) -> Constant | CompoundTerm[Constant | CompoundTerm]: ...
9
+ def get_equiv_elem(self, term: TERM_TYPE) -> list[TERM_TYPE]: ...
10
+ def rebuild_egraph(self) -> None: ...
11
+ def clear(self) -> None: ...
@@ -0,0 +1,8 @@
1
+ ## 实现功能:等价类的查询
2
+ 在文件夹equiv_class当中实现了等价类的查询功能:
3
+ 1. equiv_class.equivalence.Equivalence类,这个类中一共有两个对外的方法,分别为:
4
+ 1. update_equiv_class:这个方法用于更新等价类,调用时请传入类型为:list[Assertion | Formula]的列表。此方法没有返回值
5
+
6
+ 2. query_equivalence:这个方法用于查询一系列事实,调用时请传入Assertion或者list[Assertion]。此方法将返回一个列表,列表中为每个Assertion返回一个对应的list[bool]
7
+
8
+ 2. 在test文件夹中添加了少量单例测试,在每次修改代码后使用pytest测试可以初步确定是否有误
@@ -0,0 +1,4 @@
1
+ """用于快速判断两个term是否等价的模块"""
2
+ from .equivalence import Equivalence
3
+
4
+ __all__ = ["Equivalence"]
@@ -0,0 +1,267 @@
1
+ use std::{collections::HashSet};
2
+
3
+ use egg::{define_language, EGraph, Id, Symbol};
4
+ use pyo3::{exceptions::*, prelude::*};
5
+ use std::collections::HashMap;
6
+ use std::hash::{Hash, Hasher};
7
+
8
+ define_language! {
9
+ pub enum EngineElem {
10
+ Operator(Symbol),
11
+ Constant(Symbol),
12
+ "CompoundTerm" = CompoundTerm(Box<[Id]>),
13
+ // 要求CompoundTerm的第一个元素是Operator
14
+ }
15
+ }
16
+
17
+
18
+ // 用于储存python object,其中实现了hash和eq,调用的是python的hash和eq
19
+ struct PyObjKey {
20
+ obj: Py<PyAny>, // 储存的时候使用Py<PyAny>的绑定版本,调用的时候需要unbind
21
+ py_hash: isize, // 缓存 Python hash
22
+ }
23
+
24
+ impl PyObjKey {
25
+ fn new(py: Python<'_>, obj: Py<PyAny>) -> PyResult<Self> {
26
+ let bound = obj.into_bound(py);
27
+ let h = bound.hash()?; // 调用 Python __hash__
28
+ Ok(Self {
29
+ obj: bound.unbind(), // 拿到可存储的 Py<PyAny>
30
+ py_hash: h,
31
+ })
32
+ }
33
+ }
34
+
35
+ impl Clone for PyObjKey {
36
+ fn clone(&self) -> Self {
37
+ Python::attach(|py| Self {
38
+ obj: self.obj.clone_ref(py),
39
+ py_hash: self.py_hash,
40
+ })
41
+ }
42
+ }
43
+
44
+ impl PartialEq for PyObjKey {
45
+ fn eq(&self, other: &Self) -> bool {
46
+ Python::attach(|py| {
47
+ let lhs = self.obj.bind(py);
48
+ let rhs = other.obj.bind(py);
49
+ match lhs.eq(&rhs) { // 调用 Python __eq__
50
+ Ok(b) => b,
51
+ Err(_) => false,
52
+ }
53
+ })
54
+ }
55
+ }
56
+ impl Eq for PyObjKey {}
57
+
58
+ impl Hash for PyObjKey {
59
+ fn hash<H: Hasher>(&self, state: &mut H) {
60
+ self.py_hash.hash(state); // 使用缓存的 Python hash
61
+ }
62
+ }
63
+
64
+
65
+ #[pyclass]
66
+ pub struct EggEquivalence {
67
+ egraph: EGraph<EngineElem, ()>,
68
+
69
+ /// 1. 每个 e-node (EngineElem) 出现过的所有 SyntaxType 写法
70
+ node_terms: HashMap<EngineElem, HashSet<PyObjKey>>,
71
+
72
+ /// 2. 每个 e-node 的“代表元”:第一次见到它时绑定的 SyntaxType
73
+ node_rep: HashMap<EngineElem, Py<PyAny>>,
74
+
75
+ trace: bool,
76
+ }
77
+
78
+ impl EggEquivalence {
79
+ pub fn new(trace: bool) -> Self {
80
+ let base: EGraph<EngineElem, ()> = EGraph::default();
81
+
82
+ // 按照 trace 决定是否启用 explanations
83
+ let egraph = if trace {
84
+ base.with_explanations_enabled()
85
+ } else {
86
+ base
87
+ };
88
+
89
+ Self {
90
+ egraph,
91
+ node_terms: HashMap::new(),
92
+ node_rep: HashMap::new(),
93
+ trace: trace,
94
+ }
95
+ }
96
+ /// 把某个 SyntaxType 的写法记到对应的 EngineElem 底下
97
+ /// 如果engineElem已经存在,那么就加入到对应的hashmap,否则创建新的hashmap并注册engineElem
98
+ fn record_term(&mut self, elem: &EngineElem, term: Bound<'_, PyAny>) {
99
+ // 所有写法都收集起来,给 get_equiv_elem 用
100
+ Python::attach(|py| {
101
+ self.node_terms
102
+ .entry(elem.clone())
103
+ .or_insert_with(HashSet::new)
104
+ .insert(PyObjKey::new(py, term.clone().unbind()).unwrap());
105
+ });
106
+
107
+ // 代表元:第一次见到这个 EngineElem 时绑定,不再修改。
108
+ // 之所以有这个问题是因为不同syntaxtype可能对应同一个engineelem,因此需要这样记录
109
+ self.node_rep
110
+ .entry(elem.clone())
111
+ .or_insert_with(|| term.unbind());
112
+ }
113
+ fn add_term_to_egraph(&mut self, py: Python<'_>, obj: Bound<'_, PyAny>) -> PyResult<Id> {
114
+ // 将一个用dict形式表示的term添加到egraph中,并获取它的Id。如果它已经存在,那么直接获取Id
115
+ let type_name = obj.get_type().name()?;
116
+
117
+ if type_name == "Constant"{
118
+ let symbol_obj = obj.getattr("symbol")?;
119
+ let value = symbol_obj.str()?.to_str()?.to_owned();
120
+ let belong_obj = obj.getattr("belong_concepts")?;
121
+ let belong_concepts = belong_obj.str()?.to_str()?.to_owned();
122
+ let constant_name = value + "::" + &belong_concepts; // TODO: 一些地方有内存优化的余地
123
+ let engine_elem = EngineElem::Constant(Symbol::from(constant_name.clone()));
124
+ if let Some(id) = self.egraph.lookup(engine_elem.clone()) {
125
+ self.record_term(&engine_elem, obj);
126
+ return Ok(id);
127
+ }
128
+ else{
129
+ self.record_term(&engine_elem, obj);
130
+ let id = self.egraph.add(engine_elem);
131
+ return Ok(id)
132
+ }
133
+ }
134
+ if type_name == "Operator"{
135
+ let name_obj = obj.getattr("name")?;
136
+ let value = name_obj.str()?.to_str()?.to_owned();
137
+ let operator_name = value.clone();
138
+ let engine_elem = EngineElem::Operator(Symbol::from(operator_name.clone()));
139
+ if let Some(id) = self.egraph.lookup(engine_elem.clone()) {
140
+ self.record_term(&engine_elem, obj);
141
+ return Ok(id);
142
+ }
143
+ else{
144
+ self.record_term(&engine_elem, obj);
145
+ let id = self.egraph.add(engine_elem);
146
+ return Ok(id)
147
+ }
148
+ }
149
+ if type_name == "CompoundTerm" || type_name == "FlatCompoundTerm"{
150
+ let operator_obj = obj.getattr("operator").unwrap();
151
+ let arguments_obj = obj.getattr("arguments").unwrap();
152
+ let operator_id = self.add_term_to_egraph(py, operator_obj)?;
153
+ let arguments_len = arguments_obj.len().unwrap();
154
+ let mut children_ids: Vec<Id> = Vec::new();
155
+ children_ids.push(operator_id); // CompoundTerm的第一个元素是Operator
156
+ for argument in 0..arguments_len{
157
+ let arg_obj = arguments_obj.get_item(argument).unwrap();
158
+ let arg_id = self.add_term_to_egraph(py, arg_obj)?;
159
+ children_ids.push(arg_id);
160
+ }
161
+ let engine_elem = EngineElem::CompoundTerm(children_ids.into_boxed_slice());
162
+ if let Some(id) = self.egraph.lookup(engine_elem.clone()) {
163
+ self.record_term(&engine_elem, obj);
164
+ return Ok(id);
165
+ }
166
+ else{
167
+ self.record_term(&engine_elem, obj);
168
+ let id = self.egraph.add(engine_elem);
169
+ return Ok(id)
170
+ }
171
+ }
172
+ return Err(PyErr::new::<PyTypeError, _>(format!("Unsupported type: {}", type_name)));
173
+ }
174
+ fn reset(&mut self) {
175
+ // 重新创建一个空的 egraph,并判断是否启用explanations
176
+ let base: EGraph<EngineElem, ()> = EGraph::default();
177
+
178
+ // 按照 trace 决定是否启用 explanations
179
+ self.egraph = if self.trace {
180
+ base.with_explanations_enabled()
181
+ } else {
182
+ base
183
+ };
184
+
185
+ // 清空辅助映射
186
+ self.node_terms.clear();
187
+ self.node_rep.clear();
188
+ }
189
+ }
190
+
191
+
192
+ #[pymethods]
193
+ impl EggEquivalence {
194
+ #[new]
195
+ fn py_new(trace: bool) -> Self {
196
+ EggEquivalence::new(trace)
197
+ }
198
+ pub fn add_to_equiv(&mut self, py: Python<'_>, lhs: Bound<'_, PyAny>, rhs: Bound<'_, PyAny>){
199
+ // 将lhs==rhs这个事实添加到egraph中
200
+ let lhs_id = self.add_term_to_egraph(py, lhs).unwrap();
201
+ let rhs_id = self.add_term_to_egraph(py, rhs).unwrap();
202
+
203
+ // 如果启用了 trace,则使用带标签的 union 以便后续生成解释
204
+ if self.trace {
205
+ self.egraph.union_trusted(lhs_id, rhs_id, "input");
206
+ } else {
207
+ self.egraph.union(lhs_id, rhs_id);
208
+ }
209
+ }
210
+ pub fn query_equivalence(&mut self, py: Python<'_>, term_l: Bound<'_, PyAny>, term_r: Bound<'_, PyAny>) -> bool {
211
+ // 查询term_l是否等于term_r
212
+ let lhs_id = self.add_term_to_egraph(py, term_l).unwrap();
213
+ let rhs_id = self.add_term_to_egraph(py, term_r).unwrap();
214
+
215
+ self.rebuild_egraph(); // 可能发生了add,需要重建
216
+
217
+ return self.egraph.find(lhs_id) == self.egraph.find(rhs_id);
218
+ }
219
+ pub fn get_represent_id(&mut self, py: Python<'_>, term: Py<PyAny>) -> String {
220
+ // 获取等价类的Id
221
+ let term_id = self.add_term_to_egraph(py, term.bind(py).clone()).unwrap(); // 将term加入图中从而获得Id
222
+ self.rebuild_egraph(); // 可能发生了add,需要重建
223
+ let class_id = self.egraph.find(term_id); // 获取e-class的ID
224
+
225
+ // 返回 e-class ID 的 index
226
+ class_id.to_string() // 返回 e-class ID 作为 usize
227
+ }
228
+ pub fn get_represent_elem(&mut self, py: Python<'_>, term: Py<PyAny>) -> Py<PyAny> {
229
+ // 获取等价类的代表元的ID,注意这是一个元素的Id
230
+ let term_id = self.add_term_to_egraph(py, term.bind(py).clone()).unwrap(); // 将term加入图中从而获得Id
231
+ self.rebuild_egraph(); // 可能发生了add,需要重建
232
+ let class_id = self.egraph.find(term_id);
233
+ let result_elem = self.egraph.id_to_node(class_id);
234
+ self.node_rep.get(result_elem).unwrap().clone_ref(py)
235
+ }
236
+ pub fn get_equiv_elem(&mut self, py: Python<'_>, term: Py<PyAny>) -> Vec<Py<PyAny>> {
237
+ // 获得与一个元素等价的所有元素,返回的是元素的ID列表
238
+ let term_id = self.add_term_to_egraph(py, term.bind(py).clone()).unwrap(); // 将term加入图中从而获得Id
239
+ self.rebuild_egraph(); // 可能发生了add,需要重建
240
+ let class_id = self.egraph.find(term_id); // 获得标准Id
241
+ let eclass = &self.egraph[class_id]; // 获得对应的eclass
242
+ let mut elem_list: Vec<Py<PyAny>> = Vec::new();
243
+ for node in &eclass.nodes {
244
+ if let Some(set) = self.node_terms.get(node) {
245
+ for term_key in set {
246
+ elem_list.push(term_key.obj.clone_ref(py));
247
+ }
248
+ }
249
+ }
250
+ elem_list
251
+ }
252
+ pub fn rebuild_egraph(&mut self) {
253
+ // 重建egraph,这个操作通常在进行查询之前调用,维护graph的同余闭包
254
+ if !self.egraph.clean {
255
+ self.egraph.rebuild();
256
+ }
257
+ }
258
+ pub fn clear(&mut self) {
259
+ self.reset();
260
+ }
261
+ }
262
+
263
+ #[pymodule]
264
+ fn egg_equiv(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
265
+ m.add_class::<EggEquivalence>()?;
266
+ Ok(())
267
+ }
@@ -0,0 +1,67 @@
1
+ from kele.syntax.base_classes import Constant, CompoundTerm, Concept, Operator, TERM_TYPE, Variable
2
+ from functools import singledispatch
3
+
4
+
5
+ class EquivElem:
6
+ """等价类中的一个元素"""
7
+
8
+ def __init__(self, content: TERM_TYPE) -> None:
9
+ self.content = content
10
+ self._hash = get_hash(self.content) # 缓存hash值,避免重复计算
11
+
12
+ def __hash__(self) -> int:
13
+ return self._hash
14
+
15
+ def __eq__(self, other: object) -> bool:
16
+ if isinstance(other, EquivElem):
17
+ return self.content == other.content
18
+ return False
19
+
20
+
21
+ @singledispatch
22
+ def get_hash(content: Constant | CompoundTerm | Concept | Operator | Variable) -> int:
23
+ """
24
+ 使用singledispatchmethod是为了在后续添加新的类型的时候,能更加便捷地修改
25
+ """
26
+ return hash(content)
27
+
28
+
29
+ @get_hash.register(Constant)
30
+ def _(content: Constant) -> int:
31
+ """
32
+ 取得hash值的时候,要考虑到Constant的value, name三个属性
33
+ value的类型是string,因此直接返回hash值,无需递归
34
+ """
35
+ return hash((content.symbol, content.belong_concepts))
36
+
37
+
38
+ @get_hash.register(CompoundTerm)
39
+ def _(content: CompoundTerm) -> int:
40
+ """
41
+ CompoundTerm的hash值应当考虑到Term的operator和所有的variable
42
+ """
43
+ return hash((content.operator, content.arguments))
44
+
45
+
46
+ @get_hash.register(Concept)
47
+ def _(content: Concept) -> int:
48
+ """
49
+ Concept的hash值应当考虑到Concept的name,name的类型为string,应当直接返回hash值
50
+ """
51
+ return hash(content.name)
52
+
53
+
54
+ @get_hash.register(Operator)
55
+ def _(content: Operator) -> int:
56
+ """
57
+ Operator的hash值暂时只需要考虑name,类型为string。我们暂时不允许name相同但Operator不同的情况
58
+ """
59
+ return hash(content.name)
60
+
61
+
62
+ @get_hash.register(Variable)
63
+ def _(content: Variable) -> int:
64
+ """
65
+ Operator的hash值暂时只需要考虑name,类型为string。我们暂时不允许name相同但Operator不同的情况
66
+ """
67
+ return hash(content.symbol)
@@ -0,0 +1,36 @@
1
+ from kele.syntax.base_classes import Assertion
2
+ from ._equiv_elem import EquivElem
3
+ import warnings
4
+ from functools import singledispatch
5
+
6
+
7
+ @singledispatch
8
+ def fact_validator(item: Assertion | tuple[EquivElem, EquivElem]) -> bool:
9
+ """
10
+ 用于验证fact的合法性,分别对Assertion和tuple两种更新方式验证是否包含非法元素
11
+ 有其他的验证条件,也应当在这个函数实现
12
+
13
+ :param item: 一个fact,可能是Assertion或者tuple
14
+ :return: 如果fact合法,返回True;否则,返回False
15
+ """
16
+ warnings.warn(f"Should not update facts using {type(item)}.", stacklevel=5)
17
+ return False
18
+
19
+
20
+ @fact_validator.register(Assertion)
21
+ def _(item: Assertion) -> bool:
22
+ # 实际上这用来判断非法情形:断言中出现True/False这样的非法形式
23
+ # 这里将在后续解决:实际上Assertion左右放入False/True是非法的,True/False将以Constant/Concept的方式出现,需要单独的判断方法
24
+ return True
25
+
26
+
27
+ @fact_validator.register(tuple)
28
+ def _(item: tuple[EquivElem, EquivElem]) -> bool:
29
+ length_of_legal_assertion = 2
30
+ if not isinstance(item, tuple) or len(item) != length_of_legal_assertion:
31
+ warnings.warn(f"Invalid equivalence relation format: {item}; expected a 2-tuple.", stacklevel=2)
32
+ return False
33
+ # 后续还将有一个if
34
+ # 实际上这用来判断非法情形:断言中出现True/False这样的非法形式
35
+ # 这里将在后续解决:实际上Assertion左右放入False/True是非法的,True/False将以Constant/Concept的方式出现,需要单独的判断方法
36
+ return True