vulcan-core 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,45 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright 2025 Latchfield Technologies http://latchfield.com
3
+
4
+ from vulcan_core.actions import Action, action
5
+ from vulcan_core.ast_utils import (
6
+ ASTProcessingError,
7
+ CallableSignatureError,
8
+ ContractError,
9
+ NotAFactError,
10
+ ScopeAccessError,
11
+ )
12
+ from vulcan_core.conditions import (
13
+ CompoundCondition,
14
+ Condition,
15
+ MissingFactError,
16
+ OnFactChanged,
17
+ Operator,
18
+ condition,
19
+ )
20
+ from vulcan_core.engine import InternalStateError, RecursionLimitError, Rule, RuleEngine
21
+ from vulcan_core.models import ActionReturn, ChunkingStrategy, Fact, Similarity
22
+
23
+ __all__ = [
24
+ "ASTProcessingError",
25
+ "Action",
26
+ "ActionReturn",
27
+ "CallableSignatureError",
28
+ "ChunkingStrategy",
29
+ "CompoundCondition",
30
+ "Condition",
31
+ "ContractError",
32
+ "Fact",
33
+ "InternalStateError",
34
+ "MissingFactError",
35
+ "NotAFactError",
36
+ "OnFactChanged",
37
+ "Operator",
38
+ "RecursionLimitError",
39
+ "Rule",
40
+ "RuleEngine",
41
+ "ScopeAccessError",
42
+ "Similarity",
43
+ "action",
44
+ "condition",
45
+ ]
vulcan_core/actions.py ADDED
@@ -0,0 +1,31 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright 2025 Latchfield Technologies http://latchfield.com
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+
9
+ from vulcan_core.ast_utils import ASTProcessor
10
+ from vulcan_core.models import ActionCallable, ActionReturn, DeclaresFacts, Fact, FactHandler
11
+
12
+
13
+ @dataclass(frozen=True, slots=True)
14
+ class Action(FactHandler[ActionCallable, ActionReturn], DeclaresFacts):
15
+ """
16
+ Represents a deferred result calculation of a rule.
17
+ """
18
+
19
+ def __call__(self, *args: Fact) -> ActionReturn:
20
+ return self._evaluate(*args)
21
+
22
+ def _evaluate(self, *args: Fact) -> ActionReturn:
23
+ return self.func(*args)
24
+
25
+
26
+ def action(value: ActionCallable | ActionReturn) -> Action:
27
+ if not isinstance(value, partial) and callable(value):
28
+ processed = ASTProcessor[ActionCallable](value, action, ActionReturn)
29
+ return Action(processed.facts, processed.func)
30
+ else:
31
+ return Action((), lambda: value)
@@ -0,0 +1,506 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright 2025 Latchfield Technologies http://latchfield.com
3
+
4
+ import ast
5
+ import inspect
6
+ import io
7
+ import logging
8
+ import re
9
+ import textwrap
10
+ import tokenize
11
+ from ast import Attribute, Module, Name, NodeTransformer, NodeVisitor
12
+ from collections import OrderedDict
13
+ from collections.abc import Callable
14
+ from dataclasses import dataclass, field
15
+ from functools import cached_property
16
+ from types import MappingProxyType
17
+ from typing import Any, ClassVar, TypeAliasType, get_type_hints
18
+
19
+ from vulcan_core.models import Fact, HasSource
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ASTProcessingError(RuntimeError):
25
+ """Internal error encountered while processing AST."""
26
+
27
+
28
+ class ContractError(Exception):
29
+ """Base exception for callable contract violations."""
30
+
31
+
32
+ class ScopeAccessError(ContractError):
33
+ """Raised when a callable attempts to access instances not passed as parameters or when decorated functions attempt
34
+ to access class attributes instead of parameter instance attributes."""
35
+
36
+
37
+ class NotAFactError(ContractError):
38
+ """Raised when a callable parameter, or accessed attribute is not a subclass of Fact."""
39
+
40
+ def __init__(self, type_obj: type) -> None:
41
+ message = f"'{type_obj.__name__}' is not a Fact subclass"
42
+ super().__init__(message)
43
+
44
+
45
+ class CallableSignatureError(ContractError):
46
+ """Raised when a decorated function has any missing type hints, an incorrect return type, or if a lambda that
47
+ requires arguments is provided."""
48
+
49
+
50
+ class _AttributeVisitor(NodeVisitor):
51
+ """Visitor to collect attribute accesses from the AST."""
52
+
53
+ def __init__(self):
54
+ self.attributes = []
55
+
56
+ def visit_Attribute(self, node):
57
+ if isinstance(node.value, Name):
58
+ self.attributes.append((node.value.id, node.attr))
59
+ self.generic_visit(node) # Continue traversing the AST
60
+
61
+
62
+ class _NestedAttributeVisitor(NodeVisitor):
63
+ """Visitor to detect nested attribute access."""
64
+
65
+ def __init__(self):
66
+ self.has_nested = False
67
+
68
+ def visit_Attribute(self, node):
69
+ if isinstance(node.value, Attribute):
70
+ self.has_nested = True
71
+ self.generic_visit(node)
72
+
73
+
74
+ class AttributeTransformer(NodeTransformer):
75
+ """Transformer to replace static class attribute access with parameterized instances."""
76
+
77
+ def __init__(self, class_to_param):
78
+ self.class_to_param = class_to_param
79
+
80
+ def visit_Attribute(self, node: Attribute):
81
+ node = self.generic_visit(node) # type: ignore
82
+
83
+ if isinstance(node.value, Name) and node.value.id in self.class_to_param:
84
+ return Attribute(
85
+ value=Name(id=self.class_to_param[node.value.id], ctx=node.value.ctx),
86
+ attr=node.attr,
87
+ ctx=node.ctx,
88
+ )
89
+ return node
90
+
91
+
92
+ @dataclass(slots=True)
93
+ class LambdaTracker:
94
+ """Index entry for tracking the parsing position of lambda functions in source lines.
95
+
96
+ Attributes:
97
+ source (str): The source code string containing lambda functions
98
+ positions (list[int]): Positions where lambda functions are found in the source
99
+ index (int): The lambda being parsed within the source string.
100
+ in_use (bool): Whether this source is currently being processed or not, making it eligible for cache deletion.
101
+ """
102
+
103
+ source: str
104
+ positions: list[int]
105
+ index: int = field(default=0)
106
+ in_use: bool = field(default=True)
107
+
108
+
109
+ @dataclass
110
+ class ASTProcessor[T: Callable]:
111
+ """
112
+ This class extracts source code from functions or lambda expressions, parses them into
113
+ Abstract Syntax Trees (AST), and performs various validations and transformations.
114
+
115
+ The processor validates that:
116
+ - Functions have proper type hints for parameters and return types
117
+ - All parameters are subclasses of Fact
118
+ - No nested attribute access (e.g., X.y.z) is used
119
+ - No async functions are processed
120
+ - Lambda expressions do not contain parameters
121
+ - No duplicate parameter types in function signatures
122
+
123
+ For lambda expressions, it automatically transforms attribute access patterns
124
+ (e.g., ClassName.attribute) into parameterized functions for easier execution.
125
+
126
+ Note: This class is not thread-safe and should not be used concurrently across multiple threads.
127
+
128
+ Type Parameters:
129
+ T: The type signature the processor is working with, this varies based on a condition or action being processed.
130
+
131
+ Attributes:
132
+ func: The callable to process, a lambda or a function
133
+ decorator: The decorator type that initiated the processing (e.g., `condition` or `action`)
134
+ return_type: Expected return type for the callable
135
+ source: Extracted source code of func (set during post-init)
136
+ tree: Parsed AST of the source code (set during post-init)
137
+ facts: Tuple of fact strings discovered in the callable (set during post-init)
138
+
139
+ Properties:
140
+ is_lambda: True if the callable is a lambda expression
141
+
142
+ Raises:
143
+ OSError: When source code cannot be extracted
144
+ ScopeAccessError: When accessing undefined classes or using nested attributes
145
+ CallableSignatureError: When function signature doesn't meet requirements
146
+ NotAFactError: When parameter types are not Fact subclasses
147
+ ASTProcessingError: When AST processing encounters internal errors
148
+ """
149
+
150
+ func: T
151
+ decorator: Callable
152
+ return_type: type | TypeAliasType
153
+ source: str = field(init=False)
154
+ tree: Module = field(init=False)
155
+ facts: tuple[str, ...] = field(init=False)
156
+
157
+ # Class-level tracking of lambdas across parsing calls to handle multiple lambdas on the same line
158
+ _lambda_cache: ClassVar[OrderedDict[str, LambdaTracker]] = OrderedDict()
159
+ _MAX_LAMBDA_CACHE_SIZE: ClassVar[int] = 1024
160
+
161
+ @cached_property
162
+ def is_lambda(self) -> bool:
163
+ return isinstance(self.func, type(lambda: None)) and self.func.__name__ == "<lambda>"
164
+
165
+ def __post_init__(self):
166
+ # Extract source code and parse AST
167
+ if isinstance(self.func, HasSource):
168
+ self.source = self.func.__source__
169
+ else:
170
+ try:
171
+ if self.is_lambda:
172
+ # As of Python 3.12, there is no way to determine to which lambda self.func refers in an
173
+ # expression containing multiple lambdas. Therefore we use a dict to track the index of each
174
+ # lambda function encountered, as the order will correspond to the order of ASTProcessor
175
+ # invocations for that line. An additional benefit is that we can also use this as a cache to
176
+ # avoid re-reading and parsing the source code for lambda functions sharing the same line.
177
+ source_line = f"{self.func.__code__.co_filename}:{self.func.__code__.co_firstlineno}"
178
+ tracker = self._lambda_cache.get(source_line)
179
+
180
+ if tracker is None:
181
+ self.source = self._get_lambda_source()
182
+ positions = self._find_lambdas(self.source)
183
+
184
+ tracker = LambdaTracker(self.source, positions)
185
+ self._lambda_cache[source_line] = tracker
186
+ self._trim_lambda_cache()
187
+ else:
188
+ tracker.index += 1
189
+
190
+ # Reset the position if it exceeds the count of lambda expressions
191
+ if tracker.index >= len(tracker.positions):
192
+ tracker.index = 0
193
+
194
+ # Extract the next lambda source based on the current tracking state
195
+ self.source = self._extract_next_lambda(tracker)
196
+
197
+ # If all found lambdas have been processed, mark the tracker as not in use
198
+ if tracker.index >= len(tracker.positions) - 1:
199
+ tracker.in_use = False
200
+
201
+ else:
202
+ self.source = textwrap.dedent(inspect.getsource(self.func))
203
+ except OSError as e:
204
+ if str(e) == "could not get source code":
205
+ msg = "could not get source code. Try recursively deleting all __pycache__ folders in your project."
206
+ raise OSError(msg) from e
207
+ else:
208
+ raise
209
+ self.func.__source__ = self.source
210
+
211
+ # Parse the AST with minimal error handling
212
+ self.tree = ast.parse(self.source)
213
+
214
+ # Perform basic AST checks and attribute discovery
215
+ self._validate_ast()
216
+ attributes = self._discover_attributes()
217
+
218
+ if self.is_lambda:
219
+ # Process attributes and create a transformed lambda
220
+ caller_globals = self._get_caller_globals()
221
+ facts, class_to_param = self._resolve_facts(attributes, caller_globals)
222
+
223
+ self.facts = tuple(facts)
224
+ self.func = self._transform_lambda(class_to_param, caller_globals)
225
+
226
+ else:
227
+ # Get function metadata and validate signature
228
+ hints = get_type_hints(self.func)
229
+ params = inspect.signature(self.func).parameters # type: ignore
230
+ self._validate_signature(hints, params)
231
+
232
+ # Process attributes
233
+ facts: list[str] = []
234
+ param_names = list(params)
235
+
236
+ # Create the list of accessed facts and verify they are in the correct scope
237
+ for class_name, attr in attributes:
238
+ if class_name not in param_names:
239
+ msg = f"Accessing class '{class_name}' not passed as parameter"
240
+ raise ScopeAccessError(msg)
241
+ facts.append(f"{hints[class_name].__name__}.{attr}")
242
+
243
+ self.facts = tuple(facts)
244
+
245
+ def _trim_lambda_cache(self) -> None:
246
+ """Clean up lambda cache by removing oldest unused entries when cache size exceeds limit."""
247
+ if len(self._lambda_cache) <= self._MAX_LAMBDA_CACHE_SIZE:
248
+ return
249
+
250
+ # Calculate how many entries to remove (excess + 20% buffer to avoid thrashing)
251
+ excess_count = len(self._lambda_cache) - self._MAX_LAMBDA_CACHE_SIZE
252
+ buffer_count = int(self._MAX_LAMBDA_CACHE_SIZE * 0.2)
253
+ target_count = excess_count + buffer_count
254
+
255
+ # Find and remove unused entries
256
+ removed_count = 0
257
+ for key in list(self._lambda_cache):
258
+ if removed_count >= target_count:
259
+ break
260
+ if not self._lambda_cache[key].in_use:
261
+ del self._lambda_cache[key]
262
+ removed_count += 1
263
+
264
+ def _find_lambdas(self, source: str) -> list[int]:
265
+ """Find all lambda expressions in the source code and return their starting positions."""
266
+ tokens = tokenize.generate_tokens(io.StringIO(source).readline)
267
+ lambda_positions = [
268
+ token.start[1] for token in tokens if token.type == tokenize.NAME and token.string == "lambda"
269
+ ]
270
+
271
+ return lambda_positions
272
+
273
+ def _get_lambda_source(self) -> str:
274
+ """Get single and multiline lambda source using AST parsing of the source file."""
275
+ source = None
276
+
277
+ try:
278
+ # Get the source file and line number
279
+ # Avoid reading source from files directly, as it may fail in some cases (e.g., lambdas in REPL)
280
+ file_content = "".join(inspect.findsource(self.func)[0])
281
+ lambda_lineno = self.func.__code__.co_firstlineno
282
+
283
+ # Parse the AST of the source file
284
+ file_ast = ast.parse(file_content)
285
+
286
+ # Find the lambda expression at the specific line number
287
+ class LambdaFinder(ast.NodeVisitor):
288
+ def __init__(self, target_lineno):
289
+ self.target_lineno = target_lineno
290
+ self.found_lambda = None
291
+
292
+ def visit_Lambda(self, node):
293
+ if node.lineno == self.target_lineno:
294
+ self.found_lambda = node
295
+ self.generic_visit(node)
296
+
297
+ finder = LambdaFinder(lambda_lineno)
298
+ finder.visit(file_ast)
299
+
300
+ if finder.found_lambda:
301
+ # Get the source lines that contain this lambda
302
+ lines = file_content.split("\n")
303
+ start_line = finder.found_lambda.lineno - 1
304
+
305
+ # Find the end of the lambda expression
306
+ end_line = start_line
307
+ if hasattr(finder.found_lambda, "end_lineno") and finder.found_lambda.end_lineno:
308
+ end_line = finder.found_lambda.end_lineno - 1
309
+ else:
310
+ # Fallback: find the closing parenthesis
311
+ paren_count = 0
312
+ for i in range(start_line, len(lines)):
313
+ line = lines[i]
314
+ paren_count += line.count("(") - line.count(")")
315
+ if paren_count <= 0 and ")" in line:
316
+ end_line = i
317
+ break
318
+
319
+ source = "\n".join(lines[start_line : end_line + 1])
320
+
321
+ except (OSError, SyntaxError, AttributeError):
322
+ logger.exception("Failed to extract lambda source, attempting fallback.")
323
+ source = inspect.getsource(self.func).strip()
324
+
325
+ if source is None or source == "":
326
+ msg = "Could not extract lambda source code"
327
+ raise ASTProcessingError(msg)
328
+
329
+ # Normalize the source: convert line breaks to spaces, collapse whitespace, and dedent
330
+ source = re.sub(r"\r\n|\r|\n", " ", source)
331
+ source = re.sub(r"\s+", " ", source)
332
+ source = textwrap.dedent(source)
333
+
334
+ return source
335
+
336
+ def _extract_next_lambda(self, src: LambdaTracker) -> str:
337
+ """Extracts the next lambda expression from source code."""
338
+ source = src.source
339
+ index = src.index
340
+ lambda_start = src.positions[index]
341
+
342
+ # The source may include unrelated code (e.g., assignment and condition() call)
343
+ # So we need to extract just the lambda expression, handling nested structures correctly
344
+ source = source[lambda_start:]
345
+
346
+ # Track depth of various brackets to ensure we don't split inside valid nested structures apart from trailing
347
+ # arguments within the condition() call
348
+ paren_level = 0
349
+ bracket_level = 0
350
+ brace_level = 0
351
+
352
+ for i, char in enumerate(source):
353
+ if char == "(":
354
+ paren_level += 1
355
+ elif char == ")":
356
+ if paren_level > 0:
357
+ paren_level -= 1
358
+ elif paren_level == 0: # End of expression in a function call
359
+ return source[:i]
360
+ elif char == "[":
361
+ bracket_level += 1
362
+ elif char == "]":
363
+ if bracket_level > 0:
364
+ bracket_level -= 1
365
+ elif char == "{":
366
+ brace_level += 1
367
+ elif char == "}":
368
+ if brace_level > 0:
369
+ brace_level -= 1
370
+ # Only consider comma as a separator when not inside any brackets
371
+ elif char == "," and paren_level == 0 and bracket_level == 0 and brace_level == 0:
372
+ return source[:i]
373
+
374
+ return source
375
+
376
+ def _get_caller_globals(self) -> dict[str, Any]:
377
+ """Find the globals of the caller of the decorator in order to validate accessed types."""
378
+ try:
379
+ decorator_name = self.decorator.__name__
380
+ frame = inspect.currentframe()
381
+ while frame.f_code.co_name != decorator_name: # type: ignore
382
+ frame = frame.f_back # type: ignore
383
+ return frame.f_back.f_globals # type: ignore # noqa: TRY300
384
+
385
+ except AttributeError as err: # pragma: no cover - internal AST error
386
+ msg = f"Unable to locate caller ('{decorator_name}') globals"
387
+ raise ASTProcessingError(msg) from err
388
+
389
+ def _validate_ast(self) -> None:
390
+ # Check for nested attribute access
391
+ visitor = _NestedAttributeVisitor()
392
+ visitor.visit(self.tree)
393
+ if visitor.has_nested:
394
+ msg = "Nested attribute access (X.y.z) is not allowed"
395
+ raise ScopeAccessError(msg)
396
+
397
+ # Checks for async functions
398
+ if isinstance(self.tree.body[0], ast.AsyncFunctionDef):
399
+ msg = "Async functions are not supported"
400
+ raise CallableSignatureError(msg)
401
+
402
+ # Lambda-specific checks
403
+ if self.is_lambda:
404
+ if not isinstance(self.tree, ast.Module) or not isinstance(
405
+ self.tree.body[0], ast.Expr
406
+ ): # pragma: no cover - internal AST error
407
+ msg = "Expected an expression in AST body"
408
+ raise ASTProcessingError(msg)
409
+
410
+ lambda_node = self.tree.body[0].value
411
+ if not isinstance(lambda_node, ast.Lambda): # pragma: no cover - internal AST error
412
+ msg = "Expected a lambda expression"
413
+ raise ASTProcessingError(msg)
414
+
415
+ if lambda_node.args.args:
416
+ msg = "Lambda expressions must not have parameters"
417
+ raise CallableSignatureError(msg)
418
+
419
+ def _discover_attributes(self) -> list[tuple[str, str]]:
420
+ """Discover attributes accessed within the AST."""
421
+ visitor = _AttributeVisitor()
422
+ visitor.visit(self.tree)
423
+ return visitor.attributes
424
+
425
+ def _resolve_facts(self, attributes: list[tuple[str, str]], globals_dict: dict) -> tuple[list[str], dict[str, str]]:
426
+ """Validate attribute accesses and return normalized fact strings."""
427
+ facts = []
428
+ class_to_param = {}
429
+ param_counter = 0
430
+
431
+ for class_name, attr in attributes:
432
+ # Verify the name refers to a class type
433
+ if class_name not in globals_dict or not isinstance(globals_dict[class_name], type):
434
+ msg = f"Accessing undefined class '{class_name}'"
435
+ raise ScopeAccessError(msg)
436
+
437
+ # Verify it's a Fact subclass
438
+ class_obj = globals_dict[class_name]
439
+ if not issubclass(class_obj, Fact):
440
+ raise NotAFactError(class_obj)
441
+
442
+ facts.append(f"{class_name}.{attr}")
443
+ if class_name not in class_to_param:
444
+ class_to_param[class_name] = f"p{param_counter}"
445
+ param_counter += 1
446
+
447
+ # Deduplicate facts while preserving order
448
+ seen = set()
449
+ facts = [fact for fact in facts if not (fact in seen or seen.add(fact))]
450
+
451
+ return facts, class_to_param
452
+
453
+ def _validate_signature(self, hints: dict, params: MappingProxyType[str, inspect.Parameter]) -> None:
454
+ """Validate function signature requirements."""
455
+
456
+ # Validate return type
457
+ if "return" not in hints or hints["return"] is not self.return_type:
458
+ msg = f"Return type hint is required and must be {self.return_type!r}"
459
+ raise CallableSignatureError(msg)
460
+
461
+ # Track parameter types to check for duplicates
462
+ param_types = []
463
+
464
+ # Validate parameters
465
+ for param in params.values():
466
+ if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
467
+ msg = "Variable arguments (*args, **kwargs) are not supported"
468
+ raise CallableSignatureError(msg)
469
+
470
+ if param.name not in hints:
471
+ msg = "All parameters must have type hints"
472
+ raise CallableSignatureError(msg)
473
+
474
+ if param.name != "return":
475
+ param_type = hints[param.name]
476
+ if not isinstance(param_type, type) or not issubclass(param_type, Fact):
477
+ raise NotAFactError(param_type)
478
+ param_types.append(param_type)
479
+
480
+ # Check for duplicate parameter types
481
+ seen_types = set()
482
+ for param_type in param_types:
483
+ if param_type in seen_types:
484
+ msg = f"Duplicate parameter type '{param_type.__name__}' is not allowed"
485
+ raise CallableSignatureError(msg)
486
+ seen_types.add(param_type)
487
+
488
+ def _transform_lambda(self, class_to_param: dict[str, str], caller_globals: dict[str, Any]) -> T:
489
+ # Transform and create new lambda
490
+ transformer = AttributeTransformer(class_to_param)
491
+ new_tree = transformer.visit(self.tree)
492
+ lambda_body = ast.unparse(new_tree.body[0].value)
493
+
494
+ # The AST unparsing creates a full lambda expression, but we only want its body. This handles edge cases where
495
+ # the transformed AST might generate different lambda syntax than the original source code, ensuring we only
496
+ # get the expression part.
497
+ if lambda_body.startswith("lambda"):
498
+ lambda_body = lambda_body[lambda_body.find(":") + 1 :].strip()
499
+
500
+ # Create a new lambda object with the transformed body
501
+ # TODO: Find a way to avoid using exec or eval here
502
+ lambda_code = f"lambda {', '.join(class_to_param.values())}: {lambda_body}"
503
+ new_func = eval(lambda_code, caller_globals) # noqa: S307 # nosec B307
504
+ new_func.__source__ = self.source
505
+
506
+ return new_func