kele 0.0.1a2__cp314-cp314-win_amd64.whl → 0.0.1b1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/_version.py +1 -1
- kele/config.py +6 -1
- kele/control/builtin_hooks.py +22 -16
- kele/control/callback.py +33 -2
- kele/control/grounding_selector/_rule_strategies/_scc_sort_strategy.py +110 -0
- kele/egg_equiv.pyd +0 -0
- kele/executer/executing.py +1 -1
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +224 -52
- kele/grounder/grounded_rule_ds/_nodes/_term.py +1 -1
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +23 -12
- kele/grounder/grounded_rule_ds/grounded_class.py +446 -140
- kele/grounder/grounded_rule_ds/rule_check.py +8 -8
- kele/knowledge_bases/builtin_base/builtin_operators.py +23 -1
- kele/main.py +3 -3
- kele/syntax/base_classes.py +65 -31
- kele/utils.py +60 -0
- {kele-0.0.1a2.dist-info → kele-0.0.1b1.dist-info}/METADATA +11 -2
- {kele-0.0.1a2.dist-info → kele-0.0.1b1.dist-info}/RECORD +21 -19
- {kele-0.0.1a2.dist-info → kele-0.0.1b1.dist-info}/WHEEL +0 -0
- {kele-0.0.1a2.dist-info → kele-0.0.1b1.dist-info}/licenses/LICENSE +0 -0
- {kele-0.0.1a2.dist-info → kele-0.0.1b1.dist-info}/licenses/licensecheck.json +0 -0
kele/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
version = "0.0.
|
|
1
|
+
version = "0.0.1b1"
|
kele/config.py
CHANGED
|
@@ -50,7 +50,12 @@ class InferenceStrategyConfig:
|
|
|
50
50
|
select_rules_num: int | Literal[-1] = -1 # Number of rules to select.
|
|
51
51
|
select_facts_num: int | Literal[-1] = -1 # Number of facts to select; -1 means all facts.
|
|
52
52
|
# premise_selection_strategy: Literal[''] = '' # Premise selection algorithm. TODO: Unused.
|
|
53
|
-
grounding_rule_strategy: Literal[
|
|
53
|
+
grounding_rule_strategy: Literal[
|
|
54
|
+
'SequentialCyclic',
|
|
55
|
+
'SequentialCyclicWithPriority',
|
|
56
|
+
'SccSort',
|
|
57
|
+
'ReverseSccSort'
|
|
58
|
+
] = "SequentialCyclic" # Rule selection strategy in grounding.
|
|
54
59
|
# executing_sort_strategy: Literal[''] = '' # Execution order strategy. TODO: Unused.
|
|
55
60
|
grounding_term_strategy: Literal['Exhausted'] = "Exhausted" # Term selection strategy in grounding.
|
|
56
61
|
question_rule_interval: int = 1 # Insert a question rule every N rules; -1 uses total rule count as the interval.
|
kele/control/builtin_hooks.py
CHANGED
|
@@ -1,24 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
4
|
+
from typing import TYPE_CHECKING, Protocol
|
|
5
|
+
|
|
6
|
+
from kele.utils import format_mapping
|
|
5
7
|
|
|
6
8
|
if TYPE_CHECKING:
|
|
7
9
|
from collections.abc import Callable, Mapping
|
|
8
10
|
|
|
9
|
-
from kele.grounder import GroundedRule
|
|
10
11
|
from kele.grounder.grounded_rule_ds._nodes import _AssertionNode
|
|
11
|
-
from kele.syntax import CompoundTerm, Constant, Rule, Variable
|
|
12
|
+
from kele.syntax import CompoundTerm, Constant, Rule, Variable, Assertion
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
class AssertionCheckHookMatch(Protocol):
|
|
18
|
+
"""the protocol of output func (on_match in register_assertion_check_hook) in assertion check hook"""
|
|
19
|
+
def __call__(self, rule: Rule, assertion: Assertion, combination: dict[Variable, Constant | CompoundTerm], *, result: bool) \
|
|
20
|
+
-> None: ... # noqa: D102
|
|
21
|
+
|
|
22
|
+
|
|
16
23
|
def register_assertion_check_hook(
|
|
17
|
-
grounded_rule: GroundedRule,
|
|
18
24
|
*,
|
|
19
25
|
rule_name: str | None = None,
|
|
20
26
|
vars_filter: Mapping[str, str] | None = None,
|
|
21
|
-
on_match:
|
|
27
|
+
on_match: AssertionCheckHookMatch | None = None,
|
|
22
28
|
break_on_match: bool = False,
|
|
23
29
|
) -> None:
|
|
24
30
|
"""
|
|
@@ -33,7 +39,6 @@ def register_assertion_check_hook(
|
|
|
33
39
|
print(rule.name, assertion.content, combination, result)
|
|
34
40
|
|
|
35
41
|
register_assertion_check_hook(
|
|
36
|
-
grounded_rule,
|
|
37
42
|
rule_name="rule_3",
|
|
38
43
|
vars_filter={"p1": "f", "p2": "a", "p13": "b", "p14": "c"},
|
|
39
44
|
on_match=log_match,
|
|
@@ -55,7 +60,7 @@ def register_assertion_check_hook(
|
|
|
55
60
|
combination_str = {str(k): str(v) for k, v in combination.items()}
|
|
56
61
|
if normalized_vars_filter:
|
|
57
62
|
for key, value in normalized_vars_filter.items():
|
|
58
|
-
if combination_str.get(key) != value:
|
|
63
|
+
if combination_str.get(key) != str(value):
|
|
59
64
|
return
|
|
60
65
|
|
|
61
66
|
if on_match is None:
|
|
@@ -63,16 +68,18 @@ def register_assertion_check_hook(
|
|
|
63
68
|
"Assertion check hook: rule=%s content=%s combination=%s result=%s",
|
|
64
69
|
rule,
|
|
65
70
|
assertion.content,
|
|
66
|
-
combination_str,
|
|
71
|
+
format_mapping(combination_str),
|
|
67
72
|
result,
|
|
68
73
|
)
|
|
69
74
|
else:
|
|
70
|
-
on_match(rule, assertion, combination, result)
|
|
75
|
+
on_match(rule, assertion.content, combination, result=result)
|
|
71
76
|
|
|
72
77
|
if break_on_match:
|
|
73
78
|
breakpoint() # noqa: T100
|
|
74
79
|
|
|
75
|
-
|
|
80
|
+
from kele.grounder import GroundedRule # noqa: PLC0415
|
|
81
|
+
|
|
82
|
+
GroundedRule.register_class_hook("assertion_check", _hook)
|
|
76
83
|
|
|
77
84
|
|
|
78
85
|
class BuiltinHookEnabler:
|
|
@@ -80,12 +87,11 @@ class BuiltinHookEnabler:
|
|
|
80
87
|
Enabler for built-in hooks by name.
|
|
81
88
|
|
|
82
89
|
This class does not register hooks by itself. Create an instance and call
|
|
83
|
-
``enable`` to attach a built-in hook to
|
|
90
|
+
``enable`` to attach a built-in hook to every grounded rule instance.
|
|
84
91
|
|
|
85
92
|
Example:
|
|
86
93
|
hooks = BuiltinHookEnabler()
|
|
87
94
|
hooks.enable(
|
|
88
|
-
grounded_rule,
|
|
89
95
|
"assertion_check",
|
|
90
96
|
rule_name="rule_3",
|
|
91
97
|
vars_filter={"p1": "f"},
|
|
@@ -103,17 +109,17 @@ class BuiltinHookEnabler:
|
|
|
103
109
|
"""
|
|
104
110
|
return sorted(self._hooks)
|
|
105
111
|
|
|
106
|
-
def enable(self,
|
|
112
|
+
def enable(self, name: str, **kwargs: object) -> None:
|
|
107
113
|
"""
|
|
108
114
|
Enable a built-in hook by name.
|
|
109
115
|
"""
|
|
110
116
|
if name not in self._hooks:
|
|
111
117
|
raise KeyError(f"Unknown built-in hook: {name}")
|
|
112
|
-
self._hooks[name](
|
|
118
|
+
self._hooks[name](**kwargs)
|
|
113
119
|
|
|
114
|
-
def enable_many(self,
|
|
120
|
+
def enable_many(self, names: list[str], **kwargs: object) -> None:
|
|
115
121
|
"""
|
|
116
122
|
Enable multiple built-in hooks by name.
|
|
117
123
|
"""
|
|
118
124
|
for hook_name in names:
|
|
119
|
-
self.enable(
|
|
125
|
+
self.enable(hook_name, **kwargs)
|
kele/control/callback.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
-
from
|
|
4
|
+
from threading import RLock
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
5
6
|
|
|
6
7
|
if TYPE_CHECKING:
|
|
7
8
|
from collections.abc import Callable
|
|
@@ -18,9 +19,36 @@ class HookMixin:
|
|
|
18
19
|
|
|
19
20
|
:ivar _hooks: 事件名称到钩子函数列表的映射。
|
|
20
21
|
"""
|
|
22
|
+
_class_hooks: ClassVar[dict[str, list[Callable[..., None]]]]
|
|
23
|
+
_hooks_lock: ClassVar[RLock] = RLock()
|
|
24
|
+
|
|
25
|
+
def __init_subclass__(cls, **kwargs: object) -> None:
|
|
26
|
+
super().__init_subclass__(**kwargs)
|
|
27
|
+
parent_hooks = getattr(cls, "_class_hooks", None)
|
|
28
|
+
class_hooks: dict[str, list[Callable[..., None]]] = defaultdict(list)
|
|
29
|
+
if parent_hooks:
|
|
30
|
+
for event_name, hooks in parent_hooks.items():
|
|
31
|
+
class_hooks[event_name] = list(hooks)
|
|
32
|
+
cls._class_hooks = class_hooks
|
|
33
|
+
|
|
21
34
|
def __init__(self) -> None:
|
|
22
35
|
self._hooks: dict[str, list[Callable[..., None]]] = defaultdict(list)
|
|
23
36
|
|
|
37
|
+
@classmethod
|
|
38
|
+
def register_class_hook(cls, event_name: str, hook_fn: Callable[..., None], *, unique: bool = True) -> None:
|
|
39
|
+
"""
|
|
40
|
+
在类级别注册 hook,所有实例触发事件时都会执行该 hook。
|
|
41
|
+
|
|
42
|
+
:param event_name: 要监听的事件名称。
|
|
43
|
+
:param hook_fn: 接受任意参数的可调用钩子函数。
|
|
44
|
+
:param unique: 是否防止同一 hook 重复注册。
|
|
45
|
+
"""
|
|
46
|
+
with cls._hooks_lock:
|
|
47
|
+
hooks = cls._class_hooks[event_name]
|
|
48
|
+
if unique and any(fn is hook_fn for fn in hooks):
|
|
49
|
+
return
|
|
50
|
+
hooks.append(hook_fn)
|
|
51
|
+
|
|
24
52
|
def register_hook(self, event_name: str, hook_fn: Callable[..., None]) -> None:
|
|
25
53
|
"""
|
|
26
54
|
为指定事件注册钩子函数。
|
|
@@ -38,7 +66,10 @@ class HookMixin:
|
|
|
38
66
|
:param args: 传递给钩子的所有位置参数。
|
|
39
67
|
:param kwargs: 传递给钩子的所有关键字参数。
|
|
40
68
|
"""
|
|
41
|
-
|
|
69
|
+
with self._hooks_lock:
|
|
70
|
+
class_hooks = list(self.__class__._class_hooks.get(event_name, [])) # noqa: SLF001
|
|
71
|
+
instance_hooks = list(self._hooks.get(event_name, []))
|
|
72
|
+
for hook in class_hooks + instance_hooks:
|
|
42
73
|
hook(*args, **kwargs)
|
|
43
74
|
|
|
44
75
|
def run_hooks(self, event_name: str, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import typing
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import networkx as nx
|
|
6
|
+
|
|
7
|
+
from kele.control.registry import register
|
|
8
|
+
from kele.syntax.base_classes import FACT_TYPE, Assertion, CompoundTerm, Constant, Formula, Operator, Rule, Variable
|
|
9
|
+
|
|
10
|
+
from .strategy_protocol import Feedback, RuleSelectionStrategy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SccSortStrategyBase(RuleSelectionStrategy):
|
|
14
|
+
_reverse_order: bool
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
warnings.warn("SccSort and ReverseSccSort are not stable features. Use them at your own risk.", stacklevel=3)
|
|
18
|
+
self._rules: list[Rule]
|
|
19
|
+
self._graph: nx.DiGraph
|
|
20
|
+
"""Under current implementation, this is a bipartite graph. There may
|
|
21
|
+
be further performance optimizations with this property?"""
|
|
22
|
+
self._order: list[Rule]
|
|
23
|
+
self._stateful_gen: typing.Generator[Rule]
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def _all_operators_in_fact_type(cls, f: FACT_TYPE | CompoundTerm | Constant | Variable | None) -> set[Operator]:
|
|
27
|
+
match f:
|
|
28
|
+
case CompoundTerm(operator=operator, arguments=arguments):
|
|
29
|
+
s = set()
|
|
30
|
+
s.add(operator)
|
|
31
|
+
for argument in arguments:
|
|
32
|
+
s |= cls._all_operators_in_fact_type(argument)
|
|
33
|
+
return s
|
|
34
|
+
case Assertion(lhs=lhs, rhs=rhs):
|
|
35
|
+
return cls._all_operators_in_fact_type(lhs) | cls._all_operators_in_fact_type(rhs)
|
|
36
|
+
case Formula(formula_left=l, formula_right=r):
|
|
37
|
+
return cls._all_operators_in_fact_type(l) | cls._all_operators_in_fact_type(r)
|
|
38
|
+
case _:
|
|
39
|
+
return set()
|
|
40
|
+
|
|
41
|
+
def _rules_to_graph(self, rules: list[Rule]) -> nx.DiGraph:
|
|
42
|
+
graph = nx.DiGraph()
|
|
43
|
+
nodes_left = []
|
|
44
|
+
for rule in rules:
|
|
45
|
+
graph.add_node(rule, type_='rule')
|
|
46
|
+
nodes_left.append(rule)
|
|
47
|
+
heads = self._all_operators_in_fact_type(rule.head)
|
|
48
|
+
bodies = self._all_operators_in_fact_type(rule.body)
|
|
49
|
+
|
|
50
|
+
for op in bodies:
|
|
51
|
+
graph.add_edge(op, rule)
|
|
52
|
+
for op in heads:
|
|
53
|
+
graph.add_edge(rule, op)
|
|
54
|
+
return graph
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _compile_order(graph: nx.DiGraph) -> typing.Generator[Rule]:
|
|
58
|
+
def _order_within_scc(subgraph: nx.DiGraph) -> typing.Generator[Rule]:
|
|
59
|
+
temp_g = subgraph.copy()
|
|
60
|
+
while temp_g.nodes:
|
|
61
|
+
out_d = dict(temp_g.out_degree)
|
|
62
|
+
in_d = dict(temp_g.in_degree)
|
|
63
|
+
n = max(temp_g.nodes, key=lambda x: out_d[x] - in_d[x])
|
|
64
|
+
yield n
|
|
65
|
+
temp_g.remove_node(n)
|
|
66
|
+
|
|
67
|
+
scc_dag = nx.condensation(graph)
|
|
68
|
+
scc_order = nx.topological_sort(scc_dag)
|
|
69
|
+
for scc_id in scc_order:
|
|
70
|
+
members: frozenset[Rule | Operator] = scc_dag.nodes[scc_id]['members']
|
|
71
|
+
subgraph = typing.cast("nx.DiGraph", graph.subgraph(members))
|
|
72
|
+
for m in _order_within_scc(subgraph):
|
|
73
|
+
node = graph.nodes[m]
|
|
74
|
+
# This "isinstance" clause is useless, just to satisfy the LSP
|
|
75
|
+
if node.get('type_', None) == 'rule' and isinstance(m, Rule):
|
|
76
|
+
yield m
|
|
77
|
+
|
|
78
|
+
def _select_next_generator(self) -> typing.Generator[Rule]:
|
|
79
|
+
yield from itertools.cycle(self._order)
|
|
80
|
+
|
|
81
|
+
def set_rules(self, rules: typing.Sequence[Rule]) -> None:
|
|
82
|
+
self._rules = list(rules)
|
|
83
|
+
if not self._rules:
|
|
84
|
+
raise ValueError('rules cannot be empty')
|
|
85
|
+
self._graph = self._rules_to_graph(self._rules)
|
|
86
|
+
self._order = list(self._compile_order(self._graph))
|
|
87
|
+
if self._reverse_order:
|
|
88
|
+
self._order.reverse()
|
|
89
|
+
self._stateful_gen = self._select_next_generator()
|
|
90
|
+
|
|
91
|
+
def reset(self) -> None:
|
|
92
|
+
self._rules = []
|
|
93
|
+
self._graph = nx.DiGraph()
|
|
94
|
+
self._order = []
|
|
95
|
+
|
|
96
|
+
def select_next(self) -> typing.Sequence[Rule]:
|
|
97
|
+
return [next(self._stateful_gen)]
|
|
98
|
+
|
|
99
|
+
def on_feedback(self, feedback: Feedback) -> None: # noqa: PLR6301
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register.rule_selector('SccSort')
|
|
104
|
+
class SccSortStrategy(SccSortStrategyBase):
|
|
105
|
+
_reverse_order = False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@register.rule_selector('ReverseSccSort')
|
|
109
|
+
class ReverseSccSortStrategy(SccSortStrategyBase):
|
|
110
|
+
_reverse_order = True
|
kele/egg_equiv.pyd
CHANGED
|
Binary file
|
kele/executer/executing.py
CHANGED
|
@@ -8,7 +8,7 @@ from kele.equality import Equivalence
|
|
|
8
8
|
from kele.grounder import GroundedRule
|
|
9
9
|
from kele.knowledge_bases import FactBase
|
|
10
10
|
from kele.syntax import FACT_TYPE, CompoundTerm, Constant, Question, SankuManagementSystem, Variable, _QuestionRule
|
|
11
|
-
from kele.
|
|
11
|
+
from kele.utils import summarize_items
|
|
12
12
|
|
|
13
13
|
logger = logging.getLogger(__name__)
|
|
14
14
|
|