kele 0.0.1a1__cp313-cp313-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/__init__.py +38 -0
- kele/_version.py +1 -0
- kele/config.py +243 -0
- kele/control/README_metrics.md +102 -0
- kele/control/__init__.py +20 -0
- kele/control/callback.py +255 -0
- kele/control/grounding_selector/__init__.py +5 -0
- kele/control/grounding_selector/_rule_strategies/README.md +13 -0
- kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
- kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
- kele/control/grounding_selector/_selector_utils.py +123 -0
- kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
- kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
- kele/control/grounding_selector/rule_selector.py +98 -0
- kele/control/grounding_selector/term_selector.py +89 -0
- kele/control/infer_path.py +306 -0
- kele/control/metrics.py +357 -0
- kele/control/status.py +286 -0
- kele/egg_equiv.pyd +0 -0
- kele/egg_equiv.pyi +11 -0
- kele/equality/README.md +8 -0
- kele/equality/__init__.py +4 -0
- kele/equality/_egg_equiv/src/lib.rs +267 -0
- kele/equality/_equiv_elem.py +67 -0
- kele/equality/_utils.py +36 -0
- kele/equality/equivalence.py +141 -0
- kele/executer/__init__.py +4 -0
- kele/executer/executing.py +139 -0
- kele/grounder/README.md +83 -0
- kele/grounder/__init__.py +17 -0
- kele/grounder/grounded_rule_ds/__init__.py +6 -0
- kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
- kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
- kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
- kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
- kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
- kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
- kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
- kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
- kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
- kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
- kele/grounder/grounded_rule_ds/rule_check.py +373 -0
- kele/grounder/grounding.py +118 -0
- kele/knowledge_bases/README.md +112 -0
- kele/knowledge_bases/__init__.py +6 -0
- kele/knowledge_bases/builtin_base/__init__.py +1 -0
- kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
- kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
- kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
- kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
- kele/knowledge_bases/fact_base.py +158 -0
- kele/knowledge_bases/ontology_base.py +67 -0
- kele/knowledge_bases/rule_base.py +194 -0
- kele/main.py +464 -0
- kele/py.typed +0 -0
- kele/syntax/CONCEPT_README.md +117 -0
- kele/syntax/__init__.py +40 -0
- kele/syntax/_cnf_converter.py +161 -0
- kele/syntax/_sat_solver.py +116 -0
- kele/syntax/base_classes.py +1482 -0
- kele/syntax/connectives.py +20 -0
- kele/syntax/dnf_converter.py +145 -0
- kele/syntax/external.py +17 -0
- kele/syntax/sub_concept.py +87 -0
- kele/syntax/syntacticsugar.py +201 -0
- kele-0.0.1a1.dist-info/METADATA +166 -0
- kele-0.0.1a1.dist-info/RECORD +74 -0
- kele-0.0.1a1.dist-info/WHEEL +4 -0
- kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
- kele-0.0.1a1.dist-info/licenses/licensecheck.json +20 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# 导入所有rule strategies
|
|
2
|
+
import importlib
|
|
3
|
+
import logging
|
|
4
|
+
import pathlib
|
|
5
|
+
|
|
6
|
+
from .strategy_protocol import get_strategy_class
|
|
7
|
+
|
|
8
|
+
current_dir = pathlib.Path(__file__).resolve().parent
|
|
9
|
+
package_name = __package__ or current_dir.name
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
logger.setLevel(logging.WARNING)
|
|
13
|
+
|
|
14
|
+
for filename in current_dir.iterdir():
|
|
15
|
+
if filename.suffix == '.py' and filename.stem.endswith('_strategy') and filename.stem.startswith('_'):
|
|
16
|
+
module_name = filename.stem
|
|
17
|
+
logger.info('successfully imported module: "%s"', module_name)
|
|
18
|
+
try:
|
|
19
|
+
module = importlib.import_module(f'{package_name}.{module_name}')
|
|
20
|
+
except ImportError:
|
|
21
|
+
logger.exception('Failed to import %s', module_name)
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
__all__ = ["get_strategy_class"]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from .strategy_protocol import Feedback, register_strategy, RuleSelectionStrategy
|
|
4
|
+
from kele.syntax import Rule
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@register_strategy('SequentialCyclic')
|
|
8
|
+
class SequentialCyclicStrategy(RuleSelectionStrategy):
|
|
9
|
+
"""
|
|
10
|
+
按顺序循环遍历策略:
|
|
11
|
+
r0, r1, ..., rN-1, r0, r1, ...
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
self._idx: int = 0
|
|
15
|
+
|
|
16
|
+
def set_rules(self, rules: Sequence[Rule]) -> None:
|
|
17
|
+
self._rules = list(rules)
|
|
18
|
+
if not self._rules:
|
|
19
|
+
raise ValueError("rules cannot be empty")
|
|
20
|
+
self._idx = 0
|
|
21
|
+
|
|
22
|
+
def reset(self) -> None:
|
|
23
|
+
self._idx = 0
|
|
24
|
+
|
|
25
|
+
def select_next(self) -> Sequence[Rule]:
|
|
26
|
+
# 循环顺序取下一条
|
|
27
|
+
r = self._rules[self._idx]
|
|
28
|
+
self._idx = (self._idx + 1) % len(self._rules)
|
|
29
|
+
return [r]
|
|
30
|
+
|
|
31
|
+
def on_feedback(self, feedback: Feedback) -> None: # noqa: PLR6301 # 尚未实现,不需要转static
|
|
32
|
+
# 顺序循环策略不依赖反馈,空实现即可
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@register_strategy("SequentialCyclicWithPriority")
|
|
37
|
+
class SequentialCyclicWithPriorityStrategy(SequentialCyclicStrategy):
|
|
38
|
+
"""将规则按优先级排序,优先级高的先取"""
|
|
39
|
+
|
|
40
|
+
def set_rules(self, rules: Sequence[Rule]) -> None:
|
|
41
|
+
rules = sorted(rules, key=lambda r: r.priority, reverse=True)
|
|
42
|
+
super().set_rules(rules)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Protocol, runtime_checkable, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from kele.syntax import Rule
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class RuleSelectionStrategy(Protocol):
|
|
13
|
+
"""
|
|
14
|
+
选取策略的统一接口。允许根据需求返回任意规则。
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self) -> None: ...
|
|
17
|
+
def set_rules(self, rules: Sequence[Rule]) -> None: ...
|
|
18
|
+
def reset(self) -> None: ...
|
|
19
|
+
def select_next(self) -> Sequence[Rule]: ...
|
|
20
|
+
def on_feedback(self, feedback: Feedback) -> None: ... # 给策略回传一次选择后的反馈
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class Feedback:
|
|
25
|
+
"""一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
|
|
26
|
+
rule: Rule | None = None # 这次反馈关联到的规则;hack: 后面增加其他的相关信息,可能有facts等等
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
rule_strategy_registry: dict[str, type[RuleSelectionStrategy]] = {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def register_strategy(name: str) -> Callable[[type[RuleSelectionStrategy]], type[RuleSelectionStrategy]]:
|
|
33
|
+
"""装饰器:将策略类注册到全局表中"""
|
|
34
|
+
def decorator(cls: type[RuleSelectionStrategy]) -> type[RuleSelectionStrategy]:
|
|
35
|
+
rule_strategy_registry[name] = cls
|
|
36
|
+
return cls
|
|
37
|
+
return decorator
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_strategy_class(name: str) -> type[RuleSelectionStrategy]:
|
|
41
|
+
try:
|
|
42
|
+
return rule_strategy_registry[name]
|
|
43
|
+
except KeyError as err:
|
|
44
|
+
raise KeyError(
|
|
45
|
+
f"Rule selection strategy '{name}' is not registered. Available strategies: "
|
|
46
|
+
f"{list(rule_strategy_registry.keys())}"
|
|
47
|
+
) from err
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _available_strategies() -> list[str]:
|
|
51
|
+
return list(rule_strategy_registry.keys())
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from kele.knowledge_bases.builtin_base.builtin_concepts import FREEVARANY_CONCEPT
|
|
6
|
+
from kele.syntax import (FACT_TYPE, TERM_TYPE, Assertion, Formula, CompoundTerm,
|
|
7
|
+
Constant, Variable, ATOM_TYPE, Rule)
|
|
8
|
+
from functools import singledispatch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@singledispatch
|
|
12
|
+
def _unify_all_terms(fact: FACT_TYPE | TERM_TYPE) -> tuple[CompoundTerm | Constant, ...]:
|
|
13
|
+
"""
|
|
14
|
+
主要是将作为formula的fact拆开成Assertion用的,对于单个的Assertion,我们拆成TERMTYPE,传入其他函数处理
|
|
15
|
+
这里直接拆到FlatCompoundTerm方便一些
|
|
16
|
+
"""
|
|
17
|
+
return ()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@_unify_all_terms.register(Assertion)
|
|
21
|
+
def _(fact: Assertion) -> tuple[CompoundTerm | Constant, ...]:
|
|
22
|
+
return _unify_all_terms(fact.lhs) + _unify_all_terms(fact.rhs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@_unify_all_terms.register(Formula)
|
|
26
|
+
def _(fact: Formula) -> tuple[CompoundTerm | Constant, ...]:
|
|
27
|
+
tuple_left = _unify_all_terms(fact.formula_left)
|
|
28
|
+
tuple_right = _unify_all_terms(fact.formula_right) if fact.formula_right is not None else ()
|
|
29
|
+
return tuple_right + tuple_left
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@_unify_all_terms.register(CompoundTerm)
|
|
33
|
+
def _(fact: CompoundTerm) -> tuple[CompoundTerm | Constant, ...]:
|
|
34
|
+
return tuple(_split_all_terms(fact))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@_unify_all_terms.register(Constant)
|
|
38
|
+
def _(fact: Constant) -> tuple[CompoundTerm | Constant, ...]:
|
|
39
|
+
return (fact, )
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@_unify_all_terms.register(Variable)
|
|
43
|
+
def _(fact: Variable) -> tuple[CompoundTerm | Constant, ...]:
|
|
44
|
+
warnings.warn("Variable should not exist in fact", stacklevel=2)
|
|
45
|
+
return ()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FREEVARANY(Constant):
|
|
49
|
+
"""
|
|
50
|
+
ANY标签,在free_variables用于占位,暂定为一种特殊的Constant。
|
|
51
|
+
本引擎在flat term的level上进行grounding操作,即规则、事实中的fact都会被拆解到flat term层级进行匹配。因此nested term需要被拆解为多个
|
|
52
|
+
flat term完成,并且nested term的arguments里的Term类型的值,需要被替换为通配符,即FREEVARANY类
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, value: str) -> None:
|
|
56
|
+
concept = FREEVARANY_CONCEPT
|
|
57
|
+
super().__init__(value, concept)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
FREEANY = FREEVARANY('FREEVARANY')
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _split_all_terms(term: CompoundTerm) -> list[CompoundTerm]:
|
|
64
|
+
"""
|
|
65
|
+
这个函能将CompoundTerm拆分成FlatCompoundTerm,并且返回一个list
|
|
66
|
+
返回的过程中,除非所有的arguments都是Term,否则都会标记之后生成FlatCompoundTerm
|
|
67
|
+
这个函数只应当用在Fact当中,因为它不处理含有Variable的情况
|
|
68
|
+
"""
|
|
69
|
+
# NOTE: 常量在 _unify_all_terms 层面单独处理,这里仅拆分 CompoundTerm
|
|
70
|
+
split_terms: list[CompoundTerm] = []
|
|
71
|
+
split_terms.append(term) # 将一个复合的term中的复合子结构取出来
|
|
72
|
+
|
|
73
|
+
for var in term.arguments:
|
|
74
|
+
if isinstance(var, CompoundTerm):
|
|
75
|
+
split_terms.extend(_split_all_terms(var))
|
|
76
|
+
return split_terms
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def flatten_arguments(arguments: Sequence[TERM_TYPE]) -> tuple[ATOM_TYPE, ...]: # 暂时先作为对外函数,
|
|
80
|
+
# 另外这一页的singledispatch按说可以改成正常的if,以获得更清晰的阅读体验(比如把class丢最上面)
|
|
81
|
+
"""
|
|
82
|
+
给定一个term,这个函数会将term的arguments中的所有Term替换为$F
|
|
83
|
+
无论在fact还是rule中这个函数都是可用的,因为是否存在variable并不影响这个函数的工作
|
|
84
|
+
"""
|
|
85
|
+
return tuple(
|
|
86
|
+
FREEANY if isinstance(var, CompoundTerm) else var
|
|
87
|
+
for var in arguments
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@singledispatch
|
|
92
|
+
def _unify_into_terms(fact: FACT_TYPE | TERM_TYPE) -> tuple[TERM_TYPE, ...]:
|
|
93
|
+
"""
|
|
94
|
+
主要是将作为formula的fact拆开成Assertion用的,对于单个的Assertion,我们拆成TERMTYPE,传入其他函数处理
|
|
95
|
+
这里直接拆到FlatCompoundTerm方便一些
|
|
96
|
+
"""
|
|
97
|
+
return ()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@_unify_into_terms.register(Assertion)
|
|
101
|
+
def _(fact: Assertion) -> tuple[TERM_TYPE, ...]:
|
|
102
|
+
return (fact.lhs, fact.rhs)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@_unify_into_terms.register(Formula)
|
|
106
|
+
def _(fact: Formula) -> tuple[TERM_TYPE, ...]:
|
|
107
|
+
tuple_left = _unify_all_terms(fact.formula_left)
|
|
108
|
+
tuple_right = _unify_all_terms(fact.formula_right) if fact.formula_right is not None else ()
|
|
109
|
+
return tuple_right + tuple_left
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@_unify_into_terms.register(TERM_TYPE)
|
|
113
|
+
def _(fact: TERM_TYPE) -> tuple[TERM_TYPE, ...]:
|
|
114
|
+
return (fact, )
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _unify_ground_terms_from_rule(rule: Rule) -> tuple[TERM_TYPE, ...]:
|
|
118
|
+
terms = _unify_all_terms(rule.head) + _unify_all_terms(rule.body)
|
|
119
|
+
return tuple(term for term in terms if not term.free_variables)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _unify_ground_terms_from_rules(rules: Sequence[Rule]) -> tuple[TERM_TYPE, ...]:
|
|
123
|
+
return tuple(itertools.chain.from_iterable(_unify_ground_terms_from_rule(rule) for rule in rules))
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# 导入所有term strategies
|
|
2
|
+
import importlib
|
|
3
|
+
import logging
|
|
4
|
+
import pathlib
|
|
5
|
+
|
|
6
|
+
from .strategy_protocol import get_strategy_class
|
|
7
|
+
|
|
8
|
+
current_dir = pathlib.Path(__file__).resolve().parent
|
|
9
|
+
package_name = __package__ or current_dir.name
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
logger.setLevel(logging.WARNING)
|
|
13
|
+
|
|
14
|
+
for filename in current_dir.iterdir():
|
|
15
|
+
if filename.suffix == '.py' and filename.stem.endswith('_strategy') and filename.stem.startswith('_'):
|
|
16
|
+
module_name = filename.stem
|
|
17
|
+
logger.info('successfully imported module: "%s"', module_name)
|
|
18
|
+
try:
|
|
19
|
+
module = importlib.import_module(f'{package_name}.{module_name}')
|
|
20
|
+
except ImportError:
|
|
21
|
+
logger.exception('Failed to import %s', module_name)
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
__all__ = ["get_strategy_class"]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
from .strategy_protocol import Feedback, register_strategy, TermSelectionStrategy
|
|
5
|
+
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule, TERM_TYPE
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_strategy('Exhausted')
|
|
9
|
+
class ExhuastedStrategy(TermSelectionStrategy):
|
|
10
|
+
"""
|
|
11
|
+
每次选择剩余的所有terms
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self) -> None:
|
|
14
|
+
self._terms: list[GROUNDED_TYPE_FOR_UNIFICATION] = [] # _terms应当对所有的rule都成立,只是有的被用掉了。需要有去重的方案
|
|
15
|
+
self._rules_used_id: dict[Rule, int] = defaultdict(lambda: 0)
|
|
16
|
+
|
|
17
|
+
def add_terms(self, terms: Sequence[TERM_TYPE]) -> None:
|
|
18
|
+
self._terms.extend(terms)
|
|
19
|
+
|
|
20
|
+
def reset(self) -> None: # HACK: 这个还没用
|
|
21
|
+
self._terms.clear()
|
|
22
|
+
self._rules_used_id = defaultdict(lambda: 0)
|
|
23
|
+
|
|
24
|
+
def select_next(self, rule: Rule) -> Sequence[GROUNDED_TYPE_FOR_UNIFICATION]:
|
|
25
|
+
# 循环顺序取下一条
|
|
26
|
+
start_id = self._rules_used_id[rule]
|
|
27
|
+
selected_terms = list(self._terms[start_id:])
|
|
28
|
+
self._rules_used_id[rule] = len(self._terms)
|
|
29
|
+
|
|
30
|
+
return selected_terms
|
|
31
|
+
|
|
32
|
+
def on_feedback(self, feedback: Feedback) -> None: # noqa: PLR6301 # 尚未实现,不需要转static
|
|
33
|
+
# 顺序循环策略不依赖反馈,空实现即可
|
|
34
|
+
return
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Protocol, runtime_checkable, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule, TERM_TYPE
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class TermSelectionStrategy(Protocol):
|
|
13
|
+
"""
|
|
14
|
+
选取策略的统一接口。允许根据需求返回任意规则。
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self) -> None: ...
|
|
17
|
+
def add_terms(self, terms: Sequence[TERM_TYPE]) -> None: ...
|
|
18
|
+
def reset(self) -> None: ...
|
|
19
|
+
def select_next(self, rule: Rule) -> Sequence[GROUNDED_TYPE_FOR_UNIFICATION]: ...
|
|
20
|
+
def on_feedback(self, feedback: Feedback) -> None: ... # 给策略回传一次选择后的反馈
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class Feedback:
|
|
25
|
+
"""一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
term_strategy_registry: dict[str, type[TermSelectionStrategy]] = {}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def register_strategy(name: str) -> Callable[[type[TermSelectionStrategy]], type[TermSelectionStrategy]]:
|
|
32
|
+
"""装饰器:将策略类注册到全局表中"""
|
|
33
|
+
def decorator(cls: type[TermSelectionStrategy]) -> type[TermSelectionStrategy]:
|
|
34
|
+
term_strategy_registry[name] = cls
|
|
35
|
+
return cls
|
|
36
|
+
return decorator
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_strategy_class(name: str) -> type[TermSelectionStrategy]:
|
|
40
|
+
try:
|
|
41
|
+
return term_strategy_registry[name]
|
|
42
|
+
except KeyError as err:
|
|
43
|
+
raise KeyError(
|
|
44
|
+
f"Term selection strategy '{name}' is not registered. Available strategies: "
|
|
45
|
+
f"{list(term_strategy_registry.keys())}"
|
|
46
|
+
) from err
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _available_strategies() -> list[str]:
|
|
50
|
+
return list(term_strategy_registry.keys())
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
from ._rule_strategies import get_strategy_class
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from ._rule_strategies.strategy_protocol import RuleSelectionStrategy, Feedback
|
|
9
|
+
from collections.abc import Sequence
|
|
10
|
+
from kele.syntax import Rule, _QuestionRule
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GroundingRuleSelector:
|
|
14
|
+
"""
|
|
15
|
+
对外统一入口,内部委托给策略实现。
|
|
16
|
+
- 允许切换策略(接口不限制“连续取中”)
|
|
17
|
+
- 允许更新规则集合
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, strategy: str = "sequential_cyclic", question_rule_interval: int = -1) -> None:
|
|
20
|
+
strategy_cls = get_strategy_class(strategy)
|
|
21
|
+
self._strategy = strategy_cls()
|
|
22
|
+
self._normal_rules: list[Rule] | None = None
|
|
23
|
+
|
|
24
|
+
self.question_rule_interval = question_rule_interval
|
|
25
|
+
self.used_rule_cnt: int = 0
|
|
26
|
+
self._question_rules: list[Rule] = []
|
|
27
|
+
self._at_fixpoint: bool = False
|
|
28
|
+
|
|
29
|
+
def set_at_fixpoint(self, *, at_fixpoint: bool) -> None:
|
|
30
|
+
"""设置是否已经达到不动点状态"""
|
|
31
|
+
self._at_fixpoint = at_fixpoint
|
|
32
|
+
|
|
33
|
+
def next_rules(self) -> Sequence[Rule]:
|
|
34
|
+
"""选择一定数量的规则用于grounding,在一定轮次后会查看一次question是否被解决
|
|
35
|
+
:raises ValueError: 当 question_rule_interval 小于1且不为-1时抛出。
|
|
36
|
+
""" # noqa: DOC501
|
|
37
|
+
if self._at_fixpoint:
|
|
38
|
+
if self._question_rules:
|
|
39
|
+
self.used_rule_cnt = 0
|
|
40
|
+
return self._question_rules
|
|
41
|
+
return []
|
|
42
|
+
|
|
43
|
+
# 确定实际的检查间隔
|
|
44
|
+
if self.question_rule_interval < 1 and self.question_rule_interval != -1:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"question_rule_interval must be >= 1, or -1 to use the total count of normal rules as the interval"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
interval = self.question_rule_interval
|
|
50
|
+
if interval == -1:
|
|
51
|
+
normal_count = len(self._normal_rules) if self._normal_rules is not None else 0
|
|
52
|
+
interval = normal_count or 1
|
|
53
|
+
|
|
54
|
+
if self.used_rule_cnt >= interval and self._question_rules:
|
|
55
|
+
self.used_rule_cnt = 0 # 重置计数器
|
|
56
|
+
return self._question_rules
|
|
57
|
+
|
|
58
|
+
rules = self._strategy.select_next()
|
|
59
|
+
self.used_rule_cnt += len(rules)
|
|
60
|
+
|
|
61
|
+
return rules
|
|
62
|
+
|
|
63
|
+
def reset(self) -> None:
|
|
64
|
+
"""重置选择器"""
|
|
65
|
+
self._strategy.reset()
|
|
66
|
+
self._normal_rules = None
|
|
67
|
+
self._question_rules = []
|
|
68
|
+
self.used_rule_cnt = 0
|
|
69
|
+
self._at_fixpoint = False
|
|
70
|
+
|
|
71
|
+
def set_strategy(self, strategy: RuleSelectionStrategy) -> None:
|
|
72
|
+
"""切换策略实现(不重置调用方传入策略的内部状态,由策略自行决定)。"""
|
|
73
|
+
self._strategy = strategy
|
|
74
|
+
|
|
75
|
+
if self._normal_rules is not None:
|
|
76
|
+
self._strategy.set_rules(self._normal_rules)
|
|
77
|
+
else:
|
|
78
|
+
warnings.warn("No given rules, please call set_rules after calling set_strategy.", stacklevel=2)
|
|
79
|
+
|
|
80
|
+
def set_rules(self, normal_rules: Sequence[Rule], question_rules: Sequence[_QuestionRule]) -> None:
|
|
81
|
+
"""
|
|
82
|
+
更新规则集合,并同步给当前策略
|
|
83
|
+
:raise: ValueError: 没有可选rules时报错
|
|
84
|
+
""" # noqa: DOC501
|
|
85
|
+
if not normal_rules:
|
|
86
|
+
raise ValueError("rules cannot be empty")
|
|
87
|
+
|
|
88
|
+
self._normal_rules = list(normal_rules)
|
|
89
|
+
self._question_rules = list(question_rules)
|
|
90
|
+
|
|
91
|
+
self._strategy.set_rules(self._normal_rules)
|
|
92
|
+
|
|
93
|
+
# 重置计数器
|
|
94
|
+
self.used_rule_cnt = 0
|
|
95
|
+
|
|
96
|
+
def send_feedback(self, feedback: Feedback) -> None:
|
|
97
|
+
"""把一次选择后的反馈转发给当前策略"""
|
|
98
|
+
self._strategy.on_feedback(feedback)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
from ._term_strategies import get_strategy_class
|
|
6
|
+
from kele.control.grounding_selector._selector_utils import (
|
|
7
|
+
_unify_ground_terms_from_rules,
|
|
8
|
+
_unify_into_terms,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from kele.config import Config
|
|
13
|
+
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule
|
|
14
|
+
from kele.equality import Equivalence
|
|
15
|
+
from ._term_strategies.strategy_protocol import Feedback
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from kele.syntax import FACT_TYPE
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GroundingFlatTermWithWildCardSelector: # 此时是FlatTerm-level grounding,所以进strategy不止是Term。但是strategy本身是Term和
|
|
21
|
+
# FlatTerm均合适的,所以strategy自己的标准仍然得是TERM_TYPE。
|
|
22
|
+
# 此外,目前通配符和flat term无法分割,所以命名强调了wild card
|
|
23
|
+
"""
|
|
24
|
+
对外统一入口,内部委托给策略实现。
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self,
|
|
27
|
+
equivalence: Equivalence,
|
|
28
|
+
args: Config) -> None:
|
|
29
|
+
|
|
30
|
+
strategy_cls = get_strategy_class(args.strategy.grounding_term_strategy)
|
|
31
|
+
self._strategy = strategy_cls()
|
|
32
|
+
self._equivalence = equivalence
|
|
33
|
+
self._args = args
|
|
34
|
+
|
|
35
|
+
def next_terms(self, rule: Rule) -> list[GROUNDED_TYPE_FOR_UNIFICATION]:
|
|
36
|
+
"""为给定规则选择候选事实/term 用于 grounding。"""
|
|
37
|
+
init_terms = list(self._strategy.select_next(rule))
|
|
38
|
+
|
|
39
|
+
if self._args.run.semi_eval_with_equality:
|
|
40
|
+
selected_terms: set[GROUNDED_TYPE_FOR_UNIFICATION] = set()
|
|
41
|
+
|
|
42
|
+
for t in init_terms:
|
|
43
|
+
if t not in selected_terms:
|
|
44
|
+
equiv_terms = self._equivalence.get_related_item(t)
|
|
45
|
+
selected_terms |= set(equiv_terms) # TODO: 这里后面考虑是否把_unify_into_flat_terms挪到term selector
|
|
46
|
+
|
|
47
|
+
# 在semi-evaluation中纳入等价类其余terms,这是由
|
|
48
|
+
# 断言逻辑带来的内嵌、强制等词公理,和term-level grounding二者共同决定的算法,因为不放置于具体的Strategy中,而是作为
|
|
49
|
+
# Selector自身的行为。
|
|
50
|
+
return list(selected_terms)
|
|
51
|
+
|
|
52
|
+
return list(set(init_terms))
|
|
53
|
+
|
|
54
|
+
def reset(self) -> None:
|
|
55
|
+
"""重置选择器"""
|
|
56
|
+
self._strategy.reset()
|
|
57
|
+
|
|
58
|
+
def update_terms(self,
|
|
59
|
+
terms: Sequence[GROUNDED_TYPE_FOR_UNIFICATION] | None = None,
|
|
60
|
+
facts: Sequence[FACT_TYPE] | None = None) -> None:
|
|
61
|
+
"""
|
|
62
|
+
更新事实集合,并同步给当前策略。
|
|
63
|
+
:raise: ValueError: 没有可选terms_or_facts时报错
|
|
64
|
+
""" # noqa: DOC501
|
|
65
|
+
|
|
66
|
+
if terms:
|
|
67
|
+
fact_terms = terms
|
|
68
|
+
elif facts:
|
|
69
|
+
fact_terms = list(itertools.chain.from_iterable(_unify_into_terms(f) for f in facts)) # FIXME: 更进一步地,其实只有Assertion
|
|
70
|
+
# 能触发等式公理
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError("terms or facts cannot be empty")
|
|
73
|
+
|
|
74
|
+
# 不应有action operator的term。fact_terms = [f for f in fact_terms if f.operator.implement_func is not None]
|
|
75
|
+
self._strategy.add_terms(fact_terms)
|
|
76
|
+
|
|
77
|
+
def update_terms_from_rules(self, rules: Sequence[Rule]) -> None:
|
|
78
|
+
"""
|
|
79
|
+
将规则/问题中无变量的term加入候选表。
|
|
80
|
+
"""
|
|
81
|
+
rule_terms = _unify_ground_terms_from_rules(rules)
|
|
82
|
+
if not rule_terms:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
self._strategy.add_terms(rule_terms)
|
|
86
|
+
|
|
87
|
+
def send_feedback(self, feedback: Feedback) -> None:
|
|
88
|
+
"""把一次选择后的反馈转发给当前策略"""
|
|
89
|
+
self._strategy.on_feedback(feedback)
|