kele 0.0.1a1__cp313-cp313-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. kele/__init__.py +38 -0
  2. kele/_version.py +1 -0
  3. kele/config.py +243 -0
  4. kele/control/README_metrics.md +102 -0
  5. kele/control/__init__.py +20 -0
  6. kele/control/callback.py +255 -0
  7. kele/control/grounding_selector/__init__.py +5 -0
  8. kele/control/grounding_selector/_rule_strategies/README.md +13 -0
  9. kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
  10. kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
  11. kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
  12. kele/control/grounding_selector/_selector_utils.py +123 -0
  13. kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
  14. kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
  15. kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
  16. kele/control/grounding_selector/rule_selector.py +98 -0
  17. kele/control/grounding_selector/term_selector.py +89 -0
  18. kele/control/infer_path.py +306 -0
  19. kele/control/metrics.py +357 -0
  20. kele/control/status.py +286 -0
  21. kele/egg_equiv.pyd +0 -0
  22. kele/egg_equiv.pyi +11 -0
  23. kele/equality/README.md +8 -0
  24. kele/equality/__init__.py +4 -0
  25. kele/equality/_egg_equiv/src/lib.rs +267 -0
  26. kele/equality/_equiv_elem.py +67 -0
  27. kele/equality/_utils.py +36 -0
  28. kele/equality/equivalence.py +141 -0
  29. kele/executer/__init__.py +4 -0
  30. kele/executer/executing.py +139 -0
  31. kele/grounder/README.md +83 -0
  32. kele/grounder/__init__.py +17 -0
  33. kele/grounder/grounded_rule_ds/__init__.py +6 -0
  34. kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
  35. kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
  36. kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
  37. kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
  38. kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
  39. kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
  40. kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
  41. kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
  42. kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
  43. kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
  44. kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
  45. kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
  46. kele/grounder/grounded_rule_ds/rule_check.py +373 -0
  47. kele/grounder/grounding.py +118 -0
  48. kele/knowledge_bases/README.md +112 -0
  49. kele/knowledge_bases/__init__.py +6 -0
  50. kele/knowledge_bases/builtin_base/__init__.py +1 -0
  51. kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
  52. kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
  53. kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
  54. kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
  55. kele/knowledge_bases/fact_base.py +158 -0
  56. kele/knowledge_bases/ontology_base.py +67 -0
  57. kele/knowledge_bases/rule_base.py +194 -0
  58. kele/main.py +464 -0
  59. kele/py.typed +0 -0
  60. kele/syntax/CONCEPT_README.md +117 -0
  61. kele/syntax/__init__.py +40 -0
  62. kele/syntax/_cnf_converter.py +161 -0
  63. kele/syntax/_sat_solver.py +116 -0
  64. kele/syntax/base_classes.py +1482 -0
  65. kele/syntax/connectives.py +20 -0
  66. kele/syntax/dnf_converter.py +145 -0
  67. kele/syntax/external.py +17 -0
  68. kele/syntax/sub_concept.py +87 -0
  69. kele/syntax/syntacticsugar.py +201 -0
  70. kele-0.0.1a1.dist-info/METADATA +166 -0
  71. kele-0.0.1a1.dist-info/RECORD +74 -0
  72. kele-0.0.1a1.dist-info/WHEEL +4 -0
  73. kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
  74. kele-0.0.1a1.dist-info/licenses/licensecheck.json +20 -0
@@ -0,0 +1,24 @@
1
+ # 导入所有rule strategies
2
+ import importlib
3
+ import logging
4
+ import pathlib
5
+
6
+ from .strategy_protocol import get_strategy_class
7
+
8
+ current_dir = pathlib.Path(__file__).resolve().parent
9
+ package_name = __package__ or current_dir.name
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logger.setLevel(logging.WARNING)
13
+
14
+ for filename in current_dir.iterdir():
15
+ if filename.suffix == '.py' and filename.stem.endswith('_strategy') and filename.stem.startswith('_'):
16
+ module_name = filename.stem
17
+ logger.info('successfully imported module: "%s"', module_name)
18
+ try:
19
+ module = importlib.import_module(f'{package_name}.{module_name}')
20
+ except ImportError:
21
+ logger.exception('Failed to import %s', module_name)
22
+ continue
23
+
24
+ __all__ = ["get_strategy_class"]
@@ -0,0 +1,42 @@
1
+ from collections.abc import Sequence
2
+
3
+ from .strategy_protocol import Feedback, register_strategy, RuleSelectionStrategy
4
+ from kele.syntax import Rule
5
+
6
+
7
+ @register_strategy('SequentialCyclic')
8
+ class SequentialCyclicStrategy(RuleSelectionStrategy):
9
+ """
10
+ 按顺序循环遍历策略:
11
+ r0, r1, ..., rN-1, r0, r1, ...
12
+ """
13
+ def __init__(self) -> None:
14
+ self._idx: int = 0
15
+
16
+ def set_rules(self, rules: Sequence[Rule]) -> None:
17
+ self._rules = list(rules)
18
+ if not self._rules:
19
+ raise ValueError("rules cannot be empty")
20
+ self._idx = 0
21
+
22
+ def reset(self) -> None:
23
+ self._idx = 0
24
+
25
+ def select_next(self) -> Sequence[Rule]:
26
+ # 循环顺序取下一条
27
+ r = self._rules[self._idx]
28
+ self._idx = (self._idx + 1) % len(self._rules)
29
+ return [r]
30
+
31
+ def on_feedback(self, feedback: Feedback) -> None: # noqa: PLR6301 # 尚未实现,不需要转static
32
+ # 顺序循环策略不依赖反馈,空实现即可
33
+ return
34
+
35
+
36
+ @register_strategy("SequentialCyclicWithPriority")
37
+ class SequentialCyclicWithPriorityStrategy(SequentialCyclicStrategy):
38
+ """将规则按优先级排序,优先级高的先取"""
39
+
40
+ def set_rules(self, rules: Sequence[Rule]) -> None:
41
+ rules = sorted(rules, key=lambda r: r.priority, reverse=True)
42
+ super().set_rules(rules)
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Protocol, runtime_checkable, TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import Sequence
7
+ from kele.syntax import Rule
8
+ from collections.abc import Callable
9
+
10
+
11
+ @runtime_checkable
12
+ class RuleSelectionStrategy(Protocol):
13
+ """
14
+ 选取策略的统一接口。允许根据需求返回任意规则。
15
+ """
16
+ def __init__(self) -> None: ...
17
+ def set_rules(self, rules: Sequence[Rule]) -> None: ...
18
+ def reset(self) -> None: ...
19
+ def select_next(self) -> Sequence[Rule]: ...
20
+ def on_feedback(self, feedback: Feedback) -> None: ... # 给策略回传一次选择后的反馈
21
+
22
+
23
+ @dataclass
24
+ class Feedback:
25
+ """一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
26
+ rule: Rule | None = None # 这次反馈关联到的规则;hack: 后面增加其他的相关信息,可能有facts等等
27
+
28
+
29
+ rule_strategy_registry: dict[str, type[RuleSelectionStrategy]] = {}
30
+
31
+
32
+ def register_strategy(name: str) -> Callable[[type[RuleSelectionStrategy]], type[RuleSelectionStrategy]]:
33
+ """装饰器:将策略类注册到全局表中"""
34
+ def decorator(cls: type[RuleSelectionStrategy]) -> type[RuleSelectionStrategy]:
35
+ rule_strategy_registry[name] = cls
36
+ return cls
37
+ return decorator
38
+
39
+
40
+ def get_strategy_class(name: str) -> type[RuleSelectionStrategy]:
41
+ try:
42
+ return rule_strategy_registry[name]
43
+ except KeyError as err:
44
+ raise KeyError(
45
+ f"Rule selection strategy '{name}' is not registered. Available strategies: "
46
+ f"{list(rule_strategy_registry.keys())}"
47
+ ) from err
48
+
49
+
50
+ def _available_strategies() -> list[str]:
51
+ return list(rule_strategy_registry.keys())
@@ -0,0 +1,123 @@
1
+ import itertools
2
+ import warnings
3
+ from collections.abc import Sequence
4
+
5
+ from kele.knowledge_bases.builtin_base.builtin_concepts import FREEVARANY_CONCEPT
6
+ from kele.syntax import (FACT_TYPE, TERM_TYPE, Assertion, Formula, CompoundTerm,
7
+ Constant, Variable, ATOM_TYPE, Rule)
8
+ from functools import singledispatch
9
+
10
+
11
+ @singledispatch
12
+ def _unify_all_terms(fact: FACT_TYPE | TERM_TYPE) -> tuple[CompoundTerm | Constant, ...]:
13
+ """
14
+ 主要是将作为formula的fact拆开成Assertion用的,对于单个的Assertion,我们拆成TERMTYPE,传入其他函数处理
15
+ 这里直接拆到FlatCompoundTerm方便一些
16
+ """
17
+ return ()
18
+
19
+
20
+ @_unify_all_terms.register(Assertion)
21
+ def _(fact: Assertion) -> tuple[CompoundTerm | Constant, ...]:
22
+ return _unify_all_terms(fact.lhs) + _unify_all_terms(fact.rhs)
23
+
24
+
25
+ @_unify_all_terms.register(Formula)
26
+ def _(fact: Formula) -> tuple[CompoundTerm | Constant, ...]:
27
+ tuple_left = _unify_all_terms(fact.formula_left)
28
+ tuple_right = _unify_all_terms(fact.formula_right) if fact.formula_right is not None else ()
29
+ return tuple_right + tuple_left
30
+
31
+
32
+ @_unify_all_terms.register(CompoundTerm)
33
+ def _(fact: CompoundTerm) -> tuple[CompoundTerm | Constant, ...]:
34
+ return tuple(_split_all_terms(fact))
35
+
36
+
37
+ @_unify_all_terms.register(Constant)
38
+ def _(fact: Constant) -> tuple[CompoundTerm | Constant, ...]:
39
+ return (fact, )
40
+
41
+
42
+ @_unify_all_terms.register(Variable)
43
+ def _(fact: Variable) -> tuple[CompoundTerm | Constant, ...]:
44
+ warnings.warn("Variable should not exist in fact", stacklevel=2)
45
+ return ()
46
+
47
+
48
+ class FREEVARANY(Constant):
49
+ """
50
+ ANY标签,在free_variables用于占位,暂定为一种特殊的Constant。
51
+ 本引擎在flat term的level上进行grounding操作,即规则、事实中的fact都会被拆解到flat term层级进行匹配。因此nested term需要被拆解为多个
52
+ flat term完成,并且nested term的arguments里的Term类型的值,需要被替换为通配符,即FREEVARANY类
53
+ """
54
+
55
+ def __init__(self, value: str) -> None:
56
+ concept = FREEVARANY_CONCEPT
57
+ super().__init__(value, concept)
58
+
59
+
60
+ FREEANY = FREEVARANY('FREEVARANY')
61
+
62
+
63
+ def _split_all_terms(term: CompoundTerm) -> list[CompoundTerm]:
64
+ """
65
+ 这个函能将CompoundTerm拆分成FlatCompoundTerm,并且返回一个list
66
+ 返回的过程中,除非所有的arguments都是Term,否则都会标记之后生成FlatCompoundTerm
67
+ 这个函数只应当用在Fact当中,因为它不处理含有Variable的情况
68
+ """
69
+ # NOTE: 常量在 _unify_all_terms 层面单独处理,这里仅拆分 CompoundTerm
70
+ split_terms: list[CompoundTerm] = []
71
+ split_terms.append(term) # 将一个复合的term中的复合子结构取出来
72
+
73
+ for var in term.arguments:
74
+ if isinstance(var, CompoundTerm):
75
+ split_terms.extend(_split_all_terms(var))
76
+ return split_terms
77
+
78
+
79
+ def flatten_arguments(arguments: Sequence[TERM_TYPE]) -> tuple[ATOM_TYPE, ...]: # 暂时先作为对外函数,
80
+ # 另外这一页的singledispatch按说可以改成正常的if,以获得更清晰的阅读体验(比如把class丢最上面)
81
+ """
82
+ 给定一个term,这个函数会将term的arguments中的所有Term替换为$F
83
+ 无论在fact还是rule中这个函数都是可用的,因为是否存在variable并不影响这个函数的工作
84
+ """
85
+ return tuple(
86
+ FREEANY if isinstance(var, CompoundTerm) else var
87
+ for var in arguments
88
+ )
89
+
90
+
91
+ @singledispatch
92
+ def _unify_into_terms(fact: FACT_TYPE | TERM_TYPE) -> tuple[TERM_TYPE, ...]:
93
+ """
94
+ 主要是将作为formula的fact拆开成Assertion用的,对于单个的Assertion,我们拆成TERMTYPE,传入其他函数处理
95
+ 这里直接拆到FlatCompoundTerm方便一些
96
+ """
97
+ return ()
98
+
99
+
100
+ @_unify_into_terms.register(Assertion)
101
+ def _(fact: Assertion) -> tuple[TERM_TYPE, ...]:
102
+ return (fact.lhs, fact.rhs)
103
+
104
+
105
+ @_unify_into_terms.register(Formula)
106
+ def _(fact: Formula) -> tuple[TERM_TYPE, ...]:
107
+ tuple_left = _unify_all_terms(fact.formula_left)
108
+ tuple_right = _unify_all_terms(fact.formula_right) if fact.formula_right is not None else ()
109
+ return tuple_right + tuple_left
110
+
111
+
112
+ @_unify_into_terms.register(TERM_TYPE)
113
+ def _(fact: TERM_TYPE) -> tuple[TERM_TYPE, ...]:
114
+ return (fact, )
115
+
116
+
117
+ def _unify_ground_terms_from_rule(rule: Rule) -> tuple[TERM_TYPE, ...]:
118
+ terms = _unify_all_terms(rule.head) + _unify_all_terms(rule.body)
119
+ return tuple(term for term in terms if not term.free_variables)
120
+
121
+
122
+ def _unify_ground_terms_from_rules(rules: Sequence[Rule]) -> tuple[TERM_TYPE, ...]:
123
+ return tuple(itertools.chain.from_iterable(_unify_ground_terms_from_rule(rule) for rule in rules))
@@ -0,0 +1,24 @@
1
+ # 导入所有term strategies
2
+ import importlib
3
+ import logging
4
+ import pathlib
5
+
6
+ from .strategy_protocol import get_strategy_class
7
+
8
+ current_dir = pathlib.Path(__file__).resolve().parent
9
+ package_name = __package__ or current_dir.name
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logger.setLevel(logging.WARNING)
13
+
14
+ for filename in current_dir.iterdir():
15
+ if filename.suffix == '.py' and filename.stem.endswith('_strategy') and filename.stem.startswith('_'):
16
+ module_name = filename.stem
17
+ logger.info('successfully imported module: "%s"', module_name)
18
+ try:
19
+ module = importlib.import_module(f'{package_name}.{module_name}')
20
+ except ImportError:
21
+ logger.exception('Failed to import %s', module_name)
22
+ continue
23
+
24
+ __all__ = ["get_strategy_class"]
@@ -0,0 +1,34 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Sequence
3
+
4
+ from .strategy_protocol import Feedback, register_strategy, TermSelectionStrategy
5
+ from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule, TERM_TYPE
6
+
7
+
8
+ @register_strategy('Exhausted')
9
+ class ExhuastedStrategy(TermSelectionStrategy):
10
+ """
11
+ 每次选择剩余的所有terms
12
+ """
13
+ def __init__(self) -> None:
14
+ self._terms: list[GROUNDED_TYPE_FOR_UNIFICATION] = [] # _terms应当对所有的rule都成立,只是有的被用掉了。需要有去重的方案
15
+ self._rules_used_id: dict[Rule, int] = defaultdict(lambda: 0)
16
+
17
+ def add_terms(self, terms: Sequence[TERM_TYPE]) -> None:
18
+ self._terms.extend(terms)
19
+
20
+ def reset(self) -> None: # HACK: 这个还没用
21
+ self._terms.clear()
22
+ self._rules_used_id = defaultdict(lambda: 0)
23
+
24
+ def select_next(self, rule: Rule) -> Sequence[GROUNDED_TYPE_FOR_UNIFICATION]:
25
+ # 循环顺序取下一条
26
+ start_id = self._rules_used_id[rule]
27
+ selected_terms = list(self._terms[start_id:])
28
+ self._rules_used_id[rule] = len(self._terms)
29
+
30
+ return selected_terms
31
+
32
+ def on_feedback(self, feedback: Feedback) -> None: # noqa: PLR6301 # 尚未实现,不需要转static
33
+ # 顺序循环策略不依赖反馈,空实现即可
34
+ return
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Protocol, runtime_checkable, TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import Sequence
7
+ from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule, TERM_TYPE
8
+ from collections.abc import Callable
9
+
10
+
11
+ @runtime_checkable
12
+ class TermSelectionStrategy(Protocol):
13
+ """
14
+ 选取策略的统一接口。允许根据需求返回任意规则。
15
+ """
16
+ def __init__(self) -> None: ...
17
+ def add_terms(self, terms: Sequence[TERM_TYPE]) -> None: ...
18
+ def reset(self) -> None: ...
19
+ def select_next(self, rule: Rule) -> Sequence[GROUNDED_TYPE_FOR_UNIFICATION]: ...
20
+ def on_feedback(self, feedback: Feedback) -> None: ... # 给策略回传一次选择后的反馈
21
+
22
+
23
+ @dataclass
24
+ class Feedback:
25
+ """一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
26
+
27
+
28
+ term_strategy_registry: dict[str, type[TermSelectionStrategy]] = {}
29
+
30
+
31
+ def register_strategy(name: str) -> Callable[[type[TermSelectionStrategy]], type[TermSelectionStrategy]]:
32
+ """装饰器:将策略类注册到全局表中"""
33
+ def decorator(cls: type[TermSelectionStrategy]) -> type[TermSelectionStrategy]:
34
+ term_strategy_registry[name] = cls
35
+ return cls
36
+ return decorator
37
+
38
+
39
+ def get_strategy_class(name: str) -> type[TermSelectionStrategy]:
40
+ try:
41
+ return term_strategy_registry[name]
42
+ except KeyError as err:
43
+ raise KeyError(
44
+ f"Term selection strategy '{name}' is not registered. Available strategies: "
45
+ f"{list(term_strategy_registry.keys())}"
46
+ ) from err
47
+
48
+
49
+ def _available_strategies() -> list[str]:
50
+ return list(term_strategy_registry.keys())
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING
5
+ from ._rule_strategies import get_strategy_class
6
+
7
+ if TYPE_CHECKING:
8
+ from ._rule_strategies.strategy_protocol import RuleSelectionStrategy, Feedback
9
+ from collections.abc import Sequence
10
+ from kele.syntax import Rule, _QuestionRule
11
+
12
+
13
+ class GroundingRuleSelector:
14
+ """
15
+ 对外统一入口,内部委托给策略实现。
16
+ - 允许切换策略(接口不限制“连续取中”)
17
+ - 允许更新规则集合
18
+ """
19
+ def __init__(self, strategy: str = "sequential_cyclic", question_rule_interval: int = -1) -> None:
20
+ strategy_cls = get_strategy_class(strategy)
21
+ self._strategy = strategy_cls()
22
+ self._normal_rules: list[Rule] | None = None
23
+
24
+ self.question_rule_interval = question_rule_interval
25
+ self.used_rule_cnt: int = 0
26
+ self._question_rules: list[Rule] = []
27
+ self._at_fixpoint: bool = False
28
+
29
+ def set_at_fixpoint(self, *, at_fixpoint: bool) -> None:
30
+ """设置是否已经达到不动点状态"""
31
+ self._at_fixpoint = at_fixpoint
32
+
33
+ def next_rules(self) -> Sequence[Rule]:
34
+ """选择一定数量的规则用于grounding,在一定轮次后会查看一次question是否被解决
35
+ :raises ValueError: 当 question_rule_interval 小于1且不为-1时抛出。
36
+ """ # noqa: DOC501
37
+ if self._at_fixpoint:
38
+ if self._question_rules:
39
+ self.used_rule_cnt = 0
40
+ return self._question_rules
41
+ return []
42
+
43
+ # 确定实际的检查间隔
44
+ if self.question_rule_interval < 1 and self.question_rule_interval != -1:
45
+ raise ValueError(
46
+ "question_rule_interval must be >= 1, or -1 to use the total count of normal rules as the interval"
47
+ )
48
+
49
+ interval = self.question_rule_interval
50
+ if interval == -1:
51
+ normal_count = len(self._normal_rules) if self._normal_rules is not None else 0
52
+ interval = normal_count or 1
53
+
54
+ if self.used_rule_cnt >= interval and self._question_rules:
55
+ self.used_rule_cnt = 0 # 重置计数器
56
+ return self._question_rules
57
+
58
+ rules = self._strategy.select_next()
59
+ self.used_rule_cnt += len(rules)
60
+
61
+ return rules
62
+
63
+ def reset(self) -> None:
64
+ """重置选择器"""
65
+ self._strategy.reset()
66
+ self._normal_rules = None
67
+ self._question_rules = []
68
+ self.used_rule_cnt = 0
69
+ self._at_fixpoint = False
70
+
71
+ def set_strategy(self, strategy: RuleSelectionStrategy) -> None:
72
+ """切换策略实现(不重置调用方传入策略的内部状态,由策略自行决定)。"""
73
+ self._strategy = strategy
74
+
75
+ if self._normal_rules is not None:
76
+ self._strategy.set_rules(self._normal_rules)
77
+ else:
78
+ warnings.warn("No given rules, please call set_rules after calling set_strategy.", stacklevel=2)
79
+
80
+ def set_rules(self, normal_rules: Sequence[Rule], question_rules: Sequence[_QuestionRule]) -> None:
81
+ """
82
+ 更新规则集合,并同步给当前策略
83
+ :raise: ValueError: 没有可选rules时报错
84
+ """ # noqa: DOC501
85
+ if not normal_rules:
86
+ raise ValueError("rules cannot be empty")
87
+
88
+ self._normal_rules = list(normal_rules)
89
+ self._question_rules = list(question_rules)
90
+
91
+ self._strategy.set_rules(self._normal_rules)
92
+
93
+ # 重置计数器
94
+ self.used_rule_cnt = 0
95
+
96
+ def send_feedback(self, feedback: Feedback) -> None:
97
+ """把一次选择后的反馈转发给当前策略"""
98
+ self._strategy.on_feedback(feedback)
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ from typing import TYPE_CHECKING
5
+ from ._term_strategies import get_strategy_class
6
+ from kele.control.grounding_selector._selector_utils import (
7
+ _unify_ground_terms_from_rules,
8
+ _unify_into_terms,
9
+ )
10
+
11
+ if TYPE_CHECKING:
12
+ from kele.config import Config
13
+ from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule
14
+ from kele.equality import Equivalence
15
+ from ._term_strategies.strategy_protocol import Feedback
16
+ from collections.abc import Sequence
17
+ from kele.syntax import FACT_TYPE
18
+
19
+
20
+ class GroundingFlatTermWithWildCardSelector: # 此时是FlatTerm-level grounding,所以进strategy不止是Term。但是strategy本身是Term和
21
+ # FlatTerm均合适的,所以strategy自己的标准仍然得是TERM_TYPE。
22
+ # 此外,目前通配符和flat term无法分割,所以命名强调了wild card
23
+ """
24
+ 对外统一入口,内部委托给策略实现。
25
+ """
26
+ def __init__(self,
27
+ equivalence: Equivalence,
28
+ args: Config) -> None:
29
+
30
+ strategy_cls = get_strategy_class(args.strategy.grounding_term_strategy)
31
+ self._strategy = strategy_cls()
32
+ self._equivalence = equivalence
33
+ self._args = args
34
+
35
+ def next_terms(self, rule: Rule) -> list[GROUNDED_TYPE_FOR_UNIFICATION]:
36
+ """为给定规则选择候选事实/term 用于 grounding。"""
37
+ init_terms = list(self._strategy.select_next(rule))
38
+
39
+ if self._args.run.semi_eval_with_equality:
40
+ selected_terms: set[GROUNDED_TYPE_FOR_UNIFICATION] = set()
41
+
42
+ for t in init_terms:
43
+ if t not in selected_terms:
44
+ equiv_terms = self._equivalence.get_related_item(t)
45
+ selected_terms |= set(equiv_terms) # TODO: 这里后面考虑是否把_unify_into_flat_terms挪到term selector
46
+
47
+ # 在semi-evaluation中纳入等价类其余terms,这是由
48
+ # 断言逻辑带来的内嵌、强制等词公理,和term-level grounding二者共同决定的算法,因为不放置于具体的Strategy中,而是作为
49
+ # Selector自身的行为。
50
+ return list(selected_terms)
51
+
52
+ return list(set(init_terms))
53
+
54
+ def reset(self) -> None:
55
+ """重置选择器"""
56
+ self._strategy.reset()
57
+
58
+ def update_terms(self,
59
+ terms: Sequence[GROUNDED_TYPE_FOR_UNIFICATION] | None = None,
60
+ facts: Sequence[FACT_TYPE] | None = None) -> None:
61
+ """
62
+ 更新事实集合,并同步给当前策略。
63
+ :raise: ValueError: 没有可选terms_or_facts时报错
64
+ """ # noqa: DOC501
65
+
66
+ if terms:
67
+ fact_terms = terms
68
+ elif facts:
69
+ fact_terms = list(itertools.chain.from_iterable(_unify_into_terms(f) for f in facts)) # FIXME: 更进一步地,其实只有Assertion
70
+ # 能触发等式公理
71
+ else:
72
+ raise ValueError("terms or facts cannot be empty")
73
+
74
+ # 不应有action operator的term。fact_terms = [f for f in fact_terms if f.operator.implement_func is not None]
75
+ self._strategy.add_terms(fact_terms)
76
+
77
+ def update_terms_from_rules(self, rules: Sequence[Rule]) -> None:
78
+ """
79
+ 将规则/问题中无变量的term加入候选表。
80
+ """
81
+ rule_terms = _unify_ground_terms_from_rules(rules)
82
+ if not rule_terms:
83
+ return
84
+
85
+ self._strategy.add_terms(rule_terms)
86
+
87
+ def send_feedback(self, feedback: Feedback) -> None:
88
+ """把一次选择后的反馈转发给当前策略"""
89
+ self._strategy.on_feedback(feedback)