kele 0.0.1a1__cp314-cp314-win32.whl → 0.0.1a2__cp314-cp314-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/__init__.py +8 -6
- kele/_utils.py +23 -0
- kele/_version.py +1 -1
- kele/config.py +52 -14
- kele/control/__init__.py +9 -4
- kele/control/builtin_hooks.py +119 -0
- kele/control/callback.py +14 -4
- kele/control/grounding_selector/__init__.py +7 -1
- kele/control/grounding_selector/_rule_strategies/README.md +25 -1
- kele/control/grounding_selector/_rule_strategies/__init__.py +2 -2
- kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +5 -3
- kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +3 -27
- kele/control/grounding_selector/_selector_utils.py +2 -10
- kele/control/grounding_selector/_term_strategies/README.md +34 -0
- kele/control/grounding_selector/_term_strategies/__init__.py +2 -2
- kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +5 -3
- kele/control/grounding_selector/_term_strategies/strategy_protocol.py +4 -28
- kele/control/grounding_selector/rule_selector.py +14 -5
- kele/control/grounding_selector/term_selector.py +16 -6
- kele/control/infer_path.py +9 -8
- kele/control/metrics.py +7 -8
- kele/control/registry.py +112 -0
- kele/control/status.py +190 -49
- kele/egg_equiv.pyd +0 -0
- kele/egg_equiv.pyi +3 -3
- kele/equality/_equiv_elem.py +2 -1
- kele/equality/_utils.py +4 -2
- kele/equality/equivalence.py +8 -7
- kele/executer/executing.py +40 -24
- kele/grounder/__init__.py +2 -7
- kele/grounder/grounded_rule_ds/__init__.py +2 -2
- kele/grounder/grounded_rule_ds/_nodes/__init__.py +4 -4
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +27 -14
- kele/grounder/grounded_rule_ds/_nodes/_conn.py +4 -2
- kele/grounder/grounded_rule_ds/_nodes/_op.py +6 -3
- kele/grounder/grounded_rule_ds/_nodes/_root.py +4 -4
- kele/grounder/grounded_rule_ds/_nodes/_rule.py +5 -5
- kele/grounder/grounded_rule_ds/_nodes/_term.py +11 -10
- kele/grounder/grounded_rule_ds/_nodes/_tftable.py +1 -0
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +62 -58
- kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +1 -1
- kele/grounder/grounded_rule_ds/grounded_class.py +60 -17
- kele/grounder/grounded_rule_ds/grounded_ds_utils.py +3 -4
- kele/grounder/grounded_rule_ds/rule_check.py +6 -5
- kele/grounder/grounding.py +38 -36
- kele/knowledge_bases/__init__.py +1 -1
- kele/knowledge_bases/builtin_base/builtin_facts.py +3 -1
- kele/knowledge_bases/builtin_base/builtin_operators.py +1 -1
- kele/knowledge_bases/builtin_base/builtin_rules.py +1 -0
- kele/knowledge_bases/fact_base.py +37 -16
- kele/knowledge_bases/ontology_base.py +1 -1
- kele/knowledge_bases/rule_base.py +53 -32
- kele/main.py +55 -37
- kele/syntax/__init__.py +12 -12
- kele/syntax/_cnf_converter.py +2 -2
- kele/syntax/_sat_solver.py +3 -3
- kele/syntax/base_classes.py +92 -20
- kele/syntax/dnf_converter.py +10 -5
- kele/syntax/external.py +2 -1
- kele/syntax/sub_concept.py +5 -4
- kele/syntax/syntacticsugar.py +3 -4
- {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/METADATA +13 -3
- kele-0.0.1a2.dist-info/RECORD +78 -0
- kele-0.0.1a1.dist-info/RECORD +0 -74
- {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/WHEEL +0 -0
- {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/licenses/LICENSE +0 -0
- {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/licenses/licensecheck.json +0 -0
kele/__init__.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
"""支持断言逻辑的推理引擎"""
|
|
2
|
-
from kele.main import EngineRunResult, InferenceEngine, QueryStructure
|
|
3
2
|
from kele.config import (
|
|
4
3
|
Config,
|
|
5
|
-
RunControlConfig,
|
|
6
|
-
InferenceStrategyConfig,
|
|
7
|
-
GrounderConfig,
|
|
8
4
|
ExecutorConfig,
|
|
9
|
-
|
|
5
|
+
GrounderConfig,
|
|
6
|
+
InferenceStrategyConfig,
|
|
10
7
|
KBConfig,
|
|
8
|
+
PathConfig,
|
|
9
|
+
RunControlConfig,
|
|
11
10
|
)
|
|
12
|
-
from kele.
|
|
11
|
+
from kele.control import register
|
|
12
|
+
from kele.main import EngineRunResult, InferenceEngine, QueryStructure
|
|
13
|
+
from kele.syntax.base_classes import Assertion, CompoundTerm, Concept, Constant, Formula, Operator, Rule, Variable
|
|
13
14
|
|
|
14
15
|
try:
|
|
15
16
|
from ._version import version as __version__
|
|
@@ -35,4 +36,5 @@ __all__ = [
|
|
|
35
36
|
'Rule',
|
|
36
37
|
'RunControlConfig',
|
|
37
38
|
'Variable',
|
|
39
|
+
'register',
|
|
38
40
|
]
|
kele/_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from collections.abc import Callable, Sequence
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def summarize_items(
|
|
12
|
+
items: Sequence[T],
|
|
13
|
+
*,
|
|
14
|
+
sample_size: int = 5,
|
|
15
|
+
formatter: Callable[[T], str] = str,
|
|
16
|
+
) -> dict[str, Any]:
|
|
17
|
+
"""Summarize a sequence for debug logging."""
|
|
18
|
+
total = len(items)
|
|
19
|
+
sample = [formatter(item) for item in items[:sample_size]]
|
|
20
|
+
return {
|
|
21
|
+
"rows": total,
|
|
22
|
+
"sample": sample,
|
|
23
|
+
}
|
kele/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
version = "0.0.
|
|
1
|
+
version = "0.0.1a2"
|
kele/config.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
# ruff: noqa: ERA001 # Commented parameters are either not implemented yet or depend on unfinished upstream/downstream modules.
|
|
2
|
-
import
|
|
3
|
-
from typing import Any, cast, Literal
|
|
4
|
-
|
|
2
|
+
import json
|
|
5
3
|
import logging
|
|
6
|
-
|
|
7
|
-
from
|
|
4
|
+
import warnings
|
|
5
|
+
from dataclasses import dataclass, field, fields
|
|
6
|
+
from datetime import UTC, datetime
|
|
8
7
|
from pathlib import Path
|
|
9
|
-
import
|
|
10
|
-
|
|
11
|
-
import tyro
|
|
12
|
-
from tyro.conf import OmitArgPrefixes
|
|
8
|
+
from typing import Any, Literal, cast
|
|
9
|
+
|
|
13
10
|
import dacite
|
|
11
|
+
import tyro
|
|
12
|
+
import yaml
|
|
14
13
|
from dacite.config import Config as daConfig
|
|
14
|
+
from tyro.conf import OmitArgPrefixes
|
|
15
15
|
|
|
16
16
|
RESULT_LEVEL = 25
|
|
17
17
|
logging.RESULT = RESULT_LEVEL # type: ignore[attr-defined] # This fails mypy; setattr fails ruff.
|
|
@@ -59,12 +59,50 @@ class InferenceStrategyConfig:
|
|
|
59
59
|
@dataclass
|
|
60
60
|
class GrounderConfig:
|
|
61
61
|
"""Grounder-related parameters."""
|
|
62
|
-
|
|
63
|
-
|
|
62
|
+
grounding_rules_per_step: int | Literal[-1] = -1
|
|
63
|
+
grounding_facts_per_rule: int | Literal[-1] = -1
|
|
64
|
+
grounding_rules_num_every_step: int | Literal[-1] | None = None
|
|
65
|
+
grounding_facts_num_for_each_rule: int | Literal[-1] | None = None
|
|
64
66
|
allow_unify_with_nested_term: bool = True # Allow Variables to be replaced by CompoundTerms.
|
|
65
67
|
conceptual_fuzzy_unification: bool = True # Use strict concept constraints to accelerate inference.
|
|
66
68
|
# This depends on correct concept subsumption and full constant.belong_concepts settings; beginners should use loose matching.
|
|
67
69
|
|
|
70
|
+
def __post_init__(self) -> None:
|
|
71
|
+
if self.grounding_rules_num_every_step is not None:
|
|
72
|
+
warnings.warn(
|
|
73
|
+
"grounding_rules_num_every_step is deprecated; use grounding_rules_per_step instead.",
|
|
74
|
+
DeprecationWarning,
|
|
75
|
+
stacklevel=2,
|
|
76
|
+
)
|
|
77
|
+
if self.grounding_rules_per_step == -1:
|
|
78
|
+
self.grounding_rules_per_step = self.grounding_rules_num_every_step
|
|
79
|
+
elif self.grounding_rules_per_step != self.grounding_rules_num_every_step:
|
|
80
|
+
warnings.warn(
|
|
81
|
+
"Both grounding_rules_per_step and grounding_rules_num_every_step are set; using grounding_rules_per_step.",
|
|
82
|
+
DeprecationWarning,
|
|
83
|
+
stacklevel=2,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.grounding_facts_num_for_each_rule is not None:
|
|
87
|
+
warnings.warn(
|
|
88
|
+
"grounding_facts_num_for_each_rule is deprecated; use grounding_facts_per_rule instead.",
|
|
89
|
+
DeprecationWarning,
|
|
90
|
+
stacklevel=2,
|
|
91
|
+
)
|
|
92
|
+
if self.grounding_facts_per_rule == -1:
|
|
93
|
+
self.grounding_facts_per_rule = self.grounding_facts_num_for_each_rule
|
|
94
|
+
elif self.grounding_facts_per_rule != self.grounding_facts_num_for_each_rule:
|
|
95
|
+
warnings.warn(
|
|
96
|
+
"Both grounding_facts_per_rule and grounding_facts_num_for_each_rule are set; using grounding_facts_per_rule.",
|
|
97
|
+
DeprecationWarning,
|
|
98
|
+
stacklevel=2,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if self.grounding_rules_num_every_step is None:
|
|
102
|
+
self.grounding_rules_num_every_step = self.grounding_rules_per_step
|
|
103
|
+
if self.grounding_facts_num_for_each_rule is None:
|
|
104
|
+
self.grounding_facts_num_for_each_rule = self.grounding_facts_per_rule
|
|
105
|
+
|
|
68
106
|
|
|
69
107
|
@dataclass
|
|
70
108
|
class ExecutorConfig:
|
|
@@ -103,7 +141,7 @@ class Config:
|
|
|
103
141
|
|
|
104
142
|
|
|
105
143
|
def _load_config_file(path: str) -> dict[str, Any]:
|
|
106
|
-
with open(
|
|
144
|
+
with Path(path).open(encoding="utf-8") as f:
|
|
107
145
|
if path.endswith(('.yaml', '.yml')):
|
|
108
146
|
data = yaml.safe_load(f)
|
|
109
147
|
elif path.endswith('.json'):
|
|
@@ -117,7 +155,7 @@ def _load_config_file(path: str) -> dict[str, Any]:
|
|
|
117
155
|
|
|
118
156
|
|
|
119
157
|
def _save_config(config: dict[str, Any], path: str) -> None:
|
|
120
|
-
with open(
|
|
158
|
+
with Path(path).open('w', encoding='utf8') as f:
|
|
121
159
|
yaml.dump(config, f, sort_keys=False)
|
|
122
160
|
|
|
123
161
|
|
|
@@ -197,7 +235,7 @@ def _build_config(user_config: Config | None = None,
|
|
|
197
235
|
|
|
198
236
|
:raises: ValueError: If `user_config` is used together with `config_file_path`
|
|
199
237
|
or when `user_config.config` is set.
|
|
200
|
-
"""
|
|
238
|
+
"""
|
|
201
239
|
if user_config and (user_config.config or config_file_path):
|
|
202
240
|
raise ValueError("default config instance and config file cannot be used together")
|
|
203
241
|
|
kele/control/__init__.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
"""用于callbacks和推理路径的记录"""
|
|
2
|
-
from .callback import
|
|
2
|
+
from .callback import Callback, CallbackManager, HookMixin
|
|
3
|
+
from .grounding_selector import GroundingRuleSelector
|
|
4
|
+
from .builtin_hooks import BuiltinHookEnabler, register_assertion_check_hook
|
|
5
|
+
from .infer_path import InferencePath
|
|
6
|
+
from .registry import register
|
|
3
7
|
from .status import (
|
|
4
8
|
InferenceStatus,
|
|
5
|
-
create_main_loop_manager,
|
|
6
9
|
create_executor_manager,
|
|
10
|
+
create_main_loop_manager,
|
|
7
11
|
)
|
|
8
|
-
from .grounding_selector import GroundingRuleSelector
|
|
9
|
-
from .infer_path import InferencePath
|
|
10
12
|
|
|
11
13
|
__all__ = [
|
|
14
|
+
'BuiltinHookEnabler',
|
|
12
15
|
'Callback',
|
|
13
16
|
'CallbackManager',
|
|
14
17
|
'GroundingRuleSelector',
|
|
@@ -17,4 +20,6 @@ __all__ = [
|
|
|
17
20
|
'InferenceStatus',
|
|
18
21
|
'create_executor_manager',
|
|
19
22
|
'create_main_loop_manager',
|
|
23
|
+
'register',
|
|
24
|
+
'register_assertion_check_hook',
|
|
20
25
|
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from collections.abc import Callable, Mapping
|
|
8
|
+
|
|
9
|
+
from kele.grounder import GroundedRule
|
|
10
|
+
from kele.grounder.grounded_rule_ds._nodes import _AssertionNode
|
|
11
|
+
from kele.syntax import CompoundTerm, Constant, Rule, Variable
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def register_assertion_check_hook(
|
|
17
|
+
grounded_rule: GroundedRule,
|
|
18
|
+
*,
|
|
19
|
+
rule_name: str | None = None,
|
|
20
|
+
vars_filter: Mapping[str, str] | None = None,
|
|
21
|
+
on_match: Callable[[Rule, _AssertionNode, dict[Variable, Constant | CompoundTerm], bool], None] | None = None,
|
|
22
|
+
break_on_match: bool = False,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Register a built-in hook that observes assertion check results.
|
|
26
|
+
|
|
27
|
+
This is a user-facing inspection hook: it lets you capture whether a specific
|
|
28
|
+
assertion evaluation (for a given variable binding) is True/False, and optionally
|
|
29
|
+
stop on matches for deep debugging.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
def log_match(rule, assertion, combination, result):
|
|
33
|
+
print(rule.name, assertion.content, combination, result)
|
|
34
|
+
|
|
35
|
+
register_assertion_check_hook(
|
|
36
|
+
grounded_rule,
|
|
37
|
+
rule_name="rule_3",
|
|
38
|
+
vars_filter={"p1": "f", "p2": "a", "p13": "b", "p14": "c"},
|
|
39
|
+
on_match=log_match,
|
|
40
|
+
break_on_match=True,
|
|
41
|
+
)
|
|
42
|
+
"""
|
|
43
|
+
normalized_vars_filter = dict(vars_filter) if vars_filter is not None else None
|
|
44
|
+
|
|
45
|
+
def _hook(
|
|
46
|
+
*,
|
|
47
|
+
rule: Rule,
|
|
48
|
+
assertion: _AssertionNode,
|
|
49
|
+
combination: dict[Variable, Constant | CompoundTerm],
|
|
50
|
+
result: bool,
|
|
51
|
+
) -> None:
|
|
52
|
+
if rule_name and rule.name != rule_name:
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
combination_str = {str(k): str(v) for k, v in combination.items()}
|
|
56
|
+
if normalized_vars_filter:
|
|
57
|
+
for key, value in normalized_vars_filter.items():
|
|
58
|
+
if combination_str.get(key) != value:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
if on_match is None:
|
|
62
|
+
logger.debug(
|
|
63
|
+
"Assertion check hook: rule=%s content=%s combination=%s result=%s",
|
|
64
|
+
rule,
|
|
65
|
+
assertion.content,
|
|
66
|
+
combination_str,
|
|
67
|
+
result,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
on_match(rule, assertion, combination, result)
|
|
71
|
+
|
|
72
|
+
if break_on_match:
|
|
73
|
+
breakpoint() # noqa: T100
|
|
74
|
+
|
|
75
|
+
grounded_rule.register_hook("assertion_check", _hook)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BuiltinHookEnabler:
|
|
79
|
+
"""
|
|
80
|
+
Enabler for built-in hooks by name.
|
|
81
|
+
|
|
82
|
+
This class does not register hooks by itself. Create an instance and call
|
|
83
|
+
``enable`` to attach a built-in hook to a grounded rule.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
hooks = BuiltinHookEnabler()
|
|
87
|
+
hooks.enable(
|
|
88
|
+
grounded_rule,
|
|
89
|
+
"assertion_check",
|
|
90
|
+
rule_name="rule_3",
|
|
91
|
+
vars_filter={"p1": "f"},
|
|
92
|
+
)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self) -> None:
|
|
96
|
+
self._hooks: dict[str, Callable[..., None]] = {
|
|
97
|
+
"assertion_check": register_assertion_check_hook,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
def available_hooks(self) -> list[str]:
|
|
101
|
+
"""
|
|
102
|
+
Return the names of available built-in hooks.
|
|
103
|
+
"""
|
|
104
|
+
return sorted(self._hooks)
|
|
105
|
+
|
|
106
|
+
def enable(self, grounded_rule: GroundedRule, name: str, **kwargs: object) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Enable a built-in hook by name.
|
|
109
|
+
"""
|
|
110
|
+
if name not in self._hooks:
|
|
111
|
+
raise KeyError(f"Unknown built-in hook: {name}")
|
|
112
|
+
self._hooks[name](grounded_rule, **kwargs)
|
|
113
|
+
|
|
114
|
+
def enable_many(self, grounded_rule: GroundedRule, names: list[str], **kwargs: object) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Enable multiple built-in hooks by name.
|
|
117
|
+
"""
|
|
118
|
+
for hook_name in names:
|
|
119
|
+
self.enable(grounded_rule, hook_name, **kwargs)
|
kele/control/callback.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from typing import
|
|
5
|
-
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
6
5
|
|
|
7
6
|
if TYPE_CHECKING:
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
|
|
8
9
|
from kele.equality import Equivalence
|
|
9
10
|
from kele.grounder import GroundedRule
|
|
10
11
|
from kele.knowledge_bases import FactBase, RuleBase
|
|
11
|
-
from kele.syntax import
|
|
12
|
-
from collections.abc import Callable
|
|
12
|
+
from kele.syntax import FACT_TYPE, CompoundTerm, Constant, Question, Rule, Variable
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class HookMixin:
|
|
@@ -41,6 +41,16 @@ class HookMixin:
|
|
|
41
41
|
for hook in self._hooks.get(event_name, []):
|
|
42
42
|
hook(*args, **kwargs)
|
|
43
43
|
|
|
44
|
+
def run_hooks(self, event_name: str, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
|
|
45
|
+
"""
|
|
46
|
+
对外公开的钩子触发入口。
|
|
47
|
+
|
|
48
|
+
:param event_name: 事件名称。
|
|
49
|
+
:param args: 传递给钩子的所有位置参数。
|
|
50
|
+
:param kwargs: 传递给钩子的所有关键字参数。
|
|
51
|
+
"""
|
|
52
|
+
self._run_hooks(event_name, *args, **kwargs)
|
|
53
|
+
|
|
44
54
|
|
|
45
55
|
class Callback:
|
|
46
56
|
"""回调接口——在推理各阶段采集指标的Hook"""
|
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
"""grounding相关的选择器"""
|
|
2
|
+
from kele.control.registry import register
|
|
3
|
+
|
|
2
4
|
from .rule_selector import GroundingRuleSelector
|
|
3
5
|
from .term_selector import GroundingFlatTermWithWildCardSelector
|
|
4
6
|
|
|
5
|
-
__all__ = [
|
|
7
|
+
__all__ = [
|
|
8
|
+
"GroundingFlatTermWithWildCardSelector",
|
|
9
|
+
"GroundingRuleSelector",
|
|
10
|
+
"register",
|
|
11
|
+
]
|
|
@@ -9,5 +9,29 @@
|
|
|
9
9
|
## 创建自己的strategy
|
|
10
10
|
1. 创建一个py文件,命名要求为`_<name>_strategy.py`;
|
|
11
11
|
2. 继承RuleSelectionStrategy类,并至少声明此Protocol要求的函数;
|
|
12
|
-
3.
|
|
12
|
+
3. 从`kele.control`导入并使用`@register.rule_selector('<name>')`注册你的策略类,后续即可通过`grounding_rule_strategy`使用策略;
|
|
13
13
|
4. 注意调整`grounding_rule_strategy`的类型标注(增加Literal的候选值)。
|
|
14
|
+
|
|
15
|
+
示例:
|
|
16
|
+
```python
|
|
17
|
+
from kele.control import register
|
|
18
|
+
from kele.control.grounding_selector._rule_strategies.strategy_protocol import RuleSelectionStrategy
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register.rule_selector("MyRuleStrategy")
|
|
22
|
+
class MyRuleStrategy:
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
self._rules = []
|
|
25
|
+
|
|
26
|
+
def set_rules(self, rules):
|
|
27
|
+
self._rules = list(rules)
|
|
28
|
+
|
|
29
|
+
def reset(self) -> None:
|
|
30
|
+
self._rules = []
|
|
31
|
+
|
|
32
|
+
def select_next(self):
|
|
33
|
+
return self._rules[:1]
|
|
34
|
+
|
|
35
|
+
def on_feedback(self, feedback):
|
|
36
|
+
return None
|
|
37
|
+
```
|
|
@@ -3,7 +3,7 @@ import importlib
|
|
|
3
3
|
import logging
|
|
4
4
|
import pathlib
|
|
5
5
|
|
|
6
|
-
from .
|
|
6
|
+
from kele.control.registry import get_rule_strategy_class
|
|
7
7
|
|
|
8
8
|
current_dir = pathlib.Path(__file__).resolve().parent
|
|
9
9
|
package_name = __package__ or current_dir.name
|
|
@@ -21,4 +21,4 @@ for filename in current_dir.iterdir():
|
|
|
21
21
|
logger.exception('Failed to import %s', module_name)
|
|
22
22
|
continue
|
|
23
23
|
|
|
24
|
-
__all__ = ["
|
|
24
|
+
__all__ = ["get_rule_strategy_class"]
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
|
|
3
|
-
from .
|
|
3
|
+
from kele.control.registry import register
|
|
4
4
|
from kele.syntax import Rule
|
|
5
5
|
|
|
6
|
+
from .strategy_protocol import Feedback, RuleSelectionStrategy
|
|
6
7
|
|
|
7
|
-
|
|
8
|
+
|
|
9
|
+
@register.rule_selector('SequentialCyclic')
|
|
8
10
|
class SequentialCyclicStrategy(RuleSelectionStrategy):
|
|
9
11
|
"""
|
|
10
12
|
按顺序循环遍历策略:
|
|
@@ -33,7 +35,7 @@ class SequentialCyclicStrategy(RuleSelectionStrategy):
|
|
|
33
35
|
return
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
@
|
|
38
|
+
@register.rule_selector("SequentialCyclicWithPriority")
|
|
37
39
|
class SequentialCyclicWithPriorityStrategy(SequentialCyclicStrategy):
|
|
38
40
|
"""将规则按优先级排序,优先级高的先取"""
|
|
39
41
|
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Protocol, runtime_checkable
|
|
4
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
4
5
|
|
|
5
6
|
if TYPE_CHECKING:
|
|
6
7
|
from collections.abc import Sequence
|
|
8
|
+
|
|
7
9
|
from kele.syntax import Rule
|
|
8
|
-
from collections.abc import Callable
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
@runtime_checkable
|
|
@@ -24,28 +25,3 @@ class RuleSelectionStrategy(Protocol):
|
|
|
24
25
|
class Feedback:
|
|
25
26
|
"""一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
|
|
26
27
|
rule: Rule | None = None # 这次反馈关联到的规则;hack: 后面增加其他的相关信息,可能有facts等等
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
rule_strategy_registry: dict[str, type[RuleSelectionStrategy]] = {}
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def register_strategy(name: str) -> Callable[[type[RuleSelectionStrategy]], type[RuleSelectionStrategy]]:
|
|
33
|
-
"""装饰器:将策略类注册到全局表中"""
|
|
34
|
-
def decorator(cls: type[RuleSelectionStrategy]) -> type[RuleSelectionStrategy]:
|
|
35
|
-
rule_strategy_registry[name] = cls
|
|
36
|
-
return cls
|
|
37
|
-
return decorator
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def get_strategy_class(name: str) -> type[RuleSelectionStrategy]:
|
|
41
|
-
try:
|
|
42
|
-
return rule_strategy_registry[name]
|
|
43
|
-
except KeyError as err:
|
|
44
|
-
raise KeyError(
|
|
45
|
-
f"Rule selection strategy '{name}' is not registered. Available strategies: "
|
|
46
|
-
f"{list(rule_strategy_registry.keys())}"
|
|
47
|
-
) from err
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _available_strategies() -> list[str]:
|
|
51
|
-
return list(rule_strategy_registry.keys())
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
import itertools
|
|
2
|
-
import warnings
|
|
3
2
|
from collections.abc import Sequence
|
|
3
|
+
from functools import singledispatch
|
|
4
4
|
|
|
5
5
|
from kele.knowledge_bases.builtin_base.builtin_concepts import FREEVARANY_CONCEPT
|
|
6
|
-
from kele.syntax import
|
|
7
|
-
Constant, Variable, ATOM_TYPE, Rule)
|
|
8
|
-
from functools import singledispatch
|
|
6
|
+
from kele.syntax import ATOM_TYPE, FACT_TYPE, TERM_TYPE, Assertion, CompoundTerm, Constant, Formula, Rule
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
@singledispatch
|
|
@@ -39,12 +37,6 @@ def _(fact: Constant) -> tuple[CompoundTerm | Constant, ...]:
|
|
|
39
37
|
return (fact, )
|
|
40
38
|
|
|
41
39
|
|
|
42
|
-
@_unify_all_terms.register(Variable)
|
|
43
|
-
def _(fact: Variable) -> tuple[CompoundTerm | Constant, ...]:
|
|
44
|
-
warnings.warn("Variable should not exist in fact", stacklevel=2)
|
|
45
|
-
return ()
|
|
46
|
-
|
|
47
|
-
|
|
48
40
|
class FREEVARANY(Constant):
|
|
49
41
|
"""
|
|
50
42
|
ANY标签,在free_variables用于占位,暂定为一种特殊的Constant。
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
## 默认strategy
|
|
2
|
+
|
|
3
|
+
### Exhausted
|
|
4
|
+
穷尽所有可用term的策略实现。
|
|
5
|
+
|
|
6
|
+
## 创建自己的strategy
|
|
7
|
+
1. 创建一个py文件,命名要求为`_<name>_strategy.py`;
|
|
8
|
+
2. 继承TermSelectionStrategy类,并至少声明此Protocol要求的函数;
|
|
9
|
+
3. 从`kele.control`导入并使用`@register.term_selector('<name>')`注册你的策略类,后续即可通过`grounding_term_strategy`使用策略;
|
|
10
|
+
4. 注意调整`grounding_term_strategy`的类型标注(增加Literal的候选值)。
|
|
11
|
+
|
|
12
|
+
示例:
|
|
13
|
+
```python
|
|
14
|
+
from kele.control import register
|
|
15
|
+
from kele.control.grounding_selector._term_strategies.strategy_protocol import TermSelectionStrategy
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@register.term_selector("MyTermStrategy")
|
|
19
|
+
class MyTermStrategy:
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self._terms = []
|
|
22
|
+
|
|
23
|
+
def add_terms(self, terms):
|
|
24
|
+
self._terms.extend(terms)
|
|
25
|
+
|
|
26
|
+
def reset(self) -> None:
|
|
27
|
+
self._terms = []
|
|
28
|
+
|
|
29
|
+
def select_next(self, rule):
|
|
30
|
+
return list(self._terms)
|
|
31
|
+
|
|
32
|
+
def on_feedback(self, feedback):
|
|
33
|
+
return None
|
|
34
|
+
```
|
|
@@ -3,7 +3,7 @@ import importlib
|
|
|
3
3
|
import logging
|
|
4
4
|
import pathlib
|
|
5
5
|
|
|
6
|
-
from .
|
|
6
|
+
from kele.control.registry import get_term_strategy_class
|
|
7
7
|
|
|
8
8
|
current_dir = pathlib.Path(__file__).resolve().parent
|
|
9
9
|
package_name = __package__ or current_dir.name
|
|
@@ -21,4 +21,4 @@ for filename in current_dir.iterdir():
|
|
|
21
21
|
logger.exception('Failed to import %s', module_name)
|
|
22
22
|
continue
|
|
23
23
|
|
|
24
|
-
__all__ = ["
|
|
24
|
+
__all__ = ["get_term_strategy_class"]
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from collections.abc import Sequence
|
|
3
3
|
|
|
4
|
-
from .
|
|
5
|
-
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION,
|
|
4
|
+
from kele.control.registry import register
|
|
5
|
+
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, TERM_TYPE, Rule
|
|
6
6
|
|
|
7
|
+
from .strategy_protocol import Feedback, TermSelectionStrategy
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
|
|
10
|
+
@register.term_selector('Exhausted')
|
|
9
11
|
class ExhuastedStrategy(TermSelectionStrategy):
|
|
10
12
|
"""
|
|
11
13
|
每次选择剩余的所有terms
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Protocol, runtime_checkable
|
|
4
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
4
5
|
|
|
5
6
|
if TYPE_CHECKING:
|
|
6
7
|
from collections.abc import Sequence
|
|
7
|
-
|
|
8
|
-
from
|
|
8
|
+
|
|
9
|
+
from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, TERM_TYPE, Rule
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
@runtime_checkable
|
|
@@ -23,28 +24,3 @@ class TermSelectionStrategy(Protocol):
|
|
|
23
24
|
@dataclass
|
|
24
25
|
class Feedback:
|
|
25
26
|
"""一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
term_strategy_registry: dict[str, type[TermSelectionStrategy]] = {}
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def register_strategy(name: str) -> Callable[[type[TermSelectionStrategy]], type[TermSelectionStrategy]]:
|
|
32
|
-
"""装饰器:将策略类注册到全局表中"""
|
|
33
|
-
def decorator(cls: type[TermSelectionStrategy]) -> type[TermSelectionStrategy]:
|
|
34
|
-
term_strategy_registry[name] = cls
|
|
35
|
-
return cls
|
|
36
|
-
return decorator
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def get_strategy_class(name: str) -> type[TermSelectionStrategy]:
|
|
40
|
-
try:
|
|
41
|
-
return term_strategy_registry[name]
|
|
42
|
-
except KeyError as err:
|
|
43
|
-
raise KeyError(
|
|
44
|
-
f"Term selection strategy '{name}' is not registered. Available strategies: "
|
|
45
|
-
f"{list(term_strategy_registry.keys())}"
|
|
46
|
-
) from err
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def _available_strategies() -> list[str]:
|
|
50
|
-
return list(term_strategy_registry.keys())
|
|
@@ -2,13 +2,16 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
4
|
from typing import TYPE_CHECKING
|
|
5
|
-
|
|
5
|
+
|
|
6
|
+
from ._rule_strategies import get_rule_strategy_class
|
|
6
7
|
|
|
7
8
|
if TYPE_CHECKING:
|
|
8
|
-
from ._rule_strategies.strategy_protocol import RuleSelectionStrategy, Feedback
|
|
9
9
|
from collections.abc import Sequence
|
|
10
|
+
|
|
10
11
|
from kele.syntax import Rule, _QuestionRule
|
|
11
12
|
|
|
13
|
+
from ._rule_strategies.strategy_protocol import Feedback, RuleSelectionStrategy
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class GroundingRuleSelector:
|
|
14
17
|
"""
|
|
@@ -17,7 +20,7 @@ class GroundingRuleSelector:
|
|
|
17
20
|
- 允许更新规则集合
|
|
18
21
|
"""
|
|
19
22
|
def __init__(self, strategy: str = "sequential_cyclic", question_rule_interval: int = -1) -> None:
|
|
20
|
-
strategy_cls =
|
|
23
|
+
strategy_cls = get_rule_strategy_class(strategy)
|
|
21
24
|
self._strategy = strategy_cls()
|
|
22
25
|
self._normal_rules: list[Rule] | None = None
|
|
23
26
|
|
|
@@ -33,7 +36,7 @@ class GroundingRuleSelector:
|
|
|
33
36
|
def next_rules(self) -> Sequence[Rule]:
|
|
34
37
|
"""选择一定数量的规则用于grounding,在一定轮次后会查看一次question是否被解决
|
|
35
38
|
:raises ValueError: 当 question_rule_interval 小于1且不为-1时抛出。
|
|
36
|
-
"""
|
|
39
|
+
"""
|
|
37
40
|
if self._at_fixpoint:
|
|
38
41
|
if self._question_rules:
|
|
39
42
|
self.used_rule_cnt = 0
|
|
@@ -60,6 +63,12 @@ class GroundingRuleSelector:
|
|
|
60
63
|
|
|
61
64
|
return rules
|
|
62
65
|
|
|
66
|
+
def active_rules_count(self) -> int | None:
|
|
67
|
+
"""返回当前候选规则数量,若未初始化则返回 None。"""
|
|
68
|
+
if self._normal_rules is None:
|
|
69
|
+
return None
|
|
70
|
+
return len(self._normal_rules)
|
|
71
|
+
|
|
63
72
|
def reset(self) -> None:
|
|
64
73
|
"""重置选择器"""
|
|
65
74
|
self._strategy.reset()
|
|
@@ -81,7 +90,7 @@ class GroundingRuleSelector:
|
|
|
81
90
|
"""
|
|
82
91
|
更新规则集合,并同步给当前策略
|
|
83
92
|
:raise: ValueError: 没有可选rules时报错
|
|
84
|
-
"""
|
|
93
|
+
"""
|
|
85
94
|
if not normal_rules:
|
|
86
95
|
raise ValueError("rules cannot be empty")
|
|
87
96
|
|