kele 0.0.1a1__cp314-cp314-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kele/__init__.py +38 -0
- kele/_version.py +1 -0
- kele/config.py +243 -0
- kele/control/README_metrics.md +102 -0
- kele/control/__init__.py +20 -0
- kele/control/callback.py +255 -0
- kele/control/grounding_selector/__init__.py +5 -0
- kele/control/grounding_selector/_rule_strategies/README.md +13 -0
- kele/control/grounding_selector/_rule_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_rule_strategies/_sequential_strategy.py +42 -0
- kele/control/grounding_selector/_rule_strategies/strategy_protocol.py +51 -0
- kele/control/grounding_selector/_selector_utils.py +123 -0
- kele/control/grounding_selector/_term_strategies/__init__.py +24 -0
- kele/control/grounding_selector/_term_strategies/_exhausted_strategy.py +34 -0
- kele/control/grounding_selector/_term_strategies/strategy_protocol.py +50 -0
- kele/control/grounding_selector/rule_selector.py +98 -0
- kele/control/grounding_selector/term_selector.py +89 -0
- kele/control/infer_path.py +306 -0
- kele/control/metrics.py +357 -0
- kele/control/status.py +286 -0
- kele/egg_equiv.pyd +0 -0
- kele/egg_equiv.pyi +11 -0
- kele/equality/README.md +8 -0
- kele/equality/__init__.py +4 -0
- kele/equality/_egg_equiv/src/lib.rs +267 -0
- kele/equality/_equiv_elem.py +67 -0
- kele/equality/_utils.py +36 -0
- kele/equality/equivalence.py +141 -0
- kele/executer/__init__.py +4 -0
- kele/executer/executing.py +139 -0
- kele/grounder/README.md +83 -0
- kele/grounder/__init__.py +17 -0
- kele/grounder/grounded_rule_ds/__init__.py +6 -0
- kele/grounder/grounded_rule_ds/_nodes/__init__.py +24 -0
- kele/grounder/grounded_rule_ds/_nodes/_assertion.py +353 -0
- kele/grounder/grounded_rule_ds/_nodes/_conn.py +116 -0
- kele/grounder/grounded_rule_ds/_nodes/_op.py +57 -0
- kele/grounder/grounded_rule_ds/_nodes/_root.py +71 -0
- kele/grounder/grounded_rule_ds/_nodes/_rule.py +119 -0
- kele/grounder/grounded_rule_ds/_nodes/_term.py +390 -0
- kele/grounder/grounded_rule_ds/_nodes/_tftable.py +15 -0
- kele/grounder/grounded_rule_ds/_nodes/_tupletable.py +444 -0
- kele/grounder/grounded_rule_ds/_nodes/_typing_polars.py +26 -0
- kele/grounder/grounded_rule_ds/grounded_class.py +461 -0
- kele/grounder/grounded_rule_ds/grounded_ds_utils.py +91 -0
- kele/grounder/grounded_rule_ds/rule_check.py +373 -0
- kele/grounder/grounding.py +118 -0
- kele/knowledge_bases/README.md +112 -0
- kele/knowledge_bases/__init__.py +6 -0
- kele/knowledge_bases/builtin_base/__init__.py +1 -0
- kele/knowledge_bases/builtin_base/builtin_concepts.py +13 -0
- kele/knowledge_bases/builtin_base/builtin_facts.py +43 -0
- kele/knowledge_bases/builtin_base/builtin_operators.py +105 -0
- kele/knowledge_bases/builtin_base/builtin_rules.py +14 -0
- kele/knowledge_bases/fact_base.py +158 -0
- kele/knowledge_bases/ontology_base.py +67 -0
- kele/knowledge_bases/rule_base.py +194 -0
- kele/main.py +464 -0
- kele/py.typed +0 -0
- kele/syntax/CONCEPT_README.md +117 -0
- kele/syntax/__init__.py +40 -0
- kele/syntax/_cnf_converter.py +161 -0
- kele/syntax/_sat_solver.py +116 -0
- kele/syntax/base_classes.py +1482 -0
- kele/syntax/connectives.py +20 -0
- kele/syntax/dnf_converter.py +145 -0
- kele/syntax/external.py +17 -0
- kele/syntax/sub_concept.py +87 -0
- kele/syntax/syntacticsugar.py +201 -0
- kele-0.0.1a1.dist-info/METADATA +166 -0
- kele-0.0.1a1.dist-info/RECORD +74 -0
- kele-0.0.1a1.dist-info/WHEEL +4 -0
- kele-0.0.1a1.dist-info/licenses/LICENSE +28 -0
- kele-0.0.1a1.dist-info/licenses/licensecheck.json +20 -0
kele/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""支持断言逻辑的推理引擎"""
|
|
2
|
+
from kele.main import EngineRunResult, InferenceEngine, QueryStructure
|
|
3
|
+
from kele.config import (
|
|
4
|
+
Config,
|
|
5
|
+
RunControlConfig,
|
|
6
|
+
InferenceStrategyConfig,
|
|
7
|
+
GrounderConfig,
|
|
8
|
+
ExecutorConfig,
|
|
9
|
+
PathConfig,
|
|
10
|
+
KBConfig,
|
|
11
|
+
)
|
|
12
|
+
from kele.syntax.base_classes import Constant, Concept, Operator, Variable, CompoundTerm, Assertion, Formula, Rule
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from ._version import version as __version__
|
|
16
|
+
except ImportError:
|
|
17
|
+
__version__ = '0.0.0'
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
'Assertion',
|
|
21
|
+
'CompoundTerm',
|
|
22
|
+
'Concept',
|
|
23
|
+
'Config',
|
|
24
|
+
'Constant',
|
|
25
|
+
'EngineRunResult',
|
|
26
|
+
'ExecutorConfig',
|
|
27
|
+
'Formula',
|
|
28
|
+
'GrounderConfig',
|
|
29
|
+
'InferenceEngine',
|
|
30
|
+
'InferenceStrategyConfig',
|
|
31
|
+
'KBConfig',
|
|
32
|
+
'Operator',
|
|
33
|
+
'PathConfig',
|
|
34
|
+
'QueryStructure',
|
|
35
|
+
'Rule',
|
|
36
|
+
'RunControlConfig',
|
|
37
|
+
'Variable',
|
|
38
|
+
]
|
kele/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
version = "0.0.1a1"
|
kele/config.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# ruff: noqa: ERA001 # Commented parameters are either not implemented yet or depend on unfinished upstream/downstream modules.
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Any, cast, Literal
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass, fields, field
|
|
7
|
+
from datetime import datetime, UTC
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import yaml
|
|
10
|
+
import json
|
|
11
|
+
import tyro
|
|
12
|
+
from tyro.conf import OmitArgPrefixes
|
|
13
|
+
import dacite
|
|
14
|
+
from dacite.config import Config as daConfig
|
|
15
|
+
|
|
16
|
+
RESULT_LEVEL = 25
|
|
17
|
+
logging.RESULT = RESULT_LEVEL # type: ignore[attr-defined] # This fails mypy; setattr fails ruff.
|
|
18
|
+
|
|
19
|
+
logging.addLevelName(RESULT_LEVEL, "RESULT")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _result(self: logging.Logger, message: str, *args: object) -> None:
|
|
23
|
+
if self.isEnabledFor(RESULT_LEVEL):
|
|
24
|
+
self._log(RESULT_LEVEL, message, args)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logging.Logger.result = _result # type: ignore[attr-defined]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Config classes inheriting all settings from the original file.
|
|
31
|
+
@dataclass
|
|
32
|
+
class RunControlConfig:
|
|
33
|
+
"""Runtime control."""
|
|
34
|
+
iteration_limit: int = 300 # Timeout iterations (one grounder-executor cycle is one iteration).
|
|
35
|
+
# time_limit: int = 3000 # 暂未接入超时终止逻辑,保留字段占位。
|
|
36
|
+
log_level: Literal['DEBUG', 'INFO', 'RESULT', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO' # Log level.
|
|
37
|
+
# grounding_steps: int = 4 # Per-grounding step limit. TODO: Unused; may no longer be needed.
|
|
38
|
+
trace: bool = False # Enable inference path tracing.
|
|
39
|
+
semi_eval_with_equality: bool = True # Consider equality axioms in semi-evaluation. Disable to reduce overhead.
|
|
40
|
+
# This only partially disables related behavior. TODO: Possibly rename to inference_with_equality.
|
|
41
|
+
interactive_query_mode: Literal["interactive", "first", "all"] = "first" # Control interactive printing of solutions.
|
|
42
|
+
# interactive = interactive, first = print first solution only, all = print all solutions.
|
|
43
|
+
save_solutions: bool = False # Record and return solutions; False logs only to terminal and logs.
|
|
44
|
+
include_final_facts: bool = False # Include final facts in EngineRunResult; fact_num always reported.
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class InferenceStrategyConfig:
|
|
49
|
+
"""Inference strategy and model behavior."""
|
|
50
|
+
select_rules_num: int | Literal[-1] = -1 # Number of rules to select.
|
|
51
|
+
select_facts_num: int | Literal[-1] = -1 # Number of facts to select; -1 means all facts.
|
|
52
|
+
# premise_selection_strategy: Literal[''] = '' # Premise selection algorithm. TODO: Unused.
|
|
53
|
+
grounding_rule_strategy: Literal['SequentialCyclic', 'SequentialCyclicWithPriority'] = "SequentialCyclic" # Rule selection strategy in grounding.
|
|
54
|
+
# executing_sort_strategy: Literal[''] = '' # Execution order strategy. TODO: Unused.
|
|
55
|
+
grounding_term_strategy: Literal['Exhausted'] = "Exhausted" # Term selection strategy in grounding.
|
|
56
|
+
question_rule_interval: int = 1 # Insert a question rule every N rules; -1 uses total rule count as the interval.
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class GrounderConfig:
|
|
61
|
+
"""Grounder-related parameters."""
|
|
62
|
+
grounding_rules_num_every_step: int | Literal[-1] = -1
|
|
63
|
+
grounding_facts_num_for_each_rule: int | Literal[-1] = -1
|
|
64
|
+
allow_unify_with_nested_term: bool = True # Allow Variables to be replaced by CompoundTerms.
|
|
65
|
+
conceptual_fuzzy_unification: bool = True # Use strict concept constraints to accelerate inference.
|
|
66
|
+
# This depends on correct concept subsumption and full constant.belong_concepts settings; beginners should use loose matching.
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class ExecutorConfig:
|
|
71
|
+
"""Executor-related parameters."""
|
|
72
|
+
executing_rule_num: int | Literal[-1] = -1
|
|
73
|
+
executing_max_steps: int | Literal[-1] = -1
|
|
74
|
+
anti_join_used_facts: bool = True # Drop facts that were already produced (default True).
|
|
75
|
+
# This records last-true results and anti-joins against current results to drop facts.
|
|
76
|
+
# It can be inefficient when duplicates are rare, but speeds up heavy duplication.
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class PathConfig:
|
|
81
|
+
"""Paths and resource dependencies."""
|
|
82
|
+
rule_dir: str = './'
|
|
83
|
+
fact_dir: str = './'
|
|
84
|
+
log_dir: str = './log'
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class KBConfig:
|
|
89
|
+
"""Knowledge-base related parameters."""
|
|
90
|
+
fact_cache_size: int | Literal[-1] = -1
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class Config:
|
|
95
|
+
"""Main entry point for kele hyperparameters."""
|
|
96
|
+
run: OmitArgPrefixes[RunControlConfig] = field(default_factory=RunControlConfig)
|
|
97
|
+
strategy: OmitArgPrefixes[InferenceStrategyConfig] = field(default_factory=InferenceStrategyConfig)
|
|
98
|
+
grounder: OmitArgPrefixes[GrounderConfig] = field(default_factory=GrounderConfig)
|
|
99
|
+
executor: OmitArgPrefixes[ExecutorConfig] = field(default_factory=ExecutorConfig)
|
|
100
|
+
path: OmitArgPrefixes[PathConfig] = field(default_factory=PathConfig)
|
|
101
|
+
engineering: OmitArgPrefixes[KBConfig] = field(default_factory=KBConfig)
|
|
102
|
+
config: str | None = None # Config file path.
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _load_config_file(path: str) -> dict[str, Any]:
|
|
106
|
+
with open(path, encoding="utf-8") as f:
|
|
107
|
+
if path.endswith(('.yaml', '.yml')):
|
|
108
|
+
data = yaml.safe_load(f)
|
|
109
|
+
elif path.endswith('.json'):
|
|
110
|
+
data = json.load(f)
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError("Unsupported config file format: must be .yaml, .yml, or .json")
|
|
113
|
+
|
|
114
|
+
if not isinstance(data, dict):
|
|
115
|
+
raise TypeError(f"Config file must contain a dict, got {type(data)}")
|
|
116
|
+
return cast("dict[str, Any]", data)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _save_config(config: dict[str, Any], path: str) -> None:
|
|
120
|
+
with open(path, 'w', encoding='utf8') as f:
|
|
121
|
+
yaml.dump(config, f, sort_keys=False)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _init_logger(log_path: str | None = None,
|
|
125
|
+
run_id: str | None = None,
|
|
126
|
+
log_name: str = "run.log",
|
|
127
|
+
log_level: int = logging.INFO) -> logging.Logger: # Literal is best but too verbose here.
|
|
128
|
+
"""
|
|
129
|
+
Initialize the logging system; paths come from config.log_dir.
|
|
130
|
+
Supports per-run log files by run_id.
|
|
131
|
+
"""
|
|
132
|
+
if run_id is None:
|
|
133
|
+
run_id = datetime.now(UTC).astimezone().strftime("%Y%m%d_%H%M%S")
|
|
134
|
+
|
|
135
|
+
log_dir = Path(log_path) if log_path else Path("./log")
|
|
136
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
|
|
138
|
+
log_file = log_dir / f"{run_id}_{log_name}"
|
|
139
|
+
|
|
140
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
141
|
+
|
|
142
|
+
# Get the root logger and clear old handlers (avoid duplicate output).
|
|
143
|
+
logger = logging.getLogger()
|
|
144
|
+
logger.setLevel(logging.INFO)
|
|
145
|
+
logger.handlers.clear()
|
|
146
|
+
|
|
147
|
+
# File log handler.
|
|
148
|
+
file_handler = logging.FileHandler(log_file, mode='a')
|
|
149
|
+
file_handler.setFormatter(formatter)
|
|
150
|
+
logger.addHandler(file_handler)
|
|
151
|
+
|
|
152
|
+
# Console handler.
|
|
153
|
+
console_handler = logging.StreamHandler()
|
|
154
|
+
console_handler.setFormatter(formatter)
|
|
155
|
+
logger.addHandler(console_handler)
|
|
156
|
+
|
|
157
|
+
logger.setLevel(log_level)
|
|
158
|
+
|
|
159
|
+
logger.info("Logger initialized at %s", log_file)
|
|
160
|
+
|
|
161
|
+
return logger
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _init_config_cli() -> Config:
|
|
165
|
+
"""Initialize config and logging."""
|
|
166
|
+
cli_config, unknown = tyro.cli(Config, return_unknown_args=True) # Parse CLI args via tyro. HACK: may parse manually later.
|
|
167
|
+
if unknown:
|
|
168
|
+
warnings.warn(f"Unknown kele parameters ignored: {unknown}", stacklevel=2)
|
|
169
|
+
|
|
170
|
+
no_default_fields = [field.name for field in fields(Config) if field.default is field.default_factory]
|
|
171
|
+
base_config: dict[str, dict[Any, Any]] = {k: {} for k in no_default_fields} # Avoid errors if YAML misses child configs.
|
|
172
|
+
# If child configs have new required fields, this still fails; we could recurse but treat it as unnecessary for now.
|
|
173
|
+
file_config = base_config | _load_config_file(cli_config.config) if cli_config.config else base_config
|
|
174
|
+
|
|
175
|
+
# Merge config file and CLI parameters.
|
|
176
|
+
final_config, _ = tyro.cli(Config, return_unknown_args=True, default=dacite.from_dict(Config,
|
|
177
|
+
file_config,
|
|
178
|
+
config=daConfig(strict=True)))
|
|
179
|
+
return final_config
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _build_config(user_config: Config | None = None,
|
|
183
|
+
config_file_path: str | None = None) -> Config:
|
|
184
|
+
"""Build a `Config` from CLI arguments, an in-code default, or a config file.
|
|
185
|
+
|
|
186
|
+
This function provides a single entry point to construct the runtime configuration:
|
|
187
|
+
- If `user_config` is provided, parse CLI overrides on top of it.
|
|
188
|
+
- If `config_file_path` is provided, load the file config and allow CLI overrides.
|
|
189
|
+
- Otherwise, initialize purely from CLI (and its config argument).
|
|
190
|
+
|
|
191
|
+
:param user_config: A default `Config` instance to be used as the CLI default.
|
|
192
|
+
Mutually exclusive with `config_file_path` (and also incompatible with
|
|
193
|
+
`user_config.config` being set).
|
|
194
|
+
:param config_file_path: Path to a config file (e.g., YAML/JSON). Mutually
|
|
195
|
+
exclusive with `user_config`.
|
|
196
|
+
:return: The final merged `Config` instance.
|
|
197
|
+
|
|
198
|
+
:raises: ValueError: If `user_config` is used together with `config_file_path`
|
|
199
|
+
or when `user_config.config` is set.
|
|
200
|
+
""" # noqa: DOC501
|
|
201
|
+
if user_config and (user_config.config or config_file_path):
|
|
202
|
+
raise ValueError("default config instance and config file cannot be used together")
|
|
203
|
+
|
|
204
|
+
if user_config:
|
|
205
|
+
return tyro.cli(Config, default=user_config) # Parameters passed programmatically.
|
|
206
|
+
|
|
207
|
+
if config_file_path:
|
|
208
|
+
file_config = _load_config_file(config_file_path)
|
|
209
|
+
merged_config, _ = tyro.cli(Config, return_unknown_args=True, default=dacite.from_dict(Config,
|
|
210
|
+
file_config,
|
|
211
|
+
config=daConfig(strict=True)))
|
|
212
|
+
return merged_config
|
|
213
|
+
|
|
214
|
+
return _init_config_cli()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def init_config_logger(user_config: Config | None = None,
|
|
218
|
+
config_file_path: str | None = None,
|
|
219
|
+
*,
|
|
220
|
+
run_id: str | None = None,
|
|
221
|
+
log_name: str = "run.log") -> Config:
|
|
222
|
+
"""Initialize configuration and logger.
|
|
223
|
+
|
|
224
|
+
This is the public entry point. It builds the final `Config` (from CLI/code/file)
|
|
225
|
+
and initializes the logger under `config.path.log_dir`, then logs the final config.
|
|
226
|
+
|
|
227
|
+
:param: user_config: A default `Config` instance to be used as the CLI default.
|
|
228
|
+
Mutually exclusive with `config_file_path`.
|
|
229
|
+
:param: config_file_path: Path to a config file. Mutually exclusive with `user_config`.
|
|
230
|
+
:param: run_id: Optional run identifier used by the logger initializer.
|
|
231
|
+
:param: log_name: Suffix of log file name ({run_id}_{log_name}). Defaults to "run.log".
|
|
232
|
+
|
|
233
|
+
:return: The final merged `Config` instance.
|
|
234
|
+
"""
|
|
235
|
+
config = _build_config(user_config, config_file_path)
|
|
236
|
+
|
|
237
|
+
logger = _init_logger(config.path.log_dir,
|
|
238
|
+
run_id=run_id, log_name=log_name,
|
|
239
|
+
log_level=getattr(logging, config.run.log_level.upper(), logging.INFO))
|
|
240
|
+
|
|
241
|
+
logger.info("Final Config:\n%s", yaml.dump(config.__dict__, sort_keys=False, allow_unicode=True))
|
|
242
|
+
|
|
243
|
+
return config
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
本文档由GPT5撰写。
|
|
2
|
+
|
|
3
|
+
# Metrics 使用说明
|
|
4
|
+
|
|
5
|
+
本模块提供了推理过程中的性能指标采集与记录功能,包括 **速度**、**CPU 占用**、**内存使用** 等。
|
|
6
|
+
虽然目前速度是最主要的关注点,但也支持记录其他资源指标。
|
|
7
|
+
|
|
8
|
+
## 功能概览
|
|
9
|
+
- **指标类型**
|
|
10
|
+
- `duration_seconds`:代码或函数运行耗时(秒)
|
|
11
|
+
- `cpu_percent`:CPU 占用百分比
|
|
12
|
+
- `rss_mib`:进程占用内存(MiB)
|
|
13
|
+
- 迭代次数、规则数、事实数等离散计数指标
|
|
14
|
+
|
|
15
|
+
- **记录方式**
|
|
16
|
+
- **代码级**:使用 `PhaseTimer` 上下文管理器,手动分段计时
|
|
17
|
+
- **函数级**:使用 `@measure` 装饰器,自动记录函数执行时间和资源占用
|
|
18
|
+
- 可通过 `observe_counts`、`sample_process_gauges` 等函数采集自定义指标
|
|
19
|
+
|
|
20
|
+
- **结果存储**
|
|
21
|
+
默认保存在 **`metrics_logs`** 文件夹下,以 `run_id.json` 命名,包含运行元信息和事件记录。
|
|
22
|
+
|
|
23
|
+
## 快速开始
|
|
24
|
+
|
|
25
|
+
### 1. 初始化
|
|
26
|
+
```python
|
|
27
|
+
from metrics import init_metrics
|
|
28
|
+
init_metrics(port=14233, job="al_inference", grouping={"env": "dev"})
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
### 2. 开始与结束一次运行
|
|
32
|
+
```python
|
|
33
|
+
from metrics import start_run, end_run
|
|
34
|
+
|
|
35
|
+
run_id = start_run(log_dir="metrics_logs")
|
|
36
|
+
# ... 推理或业务逻辑 ...
|
|
37
|
+
end_run(extra_meta={"facts_final": 100, "rules_total": 50})
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
### 3. 代码级计时(阶段)
|
|
41
|
+
```python
|
|
42
|
+
from metrics import PhaseTimer
|
|
43
|
+
|
|
44
|
+
with PhaseTimer("grounding", module="pipeline"):
|
|
45
|
+
do_grounding()
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
### 4. 函数级计时
|
|
49
|
+
```python
|
|
50
|
+
from metrics import measure
|
|
51
|
+
|
|
52
|
+
@measure("infer_step", module="inference")
|
|
53
|
+
def infer_step():
|
|
54
|
+
# 推理步骤逻辑
|
|
55
|
+
pass
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
### 5. 采集计数与进程状态
|
|
59
|
+
```python
|
|
60
|
+
from metrics import observe_counts, sample_process_gauges
|
|
61
|
+
|
|
62
|
+
observe_counts(grounded_rules=10, facts_count=200)
|
|
63
|
+
sample_process_gauges()
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## 日志文件示例
|
|
67
|
+
|
|
68
|
+
一次运行的 `metrics_logs/20230801-120000-ab12cd.json` 文件内容示例:
|
|
69
|
+
|
|
70
|
+
```json
|
|
71
|
+
{
|
|
72
|
+
"meta": {
|
|
73
|
+
"run_id": "20230801-120000-ab12cd",
|
|
74
|
+
"started_at": "2023-08-01T12:00:00+08:00",
|
|
75
|
+
"ended_at": "2023-08-01T12:00:10+08:00",
|
|
76
|
+
"cpu_percent_max": 85.3,
|
|
77
|
+
"cpu_percent_mean": 42.7,
|
|
78
|
+
"rss_max_mib": 512.4,
|
|
79
|
+
"phase_durations_seconds_total": [
|
|
80
|
+
{"module": "pipeline", "phase": "grounding", "duration_seconds_total": 2.53},
|
|
81
|
+
{"module": "pipeline", "phase": "execute", "duration_seconds_total": 5.42}
|
|
82
|
+
],
|
|
83
|
+
"all_phases_duration_seconds_total": 7.95,
|
|
84
|
+
"function_durations_seconds_total": [
|
|
85
|
+
{"module": "inference", "name": "main_infer", "duration_seconds_total": 8.02}
|
|
86
|
+
],
|
|
87
|
+
"all_functions_duration_seconds_total": 8.02,
|
|
88
|
+
"facts_final": 150,
|
|
89
|
+
"rules_total": 200
|
|
90
|
+
},
|
|
91
|
+
"events": [
|
|
92
|
+
{"timestamp": "2023-08-01T12:00:00+08:00", "event": "process_sample", "cpu_percent": 30.5, "rss_mib": 400.2},
|
|
93
|
+
{"timestamp": "2023-08-01T12:00:02+08:00", "event": "phase_timing", "module": "pipeline", "phase": "grounding", "duration_seconds": 2.53},
|
|
94
|
+
{"timestamp": "2023-08-01T12:00:07+08:00", "event": "phase_timing", "module": "pipeline", "phase": "execute", "duration_seconds": 5.42},
|
|
95
|
+
{"timestamp": "2023-08-01T12:00:10+08:00", "event": "func_timing", "module": "inference", "name": "main_infer", "duration_seconds": 8.02}
|
|
96
|
+
]
|
|
97
|
+
}
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## 备注
|
|
101
|
+
- 该模块可独立用于代码性能测试,也可集成至推理引擎或其他系统。
|
|
102
|
+
- 如需推送到 Prometheus Pushgateway,可在 `init_metrics` 中配置 `pushgateway` 参数。
|
kele/control/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""用于callbacks和推理路径的记录"""
|
|
2
|
+
from .callback import HookMixin, Callback, CallbackManager
|
|
3
|
+
from .status import (
|
|
4
|
+
InferenceStatus,
|
|
5
|
+
create_main_loop_manager,
|
|
6
|
+
create_executor_manager,
|
|
7
|
+
)
|
|
8
|
+
from .grounding_selector import GroundingRuleSelector
|
|
9
|
+
from .infer_path import InferencePath
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
'Callback',
|
|
13
|
+
'CallbackManager',
|
|
14
|
+
'GroundingRuleSelector',
|
|
15
|
+
'HookMixin',
|
|
16
|
+
'InferencePath',
|
|
17
|
+
'InferenceStatus',
|
|
18
|
+
'create_executor_manager',
|
|
19
|
+
'create_main_loop_manager',
|
|
20
|
+
]
|
kele/control/callback.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from kele.equality import Equivalence
|
|
9
|
+
from kele.grounder import GroundedRule
|
|
10
|
+
from kele.knowledge_bases import FactBase, RuleBase
|
|
11
|
+
from kele.syntax import Question, Rule, FACT_TYPE, Variable, Constant, CompoundTerm
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HookMixin:
|
|
16
|
+
"""
|
|
17
|
+
提供模块自身注册hook和执行hook的 Mixin类,需要的模块可以直接继承它
|
|
18
|
+
|
|
19
|
+
:ivar _hooks: 事件名称到钩子函数列表的映射。
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
self._hooks: dict[str, list[Callable[..., None]]] = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
def register_hook(self, event_name: str, hook_fn: Callable[..., None]) -> None:
|
|
25
|
+
"""
|
|
26
|
+
为指定事件注册钩子函数。
|
|
27
|
+
|
|
28
|
+
:param event_name: 要监听的事件名称。
|
|
29
|
+
:param hook_fn: 接受任意参数的可调用钩子函数。
|
|
30
|
+
"""
|
|
31
|
+
self._hooks[event_name].append(hook_fn)
|
|
32
|
+
|
|
33
|
+
def _run_hooks(self, event_name: str, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
|
|
34
|
+
"""
|
|
35
|
+
执行所有注册到指定事件的钩子。
|
|
36
|
+
|
|
37
|
+
:param event_name: 事件名称。
|
|
38
|
+
:param args: 传递给钩子的所有位置参数。
|
|
39
|
+
:param kwargs: 传递给钩子的所有关键字参数。
|
|
40
|
+
"""
|
|
41
|
+
for hook in self._hooks.get(event_name, []):
|
|
42
|
+
hook(*args, **kwargs)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Callback:
|
|
46
|
+
"""回调接口——在推理各阶段采集指标的Hook"""
|
|
47
|
+
|
|
48
|
+
def on_infer_start(
|
|
49
|
+
self,
|
|
50
|
+
question: Question,
|
|
51
|
+
fact_base: FactBase,
|
|
52
|
+
rule_base: RuleBase,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""
|
|
55
|
+
推理开始时调用
|
|
56
|
+
|
|
57
|
+
:param question: 待推理的问题
|
|
58
|
+
:param fact_base: 事实库
|
|
59
|
+
:param rule_base: 规则库
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def on_grounder_select_start(
|
|
63
|
+
self,
|
|
64
|
+
question: Question,
|
|
65
|
+
fact_base: FactBase,
|
|
66
|
+
rule_base: RuleBase,
|
|
67
|
+
) -> None: # HACK: 参数可能需要包括grounder的选择策略?但有点没必要感觉,毕竟策略也不一定按str分类
|
|
68
|
+
"""
|
|
69
|
+
Grounder选取前调用
|
|
70
|
+
|
|
71
|
+
:param question: 待推理的问题
|
|
72
|
+
:param fact_base: 事实库
|
|
73
|
+
:param rule_base: 规则库
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def on_grounder_select_end(
|
|
77
|
+
self,
|
|
78
|
+
selected_rule_terms_pair: list[tuple[Rule, list[FACT_TYPE]]],
|
|
79
|
+
candidate_rules: RuleBase,
|
|
80
|
+
fact_base: FactBase,
|
|
81
|
+
question: Question,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Grounder 选取后调用。
|
|
85
|
+
|
|
86
|
+
:param selected_rule_terms_pair: 与该规则匹配的事实列表
|
|
87
|
+
:param candidate_rules: 本次 Grounder 考虑的所有候选规则列表,因为目前是直接从代码中选取,所以就约定为规则库即可
|
|
88
|
+
:param question: 待推理的问题
|
|
89
|
+
:param fact_base: 事实库列表
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def on_binding_change(
|
|
93
|
+
self,
|
|
94
|
+
var_name: str,
|
|
95
|
+
var_value: Constant | CompoundTerm
|
|
96
|
+
) -> None:
|
|
97
|
+
"""
|
|
98
|
+
每次变量绑定/解绑时调用。这个函数暂且作为提示性作用,如果它在_RuleNode中起作用,日后可能会被on_rule_activation替代
|
|
99
|
+
如果在其他节点或者说_TupleTable层面就起作用,那激活频率又太高了
|
|
100
|
+
|
|
101
|
+
:param var_name: 变量名
|
|
102
|
+
:param var_value: 变量值
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def on_rule_activation(
|
|
106
|
+
self,
|
|
107
|
+
rule: Rule,
|
|
108
|
+
var_dict: dict[Variable, Constant | CompoundTerm]
|
|
109
|
+
) -> None:
|
|
110
|
+
"""
|
|
111
|
+
每次 RuleNode 被激活(进入执行)时调用。
|
|
112
|
+
|
|
113
|
+
:param rule: 当前激活的规则
|
|
114
|
+
:param var_dict: 每次传递一个实例化候选元组
|
|
115
|
+
# risk: 能否思考这里计算推理深度depth: int 参数,或者与infer path联动
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def on_executor_start(
|
|
119
|
+
self,
|
|
120
|
+
grounded_rules: list[GroundedRule],
|
|
121
|
+
question: Question,
|
|
122
|
+
equivalence: Equivalence,
|
|
123
|
+
) -> None: # HACK: 参数可能需要包括executor的选择策略?但有点没必要感觉,毕竟策略也不一定按str分类
|
|
124
|
+
"""
|
|
125
|
+
Executor 执行前调用。
|
|
126
|
+
|
|
127
|
+
:param grounded_rules: 已实例化的规则列表
|
|
128
|
+
:param question: 待推理的问题
|
|
129
|
+
:param equivalence: 为规则检验提供支持的等价类
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def on_executor_sorted(
|
|
133
|
+
self,
|
|
134
|
+
sorted_rules: list[GroundedRule],
|
|
135
|
+
original_rules: list[GroundedRule],
|
|
136
|
+
question: Question,
|
|
137
|
+
) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Executor 排序后调用。
|
|
140
|
+
|
|
141
|
+
:param sorted_rules: 排序后的规则列表
|
|
142
|
+
:param original_rules: 排序前的规则列表
|
|
143
|
+
:param question: 待推理的问题
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def on_executor_post(
|
|
147
|
+
self,
|
|
148
|
+
execution_results: Any, # 包含一些新生成的事实,可能还有别的 # noqa: ANN401 # TODO: 细节待定
|
|
149
|
+
question: Question,
|
|
150
|
+
) -> None:
|
|
151
|
+
"""
|
|
152
|
+
Executor 执行后调用。
|
|
153
|
+
|
|
154
|
+
:param execution_results: 执行返回的原始结果(包含路径、绑定信息等)
|
|
155
|
+
:param question: 待推理的问题
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def on_infer_end(
|
|
159
|
+
self,
|
|
160
|
+
final_result: Any, # FIXME: 尚不清楚类型,同on_executor_post # noqa: ANN401
|
|
161
|
+
question: Question,
|
|
162
|
+
metrics: Any # noqa: ANN401 # FIXME: 细节待定
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
推理完成时调用,汇总全流程指标并评估准确率。
|
|
166
|
+
|
|
167
|
+
:param final_result: 最终推理输出
|
|
168
|
+
:param question: 待推理的问题
|
|
169
|
+
:param metrics: 可能的各种对结果的评价等信息
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class CallbackManager:
|
|
174
|
+
"""通过这个类注册实例化后的Callback"""
|
|
175
|
+
def __init__(self) -> None:
|
|
176
|
+
self._callbacks: list[Callback] = []
|
|
177
|
+
|
|
178
|
+
def register_callback(self, callback: Callback) -> None:
|
|
179
|
+
"""
|
|
180
|
+
注册一个回调实例。
|
|
181
|
+
:param callback: Callback的子类实例
|
|
182
|
+
"""
|
|
183
|
+
self._callbacks.append(callback)
|
|
184
|
+
|
|
185
|
+
def unregister_callback(self, callback: Callback) -> None:
|
|
186
|
+
"""
|
|
187
|
+
注销一个回调实例。
|
|
188
|
+
"""
|
|
189
|
+
if callback in self._callbacks:
|
|
190
|
+
self._callbacks.remove(callback)
|
|
191
|
+
|
|
192
|
+
# Infer start
|
|
193
|
+
def on_infer_start(self, question: Question, fact_base: FactBase, rule_base: RuleBase) -> None:
|
|
194
|
+
"""对应Callback的on_infer_start"""
|
|
195
|
+
for cb in self._callbacks:
|
|
196
|
+
cb.on_infer_start(question, fact_base, rule_base)
|
|
197
|
+
|
|
198
|
+
# Grounder selection hooks
|
|
199
|
+
def on_grounder_select_start(self, question: Question, fact_base: FactBase, rule_base: RuleBase) -> None:
|
|
200
|
+
"""对应Callback的on_grounder_select_start"""
|
|
201
|
+
for cb in self._callbacks:
|
|
202
|
+
cb.on_grounder_select_start(question, fact_base, rule_base)
|
|
203
|
+
|
|
204
|
+
def on_grounder_select_end(
|
|
205
|
+
self,
|
|
206
|
+
selected_rule_terms_pair: list[tuple[Rule, list[FACT_TYPE]]],
|
|
207
|
+
candidate_rules: RuleBase,
|
|
208
|
+
fact_base: FactBase,
|
|
209
|
+
question: Question,
|
|
210
|
+
) -> None:
|
|
211
|
+
"""对应Callback的on_grounder_select_end"""
|
|
212
|
+
for cb in self._callbacks:
|
|
213
|
+
cb.on_grounder_select_end(selected_rule_terms_pair, candidate_rules, fact_base, question)
|
|
214
|
+
|
|
215
|
+
# Binding hook
|
|
216
|
+
def on_binding_change(self, var_name: str, var_value: Constant | CompoundTerm) -> None:
|
|
217
|
+
"""对应Callback的on_binding_change"""
|
|
218
|
+
for cb in self._callbacks:
|
|
219
|
+
cb.on_binding_change(var_name, var_value)
|
|
220
|
+
|
|
221
|
+
# Rule activation
|
|
222
|
+
def on_rule_activation(self, rule: Rule, var_dict: dict[Variable, Constant | CompoundTerm]) -> None:
|
|
223
|
+
"""对应Callback的on_rule_activation"""
|
|
224
|
+
for cb in self._callbacks:
|
|
225
|
+
cb.on_rule_activation(rule, var_dict)
|
|
226
|
+
|
|
227
|
+
# Executor hooks
|
|
228
|
+
def on_executor_start(self,
|
|
229
|
+
grounded_rules: list[GroundedRule],
|
|
230
|
+
question: Question,
|
|
231
|
+
equivalence: Equivalence) -> None:
|
|
232
|
+
"""对应Callback的on_executor_start"""
|
|
233
|
+
for cb in self._callbacks:
|
|
234
|
+
cb.on_executor_start(grounded_rules, question, equivalence)
|
|
235
|
+
|
|
236
|
+
def on_executor_sorted(
|
|
237
|
+
self,
|
|
238
|
+
sorted_rules: list[GroundedRule],
|
|
239
|
+
original_rules: list[GroundedRule],
|
|
240
|
+
question: Question,
|
|
241
|
+
) -> None:
|
|
242
|
+
"""对应Callback的on_executor_sorted"""
|
|
243
|
+
for cb in self._callbacks:
|
|
244
|
+
cb.on_executor_sorted(sorted_rules, original_rules, question)
|
|
245
|
+
|
|
246
|
+
def on_executor_post(self, execution_results: Any, question: Question) -> None: # noqa: ANN401 # FIXME: 同Callback
|
|
247
|
+
"""对应Callback的on_executor_post"""
|
|
248
|
+
for cb in self._callbacks:
|
|
249
|
+
cb.on_executor_post(execution_results, question)
|
|
250
|
+
|
|
251
|
+
# Infer end
|
|
252
|
+
def on_infer_end(self, final_result: Any, question: Question, metrics: Any) -> None: # noqa: ANN401 # FIXME: 同Callback
|
|
253
|
+
"""对应Callback的on_infer_end"""
|
|
254
|
+
for cb in self._callbacks:
|
|
255
|
+
cb.on_infer_end(final_result, question, metrics)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
## 默认strategy
|
|
2
|
+
|
|
3
|
+
### SequentialCyclic
|
|
4
|
+
按顺序逐个选择规则,每次选1条。
|
|
5
|
+
|
|
6
|
+
### SequentialCyclicWithPriority
|
|
7
|
+
根据规则的priority排序后,按顺序逐个选择规则,每次选1条。
|
|
8
|
+
|
|
9
|
+
## 创建自己的strategy
|
|
10
|
+
1. 创建一个py文件,命名要求为`_<name>_strategy.py`;
|
|
11
|
+
2. 继承RuleSelectionStrategy类,并至少声明此Protocol要求的函数;
|
|
12
|
+
3. 使用`@register_strategy('<name>')`注册你的策略类,后续即可通过`grounding_rule_strategy`使用策略;
|
|
13
|
+
4. 注意调整`grounding_rule_strategy`的类型标注(增加Literal的候选值)。
|