vulcan-core 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vulcan_core/__init__.py +45 -0
- vulcan_core/actions.py +31 -0
- vulcan_core/ast_utils.py +506 -0
- vulcan_core/conditions.py +432 -0
- vulcan_core/engine.py +287 -0
- vulcan_core/models.py +271 -0
- vulcan_core/reporting.py +595 -0
- vulcan_core/util.py +127 -0
- vulcan_core-1.2.1.dist-info/METADATA +88 -0
- vulcan_core-1.2.1.dist-info/RECORD +11 -0
- vulcan_core-1.2.1.dist-info/WHEEL +4 -0
vulcan_core/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from vulcan_core.actions import Action, action
|
|
5
|
+
from vulcan_core.ast_utils import (
|
|
6
|
+
ASTProcessingError,
|
|
7
|
+
CallableSignatureError,
|
|
8
|
+
ContractError,
|
|
9
|
+
NotAFactError,
|
|
10
|
+
ScopeAccessError,
|
|
11
|
+
)
|
|
12
|
+
from vulcan_core.conditions import (
|
|
13
|
+
CompoundCondition,
|
|
14
|
+
Condition,
|
|
15
|
+
MissingFactError,
|
|
16
|
+
OnFactChanged,
|
|
17
|
+
Operator,
|
|
18
|
+
condition,
|
|
19
|
+
)
|
|
20
|
+
from vulcan_core.engine import InternalStateError, RecursionLimitError, Rule, RuleEngine
|
|
21
|
+
from vulcan_core.models import ActionReturn, ChunkingStrategy, Fact, Similarity
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ASTProcessingError",
|
|
25
|
+
"Action",
|
|
26
|
+
"ActionReturn",
|
|
27
|
+
"CallableSignatureError",
|
|
28
|
+
"ChunkingStrategy",
|
|
29
|
+
"CompoundCondition",
|
|
30
|
+
"Condition",
|
|
31
|
+
"ContractError",
|
|
32
|
+
"Fact",
|
|
33
|
+
"InternalStateError",
|
|
34
|
+
"MissingFactError",
|
|
35
|
+
"NotAFactError",
|
|
36
|
+
"OnFactChanged",
|
|
37
|
+
"Operator",
|
|
38
|
+
"RecursionLimitError",
|
|
39
|
+
"Rule",
|
|
40
|
+
"RuleEngine",
|
|
41
|
+
"ScopeAccessError",
|
|
42
|
+
"Similarity",
|
|
43
|
+
"action",
|
|
44
|
+
"condition",
|
|
45
|
+
]
|
vulcan_core/actions.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import partial
|
|
8
|
+
|
|
9
|
+
from vulcan_core.ast_utils import ASTProcessor
|
|
10
|
+
from vulcan_core.models import ActionCallable, ActionReturn, DeclaresFacts, Fact, FactHandler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True, slots=True)
|
|
14
|
+
class Action(FactHandler[ActionCallable, ActionReturn], DeclaresFacts):
|
|
15
|
+
"""
|
|
16
|
+
Represents a deferred result calculation of a rule.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __call__(self, *args: Fact) -> ActionReturn:
|
|
20
|
+
return self._evaluate(*args)
|
|
21
|
+
|
|
22
|
+
def _evaluate(self, *args: Fact) -> ActionReturn:
|
|
23
|
+
return self.func(*args)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def action(value: ActionCallable | ActionReturn) -> Action:
|
|
27
|
+
if not isinstance(value, partial) and callable(value):
|
|
28
|
+
processed = ASTProcessor[ActionCallable](value, action, ActionReturn)
|
|
29
|
+
return Action(processed.facts, processed.func)
|
|
30
|
+
else:
|
|
31
|
+
return Action((), lambda: value)
|
vulcan_core/ast_utils.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
import ast
|
|
5
|
+
import inspect
|
|
6
|
+
import io
|
|
7
|
+
import logging
|
|
8
|
+
import re
|
|
9
|
+
import textwrap
|
|
10
|
+
import tokenize
|
|
11
|
+
from ast import Attribute, Module, Name, NodeTransformer, NodeVisitor
|
|
12
|
+
from collections import OrderedDict
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from functools import cached_property
|
|
16
|
+
from types import MappingProxyType
|
|
17
|
+
from typing import Any, ClassVar, TypeAliasType, get_type_hints
|
|
18
|
+
|
|
19
|
+
from vulcan_core.models import Fact, HasSource
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ASTProcessingError(RuntimeError):
|
|
25
|
+
"""Internal error encountered while processing AST."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ContractError(Exception):
|
|
29
|
+
"""Base exception for callable contract violations."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ScopeAccessError(ContractError):
|
|
33
|
+
"""Raised when a callable attempts to access instances not passed as parameters or when decorated functions attempt
|
|
34
|
+
to access class attributes instead of parameter instance attributes."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class NotAFactError(ContractError):
|
|
38
|
+
"""Raised when a callable parameter, or accessed attribute is not a subclass of Fact."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, type_obj: type) -> None:
|
|
41
|
+
message = f"'{type_obj.__name__}' is not a Fact subclass"
|
|
42
|
+
super().__init__(message)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CallableSignatureError(ContractError):
|
|
46
|
+
"""Raised when a decorated function has any missing type hints, an incorrect return type, or if a lambda that
|
|
47
|
+
requires arguments is provided."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class _AttributeVisitor(NodeVisitor):
|
|
51
|
+
"""Visitor to collect attribute accesses from the AST."""
|
|
52
|
+
|
|
53
|
+
def __init__(self):
|
|
54
|
+
self.attributes = []
|
|
55
|
+
|
|
56
|
+
def visit_Attribute(self, node):
|
|
57
|
+
if isinstance(node.value, Name):
|
|
58
|
+
self.attributes.append((node.value.id, node.attr))
|
|
59
|
+
self.generic_visit(node) # Continue traversing the AST
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _NestedAttributeVisitor(NodeVisitor):
|
|
63
|
+
"""Visitor to detect nested attribute access."""
|
|
64
|
+
|
|
65
|
+
def __init__(self):
|
|
66
|
+
self.has_nested = False
|
|
67
|
+
|
|
68
|
+
def visit_Attribute(self, node):
|
|
69
|
+
if isinstance(node.value, Attribute):
|
|
70
|
+
self.has_nested = True
|
|
71
|
+
self.generic_visit(node)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AttributeTransformer(NodeTransformer):
|
|
75
|
+
"""Transformer to replace static class attribute access with parameterized instances."""
|
|
76
|
+
|
|
77
|
+
def __init__(self, class_to_param):
|
|
78
|
+
self.class_to_param = class_to_param
|
|
79
|
+
|
|
80
|
+
def visit_Attribute(self, node: Attribute):
|
|
81
|
+
node = self.generic_visit(node) # type: ignore
|
|
82
|
+
|
|
83
|
+
if isinstance(node.value, Name) and node.value.id in self.class_to_param:
|
|
84
|
+
return Attribute(
|
|
85
|
+
value=Name(id=self.class_to_param[node.value.id], ctx=node.value.ctx),
|
|
86
|
+
attr=node.attr,
|
|
87
|
+
ctx=node.ctx,
|
|
88
|
+
)
|
|
89
|
+
return node
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass(slots=True)
|
|
93
|
+
class LambdaTracker:
|
|
94
|
+
"""Index entry for tracking the parsing position of lambda functions in source lines.
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
source (str): The source code string containing lambda functions
|
|
98
|
+
positions (list[int]): Positions where lambda functions are found in the source
|
|
99
|
+
index (int): The lambda being parsed within the source string.
|
|
100
|
+
in_use (bool): Whether this source is currently being processed or not, making it eligible for cache deletion.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
source: str
|
|
104
|
+
positions: list[int]
|
|
105
|
+
index: int = field(default=0)
|
|
106
|
+
in_use: bool = field(default=True)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class ASTProcessor[T: Callable]:
|
|
111
|
+
"""
|
|
112
|
+
This class extracts source code from functions or lambda expressions, parses them into
|
|
113
|
+
Abstract Syntax Trees (AST), and performs various validations and transformations.
|
|
114
|
+
|
|
115
|
+
The processor validates that:
|
|
116
|
+
- Functions have proper type hints for parameters and return types
|
|
117
|
+
- All parameters are subclasses of Fact
|
|
118
|
+
- No nested attribute access (e.g., X.y.z) is used
|
|
119
|
+
- No async functions are processed
|
|
120
|
+
- Lambda expressions do not contain parameters
|
|
121
|
+
- No duplicate parameter types in function signatures
|
|
122
|
+
|
|
123
|
+
For lambda expressions, it automatically transforms attribute access patterns
|
|
124
|
+
(e.g., ClassName.attribute) into parameterized functions for easier execution.
|
|
125
|
+
|
|
126
|
+
Note: This class is not thread-safe and should not be used concurrently across multiple threads.
|
|
127
|
+
|
|
128
|
+
Type Parameters:
|
|
129
|
+
T: The type signature the processor is working with, this varies based on a condition or action being processed.
|
|
130
|
+
|
|
131
|
+
Attributes:
|
|
132
|
+
func: The callable to process, a lambda or a function
|
|
133
|
+
decorator: The decorator type that initiated the processing (e.g., `condition` or `action`)
|
|
134
|
+
return_type: Expected return type for the callable
|
|
135
|
+
source: Extracted source code of func (set during post-init)
|
|
136
|
+
tree: Parsed AST of the source code (set during post-init)
|
|
137
|
+
facts: Tuple of fact strings discovered in the callable (set during post-init)
|
|
138
|
+
|
|
139
|
+
Properties:
|
|
140
|
+
is_lambda: True if the callable is a lambda expression
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
OSError: When source code cannot be extracted
|
|
144
|
+
ScopeAccessError: When accessing undefined classes or using nested attributes
|
|
145
|
+
CallableSignatureError: When function signature doesn't meet requirements
|
|
146
|
+
NotAFactError: When parameter types are not Fact subclasses
|
|
147
|
+
ASTProcessingError: When AST processing encounters internal errors
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
func: T
|
|
151
|
+
decorator: Callable
|
|
152
|
+
return_type: type | TypeAliasType
|
|
153
|
+
source: str = field(init=False)
|
|
154
|
+
tree: Module = field(init=False)
|
|
155
|
+
facts: tuple[str, ...] = field(init=False)
|
|
156
|
+
|
|
157
|
+
# Class-level tracking of lambdas across parsing calls to handle multiple lambdas on the same line
|
|
158
|
+
_lambda_cache: ClassVar[OrderedDict[str, LambdaTracker]] = OrderedDict()
|
|
159
|
+
_MAX_LAMBDA_CACHE_SIZE: ClassVar[int] = 1024
|
|
160
|
+
|
|
161
|
+
@cached_property
|
|
162
|
+
def is_lambda(self) -> bool:
|
|
163
|
+
return isinstance(self.func, type(lambda: None)) and self.func.__name__ == "<lambda>"
|
|
164
|
+
|
|
165
|
+
def __post_init__(self):
|
|
166
|
+
# Extract source code and parse AST
|
|
167
|
+
if isinstance(self.func, HasSource):
|
|
168
|
+
self.source = self.func.__source__
|
|
169
|
+
else:
|
|
170
|
+
try:
|
|
171
|
+
if self.is_lambda:
|
|
172
|
+
# As of Python 3.12, there is no way to determine to which lambda self.func refers in an
|
|
173
|
+
# expression containing multiple lambdas. Therefore we use a dict to track the index of each
|
|
174
|
+
# lambda function encountered, as the order will correspond to the order of ASTProcessor
|
|
175
|
+
# invocations for that line. An additional benefit is that we can also use this as a cache to
|
|
176
|
+
# avoid re-reading and parsing the source code for lambda functions sharing the same line.
|
|
177
|
+
source_line = f"{self.func.__code__.co_filename}:{self.func.__code__.co_firstlineno}"
|
|
178
|
+
tracker = self._lambda_cache.get(source_line)
|
|
179
|
+
|
|
180
|
+
if tracker is None:
|
|
181
|
+
self.source = self._get_lambda_source()
|
|
182
|
+
positions = self._find_lambdas(self.source)
|
|
183
|
+
|
|
184
|
+
tracker = LambdaTracker(self.source, positions)
|
|
185
|
+
self._lambda_cache[source_line] = tracker
|
|
186
|
+
self._trim_lambda_cache()
|
|
187
|
+
else:
|
|
188
|
+
tracker.index += 1
|
|
189
|
+
|
|
190
|
+
# Reset the position if it exceeds the count of lambda expressions
|
|
191
|
+
if tracker.index >= len(tracker.positions):
|
|
192
|
+
tracker.index = 0
|
|
193
|
+
|
|
194
|
+
# Extract the next lambda source based on the current tracking state
|
|
195
|
+
self.source = self._extract_next_lambda(tracker)
|
|
196
|
+
|
|
197
|
+
# If all found lambdas have been processed, mark the tracker as not in use
|
|
198
|
+
if tracker.index >= len(tracker.positions) - 1:
|
|
199
|
+
tracker.in_use = False
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
self.source = textwrap.dedent(inspect.getsource(self.func))
|
|
203
|
+
except OSError as e:
|
|
204
|
+
if str(e) == "could not get source code":
|
|
205
|
+
msg = "could not get source code. Try recursively deleting all __pycache__ folders in your project."
|
|
206
|
+
raise OSError(msg) from e
|
|
207
|
+
else:
|
|
208
|
+
raise
|
|
209
|
+
self.func.__source__ = self.source
|
|
210
|
+
|
|
211
|
+
# Parse the AST with minimal error handling
|
|
212
|
+
self.tree = ast.parse(self.source)
|
|
213
|
+
|
|
214
|
+
# Perform basic AST checks and attribute discovery
|
|
215
|
+
self._validate_ast()
|
|
216
|
+
attributes = self._discover_attributes()
|
|
217
|
+
|
|
218
|
+
if self.is_lambda:
|
|
219
|
+
# Process attributes and create a transformed lambda
|
|
220
|
+
caller_globals = self._get_caller_globals()
|
|
221
|
+
facts, class_to_param = self._resolve_facts(attributes, caller_globals)
|
|
222
|
+
|
|
223
|
+
self.facts = tuple(facts)
|
|
224
|
+
self.func = self._transform_lambda(class_to_param, caller_globals)
|
|
225
|
+
|
|
226
|
+
else:
|
|
227
|
+
# Get function metadata and validate signature
|
|
228
|
+
hints = get_type_hints(self.func)
|
|
229
|
+
params = inspect.signature(self.func).parameters # type: ignore
|
|
230
|
+
self._validate_signature(hints, params)
|
|
231
|
+
|
|
232
|
+
# Process attributes
|
|
233
|
+
facts: list[str] = []
|
|
234
|
+
param_names = list(params)
|
|
235
|
+
|
|
236
|
+
# Create the list of accessed facts and verify they are in the correct scope
|
|
237
|
+
for class_name, attr in attributes:
|
|
238
|
+
if class_name not in param_names:
|
|
239
|
+
msg = f"Accessing class '{class_name}' not passed as parameter"
|
|
240
|
+
raise ScopeAccessError(msg)
|
|
241
|
+
facts.append(f"{hints[class_name].__name__}.{attr}")
|
|
242
|
+
|
|
243
|
+
self.facts = tuple(facts)
|
|
244
|
+
|
|
245
|
+
def _trim_lambda_cache(self) -> None:
|
|
246
|
+
"""Clean up lambda cache by removing oldest unused entries when cache size exceeds limit."""
|
|
247
|
+
if len(self._lambda_cache) <= self._MAX_LAMBDA_CACHE_SIZE:
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
# Calculate how many entries to remove (excess + 20% buffer to avoid thrashing)
|
|
251
|
+
excess_count = len(self._lambda_cache) - self._MAX_LAMBDA_CACHE_SIZE
|
|
252
|
+
buffer_count = int(self._MAX_LAMBDA_CACHE_SIZE * 0.2)
|
|
253
|
+
target_count = excess_count + buffer_count
|
|
254
|
+
|
|
255
|
+
# Find and remove unused entries
|
|
256
|
+
removed_count = 0
|
|
257
|
+
for key in list(self._lambda_cache):
|
|
258
|
+
if removed_count >= target_count:
|
|
259
|
+
break
|
|
260
|
+
if not self._lambda_cache[key].in_use:
|
|
261
|
+
del self._lambda_cache[key]
|
|
262
|
+
removed_count += 1
|
|
263
|
+
|
|
264
|
+
def _find_lambdas(self, source: str) -> list[int]:
|
|
265
|
+
"""Find all lambda expressions in the source code and return their starting positions."""
|
|
266
|
+
tokens = tokenize.generate_tokens(io.StringIO(source).readline)
|
|
267
|
+
lambda_positions = [
|
|
268
|
+
token.start[1] for token in tokens if token.type == tokenize.NAME and token.string == "lambda"
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
return lambda_positions
|
|
272
|
+
|
|
273
|
+
def _get_lambda_source(self) -> str:
|
|
274
|
+
"""Get single and multiline lambda source using AST parsing of the source file."""
|
|
275
|
+
source = None
|
|
276
|
+
|
|
277
|
+
try:
|
|
278
|
+
# Get the source file and line number
|
|
279
|
+
# Avoid reading source from files directly, as it may fail in some cases (e.g., lambdas in REPL)
|
|
280
|
+
file_content = "".join(inspect.findsource(self.func)[0])
|
|
281
|
+
lambda_lineno = self.func.__code__.co_firstlineno
|
|
282
|
+
|
|
283
|
+
# Parse the AST of the source file
|
|
284
|
+
file_ast = ast.parse(file_content)
|
|
285
|
+
|
|
286
|
+
# Find the lambda expression at the specific line number
|
|
287
|
+
class LambdaFinder(ast.NodeVisitor):
|
|
288
|
+
def __init__(self, target_lineno):
|
|
289
|
+
self.target_lineno = target_lineno
|
|
290
|
+
self.found_lambda = None
|
|
291
|
+
|
|
292
|
+
def visit_Lambda(self, node):
|
|
293
|
+
if node.lineno == self.target_lineno:
|
|
294
|
+
self.found_lambda = node
|
|
295
|
+
self.generic_visit(node)
|
|
296
|
+
|
|
297
|
+
finder = LambdaFinder(lambda_lineno)
|
|
298
|
+
finder.visit(file_ast)
|
|
299
|
+
|
|
300
|
+
if finder.found_lambda:
|
|
301
|
+
# Get the source lines that contain this lambda
|
|
302
|
+
lines = file_content.split("\n")
|
|
303
|
+
start_line = finder.found_lambda.lineno - 1
|
|
304
|
+
|
|
305
|
+
# Find the end of the lambda expression
|
|
306
|
+
end_line = start_line
|
|
307
|
+
if hasattr(finder.found_lambda, "end_lineno") and finder.found_lambda.end_lineno:
|
|
308
|
+
end_line = finder.found_lambda.end_lineno - 1
|
|
309
|
+
else:
|
|
310
|
+
# Fallback: find the closing parenthesis
|
|
311
|
+
paren_count = 0
|
|
312
|
+
for i in range(start_line, len(lines)):
|
|
313
|
+
line = lines[i]
|
|
314
|
+
paren_count += line.count("(") - line.count(")")
|
|
315
|
+
if paren_count <= 0 and ")" in line:
|
|
316
|
+
end_line = i
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
source = "\n".join(lines[start_line : end_line + 1])
|
|
320
|
+
|
|
321
|
+
except (OSError, SyntaxError, AttributeError):
|
|
322
|
+
logger.exception("Failed to extract lambda source, attempting fallback.")
|
|
323
|
+
source = inspect.getsource(self.func).strip()
|
|
324
|
+
|
|
325
|
+
if source is None or source == "":
|
|
326
|
+
msg = "Could not extract lambda source code"
|
|
327
|
+
raise ASTProcessingError(msg)
|
|
328
|
+
|
|
329
|
+
# Normalize the source: convert line breaks to spaces, collapse whitespace, and dedent
|
|
330
|
+
source = re.sub(r"\r\n|\r|\n", " ", source)
|
|
331
|
+
source = re.sub(r"\s+", " ", source)
|
|
332
|
+
source = textwrap.dedent(source)
|
|
333
|
+
|
|
334
|
+
return source
|
|
335
|
+
|
|
336
|
+
def _extract_next_lambda(self, src: LambdaTracker) -> str:
|
|
337
|
+
"""Extracts the next lambda expression from source code."""
|
|
338
|
+
source = src.source
|
|
339
|
+
index = src.index
|
|
340
|
+
lambda_start = src.positions[index]
|
|
341
|
+
|
|
342
|
+
# The source may include unrelated code (e.g., assignment and condition() call)
|
|
343
|
+
# So we need to extract just the lambda expression, handling nested structures correctly
|
|
344
|
+
source = source[lambda_start:]
|
|
345
|
+
|
|
346
|
+
# Track depth of various brackets to ensure we don't split inside valid nested structures apart from trailing
|
|
347
|
+
# arguments within the condition() call
|
|
348
|
+
paren_level = 0
|
|
349
|
+
bracket_level = 0
|
|
350
|
+
brace_level = 0
|
|
351
|
+
|
|
352
|
+
for i, char in enumerate(source):
|
|
353
|
+
if char == "(":
|
|
354
|
+
paren_level += 1
|
|
355
|
+
elif char == ")":
|
|
356
|
+
if paren_level > 0:
|
|
357
|
+
paren_level -= 1
|
|
358
|
+
elif paren_level == 0: # End of expression in a function call
|
|
359
|
+
return source[:i]
|
|
360
|
+
elif char == "[":
|
|
361
|
+
bracket_level += 1
|
|
362
|
+
elif char == "]":
|
|
363
|
+
if bracket_level > 0:
|
|
364
|
+
bracket_level -= 1
|
|
365
|
+
elif char == "{":
|
|
366
|
+
brace_level += 1
|
|
367
|
+
elif char == "}":
|
|
368
|
+
if brace_level > 0:
|
|
369
|
+
brace_level -= 1
|
|
370
|
+
# Only consider comma as a separator when not inside any brackets
|
|
371
|
+
elif char == "," and paren_level == 0 and bracket_level == 0 and brace_level == 0:
|
|
372
|
+
return source[:i]
|
|
373
|
+
|
|
374
|
+
return source
|
|
375
|
+
|
|
376
|
+
def _get_caller_globals(self) -> dict[str, Any]:
|
|
377
|
+
"""Find the globals of the caller of the decorator in order to validate accessed types."""
|
|
378
|
+
try:
|
|
379
|
+
decorator_name = self.decorator.__name__
|
|
380
|
+
frame = inspect.currentframe()
|
|
381
|
+
while frame.f_code.co_name != decorator_name: # type: ignore
|
|
382
|
+
frame = frame.f_back # type: ignore
|
|
383
|
+
return frame.f_back.f_globals # type: ignore # noqa: TRY300
|
|
384
|
+
|
|
385
|
+
except AttributeError as err: # pragma: no cover - internal AST error
|
|
386
|
+
msg = f"Unable to locate caller ('{decorator_name}') globals"
|
|
387
|
+
raise ASTProcessingError(msg) from err
|
|
388
|
+
|
|
389
|
+
def _validate_ast(self) -> None:
|
|
390
|
+
# Check for nested attribute access
|
|
391
|
+
visitor = _NestedAttributeVisitor()
|
|
392
|
+
visitor.visit(self.tree)
|
|
393
|
+
if visitor.has_nested:
|
|
394
|
+
msg = "Nested attribute access (X.y.z) is not allowed"
|
|
395
|
+
raise ScopeAccessError(msg)
|
|
396
|
+
|
|
397
|
+
# Checks for async functions
|
|
398
|
+
if isinstance(self.tree.body[0], ast.AsyncFunctionDef):
|
|
399
|
+
msg = "Async functions are not supported"
|
|
400
|
+
raise CallableSignatureError(msg)
|
|
401
|
+
|
|
402
|
+
# Lambda-specific checks
|
|
403
|
+
if self.is_lambda:
|
|
404
|
+
if not isinstance(self.tree, ast.Module) or not isinstance(
|
|
405
|
+
self.tree.body[0], ast.Expr
|
|
406
|
+
): # pragma: no cover - internal AST error
|
|
407
|
+
msg = "Expected an expression in AST body"
|
|
408
|
+
raise ASTProcessingError(msg)
|
|
409
|
+
|
|
410
|
+
lambda_node = self.tree.body[0].value
|
|
411
|
+
if not isinstance(lambda_node, ast.Lambda): # pragma: no cover - internal AST error
|
|
412
|
+
msg = "Expected a lambda expression"
|
|
413
|
+
raise ASTProcessingError(msg)
|
|
414
|
+
|
|
415
|
+
if lambda_node.args.args:
|
|
416
|
+
msg = "Lambda expressions must not have parameters"
|
|
417
|
+
raise CallableSignatureError(msg)
|
|
418
|
+
|
|
419
|
+
def _discover_attributes(self) -> list[tuple[str, str]]:
|
|
420
|
+
"""Discover attributes accessed within the AST."""
|
|
421
|
+
visitor = _AttributeVisitor()
|
|
422
|
+
visitor.visit(self.tree)
|
|
423
|
+
return visitor.attributes
|
|
424
|
+
|
|
425
|
+
def _resolve_facts(self, attributes: list[tuple[str, str]], globals_dict: dict) -> tuple[list[str], dict[str, str]]:
|
|
426
|
+
"""Validate attribute accesses and return normalized fact strings."""
|
|
427
|
+
facts = []
|
|
428
|
+
class_to_param = {}
|
|
429
|
+
param_counter = 0
|
|
430
|
+
|
|
431
|
+
for class_name, attr in attributes:
|
|
432
|
+
# Verify the name refers to a class type
|
|
433
|
+
if class_name not in globals_dict or not isinstance(globals_dict[class_name], type):
|
|
434
|
+
msg = f"Accessing undefined class '{class_name}'"
|
|
435
|
+
raise ScopeAccessError(msg)
|
|
436
|
+
|
|
437
|
+
# Verify it's a Fact subclass
|
|
438
|
+
class_obj = globals_dict[class_name]
|
|
439
|
+
if not issubclass(class_obj, Fact):
|
|
440
|
+
raise NotAFactError(class_obj)
|
|
441
|
+
|
|
442
|
+
facts.append(f"{class_name}.{attr}")
|
|
443
|
+
if class_name not in class_to_param:
|
|
444
|
+
class_to_param[class_name] = f"p{param_counter}"
|
|
445
|
+
param_counter += 1
|
|
446
|
+
|
|
447
|
+
# Deduplicate facts while preserving order
|
|
448
|
+
seen = set()
|
|
449
|
+
facts = [fact for fact in facts if not (fact in seen or seen.add(fact))]
|
|
450
|
+
|
|
451
|
+
return facts, class_to_param
|
|
452
|
+
|
|
453
|
+
def _validate_signature(self, hints: dict, params: MappingProxyType[str, inspect.Parameter]) -> None:
|
|
454
|
+
"""Validate function signature requirements."""
|
|
455
|
+
|
|
456
|
+
# Validate return type
|
|
457
|
+
if "return" not in hints or hints["return"] is not self.return_type:
|
|
458
|
+
msg = f"Return type hint is required and must be {self.return_type!r}"
|
|
459
|
+
raise CallableSignatureError(msg)
|
|
460
|
+
|
|
461
|
+
# Track parameter types to check for duplicates
|
|
462
|
+
param_types = []
|
|
463
|
+
|
|
464
|
+
# Validate parameters
|
|
465
|
+
for param in params.values():
|
|
466
|
+
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
|
|
467
|
+
msg = "Variable arguments (*args, **kwargs) are not supported"
|
|
468
|
+
raise CallableSignatureError(msg)
|
|
469
|
+
|
|
470
|
+
if param.name not in hints:
|
|
471
|
+
msg = "All parameters must have type hints"
|
|
472
|
+
raise CallableSignatureError(msg)
|
|
473
|
+
|
|
474
|
+
if param.name != "return":
|
|
475
|
+
param_type = hints[param.name]
|
|
476
|
+
if not isinstance(param_type, type) or not issubclass(param_type, Fact):
|
|
477
|
+
raise NotAFactError(param_type)
|
|
478
|
+
param_types.append(param_type)
|
|
479
|
+
|
|
480
|
+
# Check for duplicate parameter types
|
|
481
|
+
seen_types = set()
|
|
482
|
+
for param_type in param_types:
|
|
483
|
+
if param_type in seen_types:
|
|
484
|
+
msg = f"Duplicate parameter type '{param_type.__name__}' is not allowed"
|
|
485
|
+
raise CallableSignatureError(msg)
|
|
486
|
+
seen_types.add(param_type)
|
|
487
|
+
|
|
488
|
+
def _transform_lambda(self, class_to_param: dict[str, str], caller_globals: dict[str, Any]) -> T:
|
|
489
|
+
# Transform and create new lambda
|
|
490
|
+
transformer = AttributeTransformer(class_to_param)
|
|
491
|
+
new_tree = transformer.visit(self.tree)
|
|
492
|
+
lambda_body = ast.unparse(new_tree.body[0].value)
|
|
493
|
+
|
|
494
|
+
# The AST unparsing creates a full lambda expression, but we only want its body. This handles edge cases where
|
|
495
|
+
# the transformed AST might generate different lambda syntax than the original source code, ensuring we only
|
|
496
|
+
# get the expression part.
|
|
497
|
+
if lambda_body.startswith("lambda"):
|
|
498
|
+
lambda_body = lambda_body[lambda_body.find(":") + 1 :].strip()
|
|
499
|
+
|
|
500
|
+
# Create a new lambda object with the transformed body
|
|
501
|
+
# TODO: Find a way to avoid using exec or eval here
|
|
502
|
+
lambda_code = f"lambda {', '.join(class_to_param.values())}: {lambda_body}"
|
|
503
|
+
new_func = eval(lambda_code, caller_globals) # noqa: S307 # nosec B307
|
|
504
|
+
new_func.__source__ = self.source
|
|
505
|
+
|
|
506
|
+
return new_func
|