kele 0.0.1a2__cp314-cp314-win_amd64.whl → 0.0.1b1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kele/_version.py CHANGED
@@ -1 +1 @@
1
- version = "0.0.1a2"
1
+ version = "0.0.1b1"
kele/config.py CHANGED
@@ -50,7 +50,12 @@ class InferenceStrategyConfig:
50
50
  select_rules_num: int | Literal[-1] = -1 # Number of rules to select.
51
51
  select_facts_num: int | Literal[-1] = -1 # Number of facts to select; -1 means all facts.
52
52
  # premise_selection_strategy: Literal[''] = '' # Premise selection algorithm. TODO: Unused.
53
- grounding_rule_strategy: Literal['SequentialCyclic', 'SequentialCyclicWithPriority'] = "SequentialCyclic" # Rule selection strategy in grounding.
53
+ grounding_rule_strategy: Literal[
54
+ 'SequentialCyclic',
55
+ 'SequentialCyclicWithPriority',
56
+ 'SccSort',
57
+ 'ReverseSccSort'
58
+ ] = "SequentialCyclic" # Rule selection strategy in grounding.
54
59
  # executing_sort_strategy: Literal[''] = '' # Execution order strategy. TODO: Unused.
55
60
  grounding_term_strategy: Literal['Exhausted'] = "Exhausted" # Term selection strategy in grounding.
56
61
  question_rule_interval: int = 1 # Insert a question rule every N rules; -1 uses total rule count as the interval.
@@ -1,24 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Protocol
5
+
6
+ from kele.utils import format_mapping
5
7
 
6
8
  if TYPE_CHECKING:
7
9
  from collections.abc import Callable, Mapping
8
10
 
9
- from kele.grounder import GroundedRule
10
11
  from kele.grounder.grounded_rule_ds._nodes import _AssertionNode
11
- from kele.syntax import CompoundTerm, Constant, Rule, Variable
12
+ from kele.syntax import CompoundTerm, Constant, Rule, Variable, Assertion
12
13
 
13
14
  logger = logging.getLogger(__name__)
14
15
 
15
16
 
17
+ class AssertionCheckHookMatch(Protocol):
18
+ """the protocol of output func (on_match in register_assertion_check_hook) in assertion check hook"""
19
+ def __call__(self, rule: Rule, assertion: Assertion, combination: dict[Variable, Constant | CompoundTerm], *, result: bool) \
20
+ -> None: ... # noqa: D102
21
+
22
+
16
23
  def register_assertion_check_hook(
17
- grounded_rule: GroundedRule,
18
24
  *,
19
25
  rule_name: str | None = None,
20
26
  vars_filter: Mapping[str, str] | None = None,
21
- on_match: Callable[[Rule, _AssertionNode, dict[Variable, Constant | CompoundTerm], bool], None] | None = None,
27
+ on_match: AssertionCheckHookMatch | None = None,
22
28
  break_on_match: bool = False,
23
29
  ) -> None:
24
30
  """
@@ -33,7 +39,6 @@ def register_assertion_check_hook(
33
39
  print(rule.name, assertion.content, combination, result)
34
40
 
35
41
  register_assertion_check_hook(
36
- grounded_rule,
37
42
  rule_name="rule_3",
38
43
  vars_filter={"p1": "f", "p2": "a", "p13": "b", "p14": "c"},
39
44
  on_match=log_match,
@@ -55,7 +60,7 @@ def register_assertion_check_hook(
55
60
  combination_str = {str(k): str(v) for k, v in combination.items()}
56
61
  if normalized_vars_filter:
57
62
  for key, value in normalized_vars_filter.items():
58
- if combination_str.get(key) != value:
63
+ if combination_str.get(key) != str(value):
59
64
  return
60
65
 
61
66
  if on_match is None:
@@ -63,16 +68,18 @@ def register_assertion_check_hook(
63
68
  "Assertion check hook: rule=%s content=%s combination=%s result=%s",
64
69
  rule,
65
70
  assertion.content,
66
- combination_str,
71
+ format_mapping(combination_str),
67
72
  result,
68
73
  )
69
74
  else:
70
- on_match(rule, assertion, combination, result)
75
+ on_match(rule, assertion.content, combination, result=result)
71
76
 
72
77
  if break_on_match:
73
78
  breakpoint() # noqa: T100
74
79
 
75
- grounded_rule.register_hook("assertion_check", _hook)
80
+ from kele.grounder import GroundedRule # noqa: PLC0415
81
+
82
+ GroundedRule.register_class_hook("assertion_check", _hook)
76
83
 
77
84
 
78
85
  class BuiltinHookEnabler:
@@ -80,12 +87,11 @@ class BuiltinHookEnabler:
80
87
  Enabler for built-in hooks by name.
81
88
 
82
89
  This class does not register hooks by itself. Create an instance and call
83
- ``enable`` to attach a built-in hook to a grounded rule.
90
+ ``enable`` to attach a built-in hook to every grounded rule instance.
84
91
 
85
92
  Example:
86
93
  hooks = BuiltinHookEnabler()
87
94
  hooks.enable(
88
- grounded_rule,
89
95
  "assertion_check",
90
96
  rule_name="rule_3",
91
97
  vars_filter={"p1": "f"},
@@ -103,17 +109,17 @@ class BuiltinHookEnabler:
103
109
  """
104
110
  return sorted(self._hooks)
105
111
 
106
- def enable(self, grounded_rule: GroundedRule, name: str, **kwargs: object) -> None:
112
+ def enable(self, name: str, **kwargs: object) -> None:
107
113
  """
108
114
  Enable a built-in hook by name.
109
115
  """
110
116
  if name not in self._hooks:
111
117
  raise KeyError(f"Unknown built-in hook: {name}")
112
- self._hooks[name](grounded_rule, **kwargs)
118
+ self._hooks[name](**kwargs)
113
119
 
114
- def enable_many(self, grounded_rule: GroundedRule, names: list[str], **kwargs: object) -> None:
120
+ def enable_many(self, names: list[str], **kwargs: object) -> None:
115
121
  """
116
122
  Enable multiple built-in hooks by name.
117
123
  """
118
124
  for hook_name in names:
119
- self.enable(grounded_rule, hook_name, **kwargs)
125
+ self.enable(hook_name, **kwargs)
kele/control/callback.py CHANGED
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
- from typing import TYPE_CHECKING, Any
4
+ from threading import RLock
5
+ from typing import TYPE_CHECKING, Any, ClassVar
5
6
 
6
7
  if TYPE_CHECKING:
7
8
  from collections.abc import Callable
@@ -18,9 +19,36 @@ class HookMixin:
18
19
 
19
20
  :ivar _hooks: 事件名称到钩子函数列表的映射。
20
21
  """
22
+ _class_hooks: ClassVar[dict[str, list[Callable[..., None]]]]
23
+ _hooks_lock: ClassVar[RLock] = RLock()
24
+
25
+ def __init_subclass__(cls, **kwargs: object) -> None:
26
+ super().__init_subclass__(**kwargs)
27
+ parent_hooks = getattr(cls, "_class_hooks", None)
28
+ class_hooks: dict[str, list[Callable[..., None]]] = defaultdict(list)
29
+ if parent_hooks:
30
+ for event_name, hooks in parent_hooks.items():
31
+ class_hooks[event_name] = list(hooks)
32
+ cls._class_hooks = class_hooks
33
+
21
34
  def __init__(self) -> None:
22
35
  self._hooks: dict[str, list[Callable[..., None]]] = defaultdict(list)
23
36
 
37
+ @classmethod
38
+ def register_class_hook(cls, event_name: str, hook_fn: Callable[..., None], *, unique: bool = True) -> None:
39
+ """
40
+ 在类级别注册 hook,所有实例触发事件时都会执行该 hook。
41
+
42
+ :param event_name: 要监听的事件名称。
43
+ :param hook_fn: 接受任意参数的可调用钩子函数。
44
+ :param unique: 是否防止同一 hook 重复注册。
45
+ """
46
+ with cls._hooks_lock:
47
+ hooks = cls._class_hooks[event_name]
48
+ if unique and any(fn is hook_fn for fn in hooks):
49
+ return
50
+ hooks.append(hook_fn)
51
+
24
52
  def register_hook(self, event_name: str, hook_fn: Callable[..., None]) -> None:
25
53
  """
26
54
  为指定事件注册钩子函数。
@@ -38,7 +66,10 @@ class HookMixin:
38
66
  :param args: 传递给钩子的所有位置参数。
39
67
  :param kwargs: 传递给钩子的所有关键字参数。
40
68
  """
41
- for hook in self._hooks.get(event_name, []):
69
+ with self._hooks_lock:
70
+ class_hooks = list(self.__class__._class_hooks.get(event_name, [])) # noqa: SLF001
71
+ instance_hooks = list(self._hooks.get(event_name, []))
72
+ for hook in class_hooks + instance_hooks:
42
73
  hook(*args, **kwargs)
43
74
 
44
75
  def run_hooks(self, event_name: str, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
@@ -0,0 +1,110 @@
1
+ import itertools
2
+ import typing
3
+ import warnings
4
+
5
+ import networkx as nx
6
+
7
+ from kele.control.registry import register
8
+ from kele.syntax.base_classes import FACT_TYPE, Assertion, CompoundTerm, Constant, Formula, Operator, Rule, Variable
9
+
10
+ from .strategy_protocol import Feedback, RuleSelectionStrategy
11
+
12
+
13
+ class SccSortStrategyBase(RuleSelectionStrategy):
14
+ _reverse_order: bool
15
+
16
+ def __init__(self) -> None:
17
+ warnings.warn("SccSort and ReverseSccSort are not stable features. Use them at your own risk.", stacklevel=3)
18
+ self._rules: list[Rule]
19
+ self._graph: nx.DiGraph
20
+ """Under current implementation, this is a bipartite graph. There may
21
+ be further performance optimizations with this property?"""
22
+ self._order: list[Rule]
23
+ self._stateful_gen: typing.Generator[Rule]
24
+
25
+ @classmethod
26
+ def _all_operators_in_fact_type(cls, f: FACT_TYPE | CompoundTerm | Constant | Variable | None) -> set[Operator]:
27
+ match f:
28
+ case CompoundTerm(operator=operator, arguments=arguments):
29
+ s = set()
30
+ s.add(operator)
31
+ for argument in arguments:
32
+ s |= cls._all_operators_in_fact_type(argument)
33
+ return s
34
+ case Assertion(lhs=lhs, rhs=rhs):
35
+ return cls._all_operators_in_fact_type(lhs) | cls._all_operators_in_fact_type(rhs)
36
+ case Formula(formula_left=l, formula_right=r):
37
+ return cls._all_operators_in_fact_type(l) | cls._all_operators_in_fact_type(r)
38
+ case _:
39
+ return set()
40
+
41
+ def _rules_to_graph(self, rules: list[Rule]) -> nx.DiGraph:
42
+ graph = nx.DiGraph()
43
+ nodes_left = []
44
+ for rule in rules:
45
+ graph.add_node(rule, type_='rule')
46
+ nodes_left.append(rule)
47
+ heads = self._all_operators_in_fact_type(rule.head)
48
+ bodies = self._all_operators_in_fact_type(rule.body)
49
+
50
+ for op in bodies:
51
+ graph.add_edge(op, rule)
52
+ for op in heads:
53
+ graph.add_edge(rule, op)
54
+ return graph
55
+
56
+ @staticmethod
57
+ def _compile_order(graph: nx.DiGraph) -> typing.Generator[Rule]:
58
+ def _order_within_scc(subgraph: nx.DiGraph) -> typing.Generator[Rule]:
59
+ temp_g = subgraph.copy()
60
+ while temp_g.nodes:
61
+ out_d = dict(temp_g.out_degree)
62
+ in_d = dict(temp_g.in_degree)
63
+ n = max(temp_g.nodes, key=lambda x: out_d[x] - in_d[x])
64
+ yield n
65
+ temp_g.remove_node(n)
66
+
67
+ scc_dag = nx.condensation(graph)
68
+ scc_order = nx.topological_sort(scc_dag)
69
+ for scc_id in scc_order:
70
+ members: frozenset[Rule | Operator] = scc_dag.nodes[scc_id]['members']
71
+ subgraph = typing.cast("nx.DiGraph", graph.subgraph(members))
72
+ for m in _order_within_scc(subgraph):
73
+ node = graph.nodes[m]
74
+ # This "isinstance" clause is useless, just to satisfy the LSP
75
+ if node.get('type_', None) == 'rule' and isinstance(m, Rule):
76
+ yield m
77
+
78
+ def _select_next_generator(self) -> typing.Generator[Rule]:
79
+ yield from itertools.cycle(self._order)
80
+
81
+ def set_rules(self, rules: typing.Sequence[Rule]) -> None:
82
+ self._rules = list(rules)
83
+ if not self._rules:
84
+ raise ValueError('rules cannot be empty')
85
+ self._graph = self._rules_to_graph(self._rules)
86
+ self._order = list(self._compile_order(self._graph))
87
+ if self._reverse_order:
88
+ self._order.reverse()
89
+ self._stateful_gen = self._select_next_generator()
90
+
91
+ def reset(self) -> None:
92
+ self._rules = []
93
+ self._graph = nx.DiGraph()
94
+ self._order = []
95
+
96
+ def select_next(self) -> typing.Sequence[Rule]:
97
+ return [next(self._stateful_gen)]
98
+
99
+ def on_feedback(self, feedback: Feedback) -> None: # noqa: PLR6301
100
+ return
101
+
102
+
103
+ @register.rule_selector('SccSort')
104
+ class SccSortStrategy(SccSortStrategyBase):
105
+ _reverse_order = False
106
+
107
+
108
+ @register.rule_selector('ReverseSccSort')
109
+ class ReverseSccSortStrategy(SccSortStrategyBase):
110
+ _reverse_order = True
kele/egg_equiv.pyd CHANGED
Binary file
@@ -8,7 +8,7 @@ from kele.equality import Equivalence
8
8
  from kele.grounder import GroundedRule
9
9
  from kele.knowledge_bases import FactBase
10
10
  from kele.syntax import FACT_TYPE, CompoundTerm, Constant, Question, SankuManagementSystem, Variable, _QuestionRule
11
- from kele._utils import summarize_items
11
+ from kele.utils import summarize_items
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14