kele 0.0.1a1__cp313-cp313-win32.whl → 0.0.1a2__cp313-cp313-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. kele/__init__.py +8 -6
  2. kele/_utils.py +23 -0
  3. kele/_version.py +1 -1
  4. kele/config.py +52 -14
  5. kele/control/__init__.py +9 -4
  6. kele/control/builtin_hooks.py +119 -0
  7. kele/control/callback.py +14 -4
  8. kele/control/grounding_selector/__init__.py +7 -1
  9. kele/control/grounding_selector/_rule_strategies/README.md +25 -1
  10. kele/control/grounding_selector/_rule_strategies/__init__.py +2 -2
  11. kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +5 -3
  12. kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +3 -27
  13. kele/control/grounding_selector/_selector_utils.py +2 -10
  14. kele/control/grounding_selector/_term_strategies/README.md +34 -0
  15. kele/control/grounding_selector/_term_strategies/__init__.py +2 -2
  16. kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +5 -3
  17. kele/control/grounding_selector/_term_strategies/strategy_protocol.py +4 -28
  18. kele/control/grounding_selector/rule_selector.py +14 -5
  19. kele/control/grounding_selector/term_selector.py +16 -6
  20. kele/control/infer_path.py +9 -8
  21. kele/control/metrics.py +7 -8
  22. kele/control/registry.py +112 -0
  23. kele/control/status.py +190 -49
  24. kele/egg_equiv.pyd +0 -0
  25. kele/egg_equiv.pyi +3 -3
  26. kele/equality/_equiv_elem.py +2 -1
  27. kele/equality/_utils.py +4 -2
  28. kele/equality/equivalence.py +8 -7
  29. kele/executer/executing.py +40 -24
  30. kele/grounder/__init__.py +2 -7
  31. kele/grounder/grounded_rule_ds/__init__.py +2 -2
  32. kele/grounder/grounded_rule_ds/_nodes/__init__.py +4 -4
  33. kele/grounder/grounded_rule_ds/_nodes/_assertion.py +27 -14
  34. kele/grounder/grounded_rule_ds/_nodes/_conn.py +4 -2
  35. kele/grounder/grounded_rule_ds/_nodes/_op.py +6 -3
  36. kele/grounder/grounded_rule_ds/_nodes/_root.py +4 -4
  37. kele/grounder/grounded_rule_ds/_nodes/_rule.py +5 -5
  38. kele/grounder/grounded_rule_ds/_nodes/_term.py +11 -10
  39. kele/grounder/grounded_rule_ds/_nodes/_tftable.py +1 -0
  40. kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +62 -58
  41. kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +1 -1
  42. kele/grounder/grounded_rule_ds/grounded_class.py +60 -17
  43. kele/grounder/grounded_rule_ds/grounded_ds_utils.py +3 -4
  44. kele/grounder/grounded_rule_ds/rule_check.py +6 -5
  45. kele/grounder/grounding.py +38 -36
  46. kele/knowledge_bases/__init__.py +1 -1
  47. kele/knowledge_bases/builtin_base/builtin_facts.py +3 -1
  48. kele/knowledge_bases/builtin_base/builtin_operators.py +1 -1
  49. kele/knowledge_bases/builtin_base/builtin_rules.py +1 -0
  50. kele/knowledge_bases/fact_base.py +37 -16
  51. kele/knowledge_bases/ontology_base.py +1 -1
  52. kele/knowledge_bases/rule_base.py +53 -32
  53. kele/main.py +55 -37
  54. kele/syntax/__init__.py +12 -12
  55. kele/syntax/_cnf_converter.py +2 -2
  56. kele/syntax/_sat_solver.py +3 -3
  57. kele/syntax/base_classes.py +92 -20
  58. kele/syntax/dnf_converter.py +10 -5
  59. kele/syntax/external.py +2 -1
  60. kele/syntax/sub_concept.py +5 -4
  61. kele/syntax/syntacticsugar.py +3 -4
  62. {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/METADATA +13 -3
  63. kele-0.0.1a2.dist-info/RECORD +78 -0
  64. kele-0.0.1a1.dist-info/RECORD +0 -74
  65. {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/WHEEL +0 -0
  66. {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/licenses/LICENSE +0 -0
  67. {kele-0.0.1a1.dist-info → kele-0.0.1a2.dist-info}/licenses/licensecheck.json +0 -0
kele/__init__.py CHANGED
@@ -1,15 +1,16 @@
1
1
  """支持断言逻辑的推理引擎"""
2
- from kele.main import EngineRunResult, InferenceEngine, QueryStructure
3
2
  from kele.config import (
4
3
  Config,
5
- RunControlConfig,
6
- InferenceStrategyConfig,
7
- GrounderConfig,
8
4
  ExecutorConfig,
9
- PathConfig,
5
+ GrounderConfig,
6
+ InferenceStrategyConfig,
10
7
  KBConfig,
8
+ PathConfig,
9
+ RunControlConfig,
11
10
  )
12
- from kele.syntax.base_classes import Constant, Concept, Operator, Variable, CompoundTerm, Assertion, Formula, Rule
11
+ from kele.control import register
12
+ from kele.main import EngineRunResult, InferenceEngine, QueryStructure
13
+ from kele.syntax.base_classes import Assertion, CompoundTerm, Concept, Constant, Formula, Operator, Rule, Variable
13
14
 
14
15
  try:
15
16
  from ._version import version as __version__
@@ -35,4 +36,5 @@ __all__ = [
35
36
  'Rule',
36
37
  'RunControlConfig',
37
38
  'Variable',
39
+ 'register',
38
40
  ]
kele/_utils.py ADDED
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, TypeVar
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import Callable, Sequence
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def summarize_items(
12
+ items: Sequence[T],
13
+ *,
14
+ sample_size: int = 5,
15
+ formatter: Callable[[T], str] = str,
16
+ ) -> dict[str, Any]:
17
+ """Summarize a sequence for debug logging."""
18
+ total = len(items)
19
+ sample = [formatter(item) for item in items[:sample_size]]
20
+ return {
21
+ "rows": total,
22
+ "sample": sample,
23
+ }
kele/_version.py CHANGED
@@ -1 +1 @@
1
- version = "0.0.1a1"
1
+ version = "0.0.1a2"
kele/config.py CHANGED
@@ -1,17 +1,17 @@
1
1
  # ruff: noqa: ERA001 # Commented parameters are either not implemented yet or depend on unfinished upstream/downstream modules.
2
- import warnings
3
- from typing import Any, cast, Literal
4
-
2
+ import json
5
3
  import logging
6
- from dataclasses import dataclass, fields, field
7
- from datetime import datetime, UTC
4
+ import warnings
5
+ from dataclasses import dataclass, field, fields
6
+ from datetime import UTC, datetime
8
7
  from pathlib import Path
9
- import yaml
10
- import json
11
- import tyro
12
- from tyro.conf import OmitArgPrefixes
8
+ from typing import Any, Literal, cast
9
+
13
10
  import dacite
11
+ import tyro
12
+ import yaml
14
13
  from dacite.config import Config as daConfig
14
+ from tyro.conf import OmitArgPrefixes
15
15
 
16
16
  RESULT_LEVEL = 25
17
17
  logging.RESULT = RESULT_LEVEL # type: ignore[attr-defined] # This fails mypy; setattr fails ruff.
@@ -59,12 +59,50 @@ class InferenceStrategyConfig:
59
59
  @dataclass
60
60
  class GrounderConfig:
61
61
  """Grounder-related parameters."""
62
- grounding_rules_num_every_step: int | Literal[-1] = -1
63
- grounding_facts_num_for_each_rule: int | Literal[-1] = -1
62
+ grounding_rules_per_step: int | Literal[-1] = -1
63
+ grounding_facts_per_rule: int | Literal[-1] = -1
64
+ grounding_rules_num_every_step: int | Literal[-1] | None = None
65
+ grounding_facts_num_for_each_rule: int | Literal[-1] | None = None
64
66
  allow_unify_with_nested_term: bool = True # Allow Variables to be replaced by CompoundTerms.
65
67
  conceptual_fuzzy_unification: bool = True # Use strict concept constraints to accelerate inference.
66
68
  # This depends on correct concept subsumption and full constant.belong_concepts settings; beginners should use loose matching.
67
69
 
70
+ def __post_init__(self) -> None:
71
+ if self.grounding_rules_num_every_step is not None:
72
+ warnings.warn(
73
+ "grounding_rules_num_every_step is deprecated; use grounding_rules_per_step instead.",
74
+ DeprecationWarning,
75
+ stacklevel=2,
76
+ )
77
+ if self.grounding_rules_per_step == -1:
78
+ self.grounding_rules_per_step = self.grounding_rules_num_every_step
79
+ elif self.grounding_rules_per_step != self.grounding_rules_num_every_step:
80
+ warnings.warn(
81
+ "Both grounding_rules_per_step and grounding_rules_num_every_step are set; using grounding_rules_per_step.",
82
+ DeprecationWarning,
83
+ stacklevel=2,
84
+ )
85
+
86
+ if self.grounding_facts_num_for_each_rule is not None:
87
+ warnings.warn(
88
+ "grounding_facts_num_for_each_rule is deprecated; use grounding_facts_per_rule instead.",
89
+ DeprecationWarning,
90
+ stacklevel=2,
91
+ )
92
+ if self.grounding_facts_per_rule == -1:
93
+ self.grounding_facts_per_rule = self.grounding_facts_num_for_each_rule
94
+ elif self.grounding_facts_per_rule != self.grounding_facts_num_for_each_rule:
95
+ warnings.warn(
96
+ "Both grounding_facts_per_rule and grounding_facts_num_for_each_rule are set; using grounding_facts_per_rule.",
97
+ DeprecationWarning,
98
+ stacklevel=2,
99
+ )
100
+
101
+ if self.grounding_rules_num_every_step is None:
102
+ self.grounding_rules_num_every_step = self.grounding_rules_per_step
103
+ if self.grounding_facts_num_for_each_rule is None:
104
+ self.grounding_facts_num_for_each_rule = self.grounding_facts_per_rule
105
+
68
106
 
69
107
  @dataclass
70
108
  class ExecutorConfig:
@@ -103,7 +141,7 @@ class Config:
103
141
 
104
142
 
105
143
  def _load_config_file(path: str) -> dict[str, Any]:
106
- with open(path, encoding="utf-8") as f:
144
+ with Path(path).open(encoding="utf-8") as f:
107
145
  if path.endswith(('.yaml', '.yml')):
108
146
  data = yaml.safe_load(f)
109
147
  elif path.endswith('.json'):
@@ -117,7 +155,7 @@ def _load_config_file(path: str) -> dict[str, Any]:
117
155
 
118
156
 
119
157
  def _save_config(config: dict[str, Any], path: str) -> None:
120
- with open(path, 'w', encoding='utf8') as f:
158
+ with Path(path).open('w', encoding='utf8') as f:
121
159
  yaml.dump(config, f, sort_keys=False)
122
160
 
123
161
 
@@ -197,7 +235,7 @@ def _build_config(user_config: Config | None = None,
197
235
 
198
236
  :raises: ValueError: If `user_config` is used together with `config_file_path`
199
237
  or when `user_config.config` is set.
200
- """ # noqa: DOC501
238
+ """
201
239
  if user_config and (user_config.config or config_file_path):
202
240
  raise ValueError("default config instance and config file cannot be used together")
203
241
 
kele/control/__init__.py CHANGED
@@ -1,14 +1,17 @@
1
1
  """用于callbacks和推理路径的记录"""
2
- from .callback import HookMixin, Callback, CallbackManager
2
+ from .callback import Callback, CallbackManager, HookMixin
3
+ from .grounding_selector import GroundingRuleSelector
4
+ from .builtin_hooks import BuiltinHookEnabler, register_assertion_check_hook
5
+ from .infer_path import InferencePath
6
+ from .registry import register
3
7
  from .status import (
4
8
  InferenceStatus,
5
- create_main_loop_manager,
6
9
  create_executor_manager,
10
+ create_main_loop_manager,
7
11
  )
8
- from .grounding_selector import GroundingRuleSelector
9
- from .infer_path import InferencePath
10
12
 
11
13
  __all__ = [
14
+ 'BuiltinHookEnabler',
12
15
  'Callback',
13
16
  'CallbackManager',
14
17
  'GroundingRuleSelector',
@@ -17,4 +20,6 @@ __all__ = [
17
20
  'InferenceStatus',
18
21
  'create_executor_manager',
19
22
  'create_main_loop_manager',
23
+ 'register',
24
+ 'register_assertion_check_hook',
20
25
  ]
@@ -0,0 +1,119 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from collections.abc import Callable, Mapping
8
+
9
+ from kele.grounder import GroundedRule
10
+ from kele.grounder.grounded_rule_ds._nodes import _AssertionNode
11
+ from kele.syntax import CompoundTerm, Constant, Rule, Variable
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def register_assertion_check_hook(
17
+ grounded_rule: GroundedRule,
18
+ *,
19
+ rule_name: str | None = None,
20
+ vars_filter: Mapping[str, str] | None = None,
21
+ on_match: Callable[[Rule, _AssertionNode, dict[Variable, Constant | CompoundTerm], bool], None] | None = None,
22
+ break_on_match: bool = False,
23
+ ) -> None:
24
+ """
25
+ Register a built-in hook that observes assertion check results.
26
+
27
+ This is a user-facing inspection hook: it lets you capture whether a specific
28
+ assertion evaluation (for a given variable binding) is True/False, and optionally
29
+ stop on matches for deep debugging.
30
+
31
+ Example:
32
+ def log_match(rule, assertion, combination, result):
33
+ print(rule.name, assertion.content, combination, result)
34
+
35
+ register_assertion_check_hook(
36
+ grounded_rule,
37
+ rule_name="rule_3",
38
+ vars_filter={"p1": "f", "p2": "a", "p13": "b", "p14": "c"},
39
+ on_match=log_match,
40
+ break_on_match=True,
41
+ )
42
+ """
43
+ normalized_vars_filter = dict(vars_filter) if vars_filter is not None else None
44
+
45
+ def _hook(
46
+ *,
47
+ rule: Rule,
48
+ assertion: _AssertionNode,
49
+ combination: dict[Variable, Constant | CompoundTerm],
50
+ result: bool,
51
+ ) -> None:
52
+ if rule_name and rule.name != rule_name:
53
+ return
54
+
55
+ combination_str = {str(k): str(v) for k, v in combination.items()}
56
+ if normalized_vars_filter:
57
+ for key, value in normalized_vars_filter.items():
58
+ if combination_str.get(key) != value:
59
+ return
60
+
61
+ if on_match is None:
62
+ logger.debug(
63
+ "Assertion check hook: rule=%s content=%s combination=%s result=%s",
64
+ rule,
65
+ assertion.content,
66
+ combination_str,
67
+ result,
68
+ )
69
+ else:
70
+ on_match(rule, assertion, combination, result)
71
+
72
+ if break_on_match:
73
+ breakpoint() # noqa: T100
74
+
75
+ grounded_rule.register_hook("assertion_check", _hook)
76
+
77
+
78
+ class BuiltinHookEnabler:
79
+ """
80
+ Enabler for built-in hooks by name.
81
+
82
+ This class does not register hooks by itself. Create an instance and call
83
+ ``enable`` to attach a built-in hook to a grounded rule.
84
+
85
+ Example:
86
+ hooks = BuiltinHookEnabler()
87
+ hooks.enable(
88
+ grounded_rule,
89
+ "assertion_check",
90
+ rule_name="rule_3",
91
+ vars_filter={"p1": "f"},
92
+ )
93
+ """
94
+
95
+ def __init__(self) -> None:
96
+ self._hooks: dict[str, Callable[..., None]] = {
97
+ "assertion_check": register_assertion_check_hook,
98
+ }
99
+
100
+ def available_hooks(self) -> list[str]:
101
+ """
102
+ Return the names of available built-in hooks.
103
+ """
104
+ return sorted(self._hooks)
105
+
106
+ def enable(self, grounded_rule: GroundedRule, name: str, **kwargs: object) -> None:
107
+ """
108
+ Enable a built-in hook by name.
109
+ """
110
+ if name not in self._hooks:
111
+ raise KeyError(f"Unknown built-in hook: {name}")
112
+ self._hooks[name](grounded_rule, **kwargs)
113
+
114
+ def enable_many(self, grounded_rule: GroundedRule, names: list[str], **kwargs: object) -> None:
115
+ """
116
+ Enable multiple built-in hooks by name.
117
+ """
118
+ for hook_name in names:
119
+ self.enable(grounded_rule, hook_name, **kwargs)
kele/control/callback.py CHANGED
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
- from typing import Any, TYPE_CHECKING
5
-
4
+ from typing import TYPE_CHECKING, Any
6
5
 
7
6
  if TYPE_CHECKING:
7
+ from collections.abc import Callable
8
+
8
9
  from kele.equality import Equivalence
9
10
  from kele.grounder import GroundedRule
10
11
  from kele.knowledge_bases import FactBase, RuleBase
11
- from kele.syntax import Question, Rule, FACT_TYPE, Variable, Constant, CompoundTerm
12
- from collections.abc import Callable
12
+ from kele.syntax import FACT_TYPE, CompoundTerm, Constant, Question, Rule, Variable
13
13
 
14
14
 
15
15
  class HookMixin:
@@ -41,6 +41,16 @@ class HookMixin:
41
41
  for hook in self._hooks.get(event_name, []):
42
42
  hook(*args, **kwargs)
43
43
 
44
+ def run_hooks(self, event_name: str, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
45
+ """
46
+ 对外公开的钩子触发入口。
47
+
48
+ :param event_name: 事件名称。
49
+ :param args: 传递给钩子的所有位置参数。
50
+ :param kwargs: 传递给钩子的所有关键字参数。
51
+ """
52
+ self._run_hooks(event_name, *args, **kwargs)
53
+
44
54
 
45
55
  class Callback:
46
56
  """回调接口——在推理各阶段采集指标的Hook"""
@@ -1,5 +1,11 @@
1
1
  """grounding相关的选择器"""
2
+ from kele.control.registry import register
3
+
2
4
  from .rule_selector import GroundingRuleSelector
3
5
  from .term_selector import GroundingFlatTermWithWildCardSelector
4
6
 
5
- __all__ = ["GroundingFlatTermWithWildCardSelector", "GroundingRuleSelector"]
7
+ __all__ = [
8
+ "GroundingFlatTermWithWildCardSelector",
9
+ "GroundingRuleSelector",
10
+ "register",
11
+ ]
@@ -9,5 +9,29 @@
9
9
  ## 创建自己的strategy
10
10
  1. 创建一个py文件,命名要求为`_<name>_strategy.py`;
11
11
  2. 继承RuleSelectionStrategy类,并至少声明此Protocol要求的函数;
12
- 3. 使用`@register_strategy('<name>')`注册你的策略类,后续即可通过`grounding_rule_strategy`使用策略;
12
+ 3. 从`kele.control`导入并使用`@register.rule_selector('<name>')`注册你的策略类,后续即可通过`grounding_rule_strategy`使用策略;
13
13
  4. 注意调整`grounding_rule_strategy`的类型标注(增加Literal的候选值)。
14
+
15
+ 示例:
16
+ ```python
17
+ from kele.control import register
18
+ from kele.control.grounding_selector._rule_strategies.strategy_protocol import RuleSelectionStrategy
19
+
20
+
21
+ @register.rule_selector("MyRuleStrategy")
22
+ class MyRuleStrategy:
23
+ def __init__(self) -> None:
24
+ self._rules = []
25
+
26
+ def set_rules(self, rules):
27
+ self._rules = list(rules)
28
+
29
+ def reset(self) -> None:
30
+ self._rules = []
31
+
32
+ def select_next(self):
33
+ return self._rules[:1]
34
+
35
+ def on_feedback(self, feedback):
36
+ return None
37
+ ```
@@ -3,7 +3,7 @@ import importlib
3
3
  import logging
4
4
  import pathlib
5
5
 
6
- from .strategy_protocol import get_strategy_class
6
+ from kele.control.registry import get_rule_strategy_class
7
7
 
8
8
  current_dir = pathlib.Path(__file__).resolve().parent
9
9
  package_name = __package__ or current_dir.name
@@ -21,4 +21,4 @@ for filename in current_dir.iterdir():
21
21
  logger.exception('Failed to import %s', module_name)
22
22
  continue
23
23
 
24
- __all__ = ["get_strategy_class"]
24
+ __all__ = ["get_rule_strategy_class"]
@@ -1,10 +1,12 @@
1
1
  from collections.abc import Sequence
2
2
 
3
- from .strategy_protocol import Feedback, register_strategy, RuleSelectionStrategy
3
+ from kele.control.registry import register
4
4
  from kele.syntax import Rule
5
5
 
6
+ from .strategy_protocol import Feedback, RuleSelectionStrategy
6
7
 
7
- @register_strategy('SequentialCyclic')
8
+
9
+ @register.rule_selector('SequentialCyclic')
8
10
  class SequentialCyclicStrategy(RuleSelectionStrategy):
9
11
  """
10
12
  按顺序循环遍历策略:
@@ -33,7 +35,7 @@ class SequentialCyclicStrategy(RuleSelectionStrategy):
33
35
  return
34
36
 
35
37
 
36
- @register_strategy("SequentialCyclicWithPriority")
38
+ @register.rule_selector("SequentialCyclicWithPriority")
37
39
  class SequentialCyclicWithPriorityStrategy(SequentialCyclicStrategy):
38
40
  """将规则按优先级排序,优先级高的先取"""
39
41
 
@@ -1,11 +1,12 @@
1
1
  from __future__ import annotations
2
+
2
3
  from dataclasses import dataclass
3
- from typing import Protocol, runtime_checkable, TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Protocol, runtime_checkable
4
5
 
5
6
  if TYPE_CHECKING:
6
7
  from collections.abc import Sequence
8
+
7
9
  from kele.syntax import Rule
8
- from collections.abc import Callable
9
10
 
10
11
 
11
12
  @runtime_checkable
@@ -24,28 +25,3 @@ class RuleSelectionStrategy(Protocol):
24
25
  class Feedback:
25
26
  """一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
26
27
  rule: Rule | None = None # 这次反馈关联到的规则;hack: 后面增加其他的相关信息,可能有facts等等
27
-
28
-
29
- rule_strategy_registry: dict[str, type[RuleSelectionStrategy]] = {}
30
-
31
-
32
- def register_strategy(name: str) -> Callable[[type[RuleSelectionStrategy]], type[RuleSelectionStrategy]]:
33
- """装饰器:将策略类注册到全局表中"""
34
- def decorator(cls: type[RuleSelectionStrategy]) -> type[RuleSelectionStrategy]:
35
- rule_strategy_registry[name] = cls
36
- return cls
37
- return decorator
38
-
39
-
40
- def get_strategy_class(name: str) -> type[RuleSelectionStrategy]:
41
- try:
42
- return rule_strategy_registry[name]
43
- except KeyError as err:
44
- raise KeyError(
45
- f"Rule selection strategy '{name}' is not registered. Available strategies: "
46
- f"{list(rule_strategy_registry.keys())}"
47
- ) from err
48
-
49
-
50
- def _available_strategies() -> list[str]:
51
- return list(rule_strategy_registry.keys())
@@ -1,11 +1,9 @@
1
1
  import itertools
2
- import warnings
3
2
  from collections.abc import Sequence
3
+ from functools import singledispatch
4
4
 
5
5
  from kele.knowledge_bases.builtin_base.builtin_concepts import FREEVARANY_CONCEPT
6
- from kele.syntax import (FACT_TYPE, TERM_TYPE, Assertion, Formula, CompoundTerm,
7
- Constant, Variable, ATOM_TYPE, Rule)
8
- from functools import singledispatch
6
+ from kele.syntax import ATOM_TYPE, FACT_TYPE, TERM_TYPE, Assertion, CompoundTerm, Constant, Formula, Rule
9
7
 
10
8
 
11
9
  @singledispatch
@@ -39,12 +37,6 @@ def _(fact: Constant) -> tuple[CompoundTerm | Constant, ...]:
39
37
  return (fact, )
40
38
 
41
39
 
42
- @_unify_all_terms.register(Variable)
43
- def _(fact: Variable) -> tuple[CompoundTerm | Constant, ...]:
44
- warnings.warn("Variable should not exist in fact", stacklevel=2)
45
- return ()
46
-
47
-
48
40
  class FREEVARANY(Constant):
49
41
  """
50
42
  ANY标签,在free_variables用于占位,暂定为一种特殊的Constant。
@@ -0,0 +1,34 @@
1
+ ## 默认strategy
2
+
3
+ ### Exhausted
4
+ 穷尽所有可用term的策略实现。
5
+
6
+ ## 创建自己的strategy
7
+ 1. 创建一个py文件,命名要求为`_<name>_strategy.py`;
8
+ 2. 继承TermSelectionStrategy类,并至少声明此Protocol要求的函数;
9
+ 3. 从`kele.control`导入并使用`@register.term_selector('<name>')`注册你的策略类,后续即可通过`grounding_term_strategy`使用策略;
10
+ 4. 注意调整`grounding_term_strategy`的类型标注(增加Literal的候选值)。
11
+
12
+ 示例:
13
+ ```python
14
+ from kele.control import register
15
+ from kele.control.grounding_selector._term_strategies.strategy_protocol import TermSelectionStrategy
16
+
17
+
18
+ @register.term_selector("MyTermStrategy")
19
+ class MyTermStrategy:
20
+ def __init__(self) -> None:
21
+ self._terms = []
22
+
23
+ def add_terms(self, terms):
24
+ self._terms.extend(terms)
25
+
26
+ def reset(self) -> None:
27
+ self._terms = []
28
+
29
+ def select_next(self, rule):
30
+ return list(self._terms)
31
+
32
+ def on_feedback(self, feedback):
33
+ return None
34
+ ```
@@ -3,7 +3,7 @@ import importlib
3
3
  import logging
4
4
  import pathlib
5
5
 
6
- from .strategy_protocol import get_strategy_class
6
+ from kele.control.registry import get_term_strategy_class
7
7
 
8
8
  current_dir = pathlib.Path(__file__).resolve().parent
9
9
  package_name = __package__ or current_dir.name
@@ -21,4 +21,4 @@ for filename in current_dir.iterdir():
21
21
  logger.exception('Failed to import %s', module_name)
22
22
  continue
23
23
 
24
- __all__ = ["get_strategy_class"]
24
+ __all__ = ["get_term_strategy_class"]
@@ -1,11 +1,13 @@
1
1
  from collections import defaultdict
2
2
  from collections.abc import Sequence
3
3
 
4
- from .strategy_protocol import Feedback, register_strategy, TermSelectionStrategy
5
- from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule, TERM_TYPE
4
+ from kele.control.registry import register
5
+ from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, TERM_TYPE, Rule
6
6
 
7
+ from .strategy_protocol import Feedback, TermSelectionStrategy
7
8
 
8
- @register_strategy('Exhausted')
9
+
10
+ @register.term_selector('Exhausted')
9
11
  class ExhuastedStrategy(TermSelectionStrategy):
10
12
  """
11
13
  每次选择剩余的所有terms
@@ -1,11 +1,12 @@
1
1
  from __future__ import annotations
2
+
2
3
  from dataclasses import dataclass
3
- from typing import Protocol, runtime_checkable, TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Protocol, runtime_checkable
4
5
 
5
6
  if TYPE_CHECKING:
6
7
  from collections.abc import Sequence
7
- from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, Rule, TERM_TYPE
8
- from collections.abc import Callable
8
+
9
+ from kele.syntax import GROUNDED_TYPE_FOR_UNIFICATION, TERM_TYPE, Rule
9
10
 
10
11
 
11
12
  @runtime_checkable
@@ -23,28 +24,3 @@ class TermSelectionStrategy(Protocol):
23
24
  @dataclass
24
25
  class Feedback:
25
26
  """一次选择后的可选反馈信息;字段都可缺省,策略按需使用。"""
26
-
27
-
28
- term_strategy_registry: dict[str, type[TermSelectionStrategy]] = {}
29
-
30
-
31
- def register_strategy(name: str) -> Callable[[type[TermSelectionStrategy]], type[TermSelectionStrategy]]:
32
- """装饰器:将策略类注册到全局表中"""
33
- def decorator(cls: type[TermSelectionStrategy]) -> type[TermSelectionStrategy]:
34
- term_strategy_registry[name] = cls
35
- return cls
36
- return decorator
37
-
38
-
39
- def get_strategy_class(name: str) -> type[TermSelectionStrategy]:
40
- try:
41
- return term_strategy_registry[name]
42
- except KeyError as err:
43
- raise KeyError(
44
- f"Term selection strategy '{name}' is not registered. Available strategies: "
45
- f"{list(term_strategy_registry.keys())}"
46
- ) from err
47
-
48
-
49
- def _available_strategies() -> list[str]:
50
- return list(term_strategy_registry.keys())
@@ -2,13 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  import warnings
4
4
  from typing import TYPE_CHECKING
5
- from ._rule_strategies import get_strategy_class
5
+
6
+ from ._rule_strategies import get_rule_strategy_class
6
7
 
7
8
  if TYPE_CHECKING:
8
- from ._rule_strategies.strategy_protocol import RuleSelectionStrategy, Feedback
9
9
  from collections.abc import Sequence
10
+
10
11
  from kele.syntax import Rule, _QuestionRule
11
12
 
13
+ from ._rule_strategies.strategy_protocol import Feedback, RuleSelectionStrategy
14
+
12
15
 
13
16
  class GroundingRuleSelector:
14
17
  """
@@ -17,7 +20,7 @@ class GroundingRuleSelector:
17
20
  - 允许更新规则集合
18
21
  """
19
22
  def __init__(self, strategy: str = "sequential_cyclic", question_rule_interval: int = -1) -> None:
20
- strategy_cls = get_strategy_class(strategy)
23
+ strategy_cls = get_rule_strategy_class(strategy)
21
24
  self._strategy = strategy_cls()
22
25
  self._normal_rules: list[Rule] | None = None
23
26
 
@@ -33,7 +36,7 @@ class GroundingRuleSelector:
33
36
  def next_rules(self) -> Sequence[Rule]:
34
37
  """选择一定数量的规则用于grounding,在一定轮次后会查看一次question是否被解决
35
38
  :raises ValueError: 当 question_rule_interval 小于1且不为-1时抛出。
36
- """ # noqa: DOC501
39
+ """
37
40
  if self._at_fixpoint:
38
41
  if self._question_rules:
39
42
  self.used_rule_cnt = 0
@@ -60,6 +63,12 @@ class GroundingRuleSelector:
60
63
 
61
64
  return rules
62
65
 
66
+ def active_rules_count(self) -> int | None:
67
+ """返回当前候选规则数量,若未初始化则返回 None。"""
68
+ if self._normal_rules is None:
69
+ return None
70
+ return len(self._normal_rules)
71
+
63
72
  def reset(self) -> None:
64
73
  """重置选择器"""
65
74
  self._strategy.reset()
@@ -81,7 +90,7 @@ class GroundingRuleSelector:
81
90
  """
82
91
  更新规则集合,并同步给当前策略
83
92
  :raise: ValueError: 没有可选rules时报错
84
- """ # noqa: DOC501
93
+ """
85
94
  if not normal_rules:
86
95
  raise ValueError("rules cannot be empty")
87
96