vulcan-core 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vulcan-core might be problematic. Click here for more details.
- vulcan_core/__init__.py +45 -0
- vulcan_core/actions.py +28 -0
- vulcan_core/ast_utils.py +296 -0
- vulcan_core/conditions.py +302 -0
- vulcan_core/engine.py +232 -0
- vulcan_core/models.py +260 -0
- vulcan_core/util.py +127 -0
- vulcan_core-1.0.0.dist-info/LICENSE +176 -0
- vulcan_core-1.0.0.dist-info/METADATA +90 -0
- vulcan_core-1.0.0.dist-info/NOTICE +8 -0
- vulcan_core-1.0.0.dist-info/RECORD +12 -0
- vulcan_core-1.0.0.dist-info/WHEEL +4 -0
vulcan_core/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from vulcan_core.actions import Action, action
|
|
5
|
+
from vulcan_core.ast_utils import (
|
|
6
|
+
ASTProcessingError,
|
|
7
|
+
CallableSignatureError,
|
|
8
|
+
ContractError,
|
|
9
|
+
NotAFactError,
|
|
10
|
+
ScopeAccessError,
|
|
11
|
+
)
|
|
12
|
+
from vulcan_core.conditions import (
|
|
13
|
+
CompoundCondition,
|
|
14
|
+
Condition,
|
|
15
|
+
MissingFactError,
|
|
16
|
+
OnFactChanged,
|
|
17
|
+
Operator,
|
|
18
|
+
condition,
|
|
19
|
+
)
|
|
20
|
+
from vulcan_core.engine import InternalStateError, RecursionLimitError, Rule, RuleEngine
|
|
21
|
+
from vulcan_core.models import ActionReturn, ChunkingStrategy, Fact, Similarity
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ASTProcessingError",
|
|
25
|
+
"Action",
|
|
26
|
+
"ActionReturn",
|
|
27
|
+
"CallableSignatureError",
|
|
28
|
+
"ChunkingStrategy",
|
|
29
|
+
"CompoundCondition",
|
|
30
|
+
"Condition",
|
|
31
|
+
"ContractError",
|
|
32
|
+
"Fact",
|
|
33
|
+
"InternalStateError",
|
|
34
|
+
"MissingFactError",
|
|
35
|
+
"NotAFactError",
|
|
36
|
+
"OnFactChanged",
|
|
37
|
+
"Operator",
|
|
38
|
+
"RecursionLimitError",
|
|
39
|
+
"Rule",
|
|
40
|
+
"RuleEngine",
|
|
41
|
+
"ScopeAccessError",
|
|
42
|
+
"Similarity",
|
|
43
|
+
"action",
|
|
44
|
+
"condition",
|
|
45
|
+
]
|
vulcan_core/actions.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import partial
|
|
8
|
+
|
|
9
|
+
from vulcan_core.ast_utils import ASTProcessor
|
|
10
|
+
from vulcan_core.models import ActionCallable, ActionReturn, DeclaresFacts, Fact, FactHandler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True, slots=True)
|
|
14
|
+
class Action(FactHandler[ActionCallable, ActionReturn], DeclaresFacts):
|
|
15
|
+
"""
|
|
16
|
+
Represents a deferred result calculation of a rule.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __call__(self, *args: Fact) -> ActionReturn:
|
|
20
|
+
return self.func(*args)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def action(value: ActionCallable | ActionReturn) -> Action:
|
|
24
|
+
if not isinstance(value, partial) and callable(value):
|
|
25
|
+
processed = ASTProcessor[ActionCallable](value, action, ActionReturn)
|
|
26
|
+
return Action(processed.facts, processed.func)
|
|
27
|
+
else:
|
|
28
|
+
return Action((), lambda: value)
|
vulcan_core/ast_utils.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
import ast
|
|
5
|
+
import inspect
|
|
6
|
+
import textwrap
|
|
7
|
+
from ast import Attribute, Module, Name, NodeTransformer, NodeVisitor
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from types import MappingProxyType
|
|
12
|
+
from typing import Any, TypeAliasType, get_type_hints
|
|
13
|
+
|
|
14
|
+
from vulcan_core.models import Fact, HasSource
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ASTProcessingError(RuntimeError):
|
|
18
|
+
"""Internal error encountered while processing AST."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ContractError(Exception):
|
|
22
|
+
"""Base exception for callable contract violations."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ScopeAccessError(ContractError):
|
|
26
|
+
"""Raised when a callable attempts to access instances not passed as parameters or when decorated functions attempt
|
|
27
|
+
a toccess class attributes instead of parameter instance attributes ."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class NotAFactError(ContractError):
|
|
31
|
+
"""Raised when a callable parameter, or accessed attribute is not a subclass of Fact."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, type_obj: type) -> None:
|
|
34
|
+
message = f"'{type_obj.__name__}' is not a Fact subclass"
|
|
35
|
+
super().__init__(message)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CallableSignatureError(ContractError):
|
|
39
|
+
"""Raised when a decorated function has any missing type hints, an incorrect return type, or if a lambda that
|
|
40
|
+
requires arguments is provided."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _AttributeVisitor(NodeVisitor):
|
|
44
|
+
"""Visitor to collect attribute accesses from the AST."""
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
self.attributes = []
|
|
48
|
+
|
|
49
|
+
def visit_Attribute(self, node): # noqa: N802
|
|
50
|
+
if isinstance(node.value, Name):
|
|
51
|
+
self.attributes.append((node.value.id, node.attr))
|
|
52
|
+
self.generic_visit(node) # Continue traversing the AST
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _NestedAttributeVisitor(NodeVisitor):
|
|
56
|
+
"""Visitor to detect nested attribute access."""
|
|
57
|
+
|
|
58
|
+
def __init__(self):
|
|
59
|
+
self.has_nested = False
|
|
60
|
+
|
|
61
|
+
def visit_Attribute(self, node): # noqa: N802
|
|
62
|
+
if isinstance(node.value, Attribute):
|
|
63
|
+
self.has_nested = True
|
|
64
|
+
self.generic_visit(node)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class AttributeTransformer(NodeTransformer):
|
|
68
|
+
"""Transformer to replace static class attribute access with parameterized instances."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, class_to_param):
|
|
71
|
+
self.class_to_param = class_to_param
|
|
72
|
+
|
|
73
|
+
def visit_Attribute(self, node: Attribute): # noqa: N802
|
|
74
|
+
node = self.generic_visit(node) # type: ignore
|
|
75
|
+
|
|
76
|
+
if isinstance(node.value, Name) and node.value.id in self.class_to_param:
|
|
77
|
+
return Attribute(
|
|
78
|
+
value=Name(id=self.class_to_param[node.value.id], ctx=node.value.ctx),
|
|
79
|
+
attr=node.attr,
|
|
80
|
+
ctx=node.ctx,
|
|
81
|
+
)
|
|
82
|
+
return node
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class ASTProcessor[T: Callable]:
|
|
87
|
+
func: T
|
|
88
|
+
decorator: Callable
|
|
89
|
+
return_type: type | TypeAliasType
|
|
90
|
+
source: str = field(init=False)
|
|
91
|
+
tree: Module = field(init=False)
|
|
92
|
+
facts: tuple[str, ...] = field(init=False)
|
|
93
|
+
|
|
94
|
+
@cached_property
|
|
95
|
+
def is_lambda(self) -> bool:
|
|
96
|
+
return isinstance(self.func, type(lambda: None)) and self.func.__name__ == "<lambda>"
|
|
97
|
+
|
|
98
|
+
def __post_init__(self):
|
|
99
|
+
# Extract source code and parse AST
|
|
100
|
+
if isinstance(self.func, HasSource):
|
|
101
|
+
self.source = self.func.__source__
|
|
102
|
+
else:
|
|
103
|
+
try:
|
|
104
|
+
self.source = textwrap.dedent(inspect.getsource(self.func))
|
|
105
|
+
except OSError as e:
|
|
106
|
+
if str(e) == "could not get source code":
|
|
107
|
+
msg = "could not get source code. Try recursively deleting all __pycache__ folders in your project."
|
|
108
|
+
raise OSError(msg) from e
|
|
109
|
+
else:
|
|
110
|
+
raise
|
|
111
|
+
self.func.__source__ = self.source
|
|
112
|
+
|
|
113
|
+
self.source = self._extract_lambda_source() if self.is_lambda else self.source
|
|
114
|
+
self.tree = ast.parse(self.source)
|
|
115
|
+
|
|
116
|
+
# Peform basic AST checks and attribute discovery
|
|
117
|
+
self._validate_ast()
|
|
118
|
+
attributes = self._discover_attributes()
|
|
119
|
+
|
|
120
|
+
if self.is_lambda:
|
|
121
|
+
# Process attributes and create a transformed lambda
|
|
122
|
+
caller_globals = self._get_caller_globals()
|
|
123
|
+
facts, class_to_param = self._resolve_facts(attributes, caller_globals)
|
|
124
|
+
|
|
125
|
+
self.facts = tuple(facts)
|
|
126
|
+
self.func = self._transform_lambda(class_to_param, caller_globals)
|
|
127
|
+
|
|
128
|
+
else:
|
|
129
|
+
# Get function metadata and validate signature
|
|
130
|
+
hints = get_type_hints(self.func)
|
|
131
|
+
params = inspect.signature(self.func).parameters
|
|
132
|
+
self._validate_signature(hints, params)
|
|
133
|
+
|
|
134
|
+
# Process attributes
|
|
135
|
+
facts: list[str] = []
|
|
136
|
+
param_names = list(params)
|
|
137
|
+
|
|
138
|
+
# Create the list of accessed facts and verify they are in the correct scope
|
|
139
|
+
for class_name, attr in attributes:
|
|
140
|
+
if class_name not in param_names:
|
|
141
|
+
msg = f"Accessing class '{class_name}' not passed as parameter"
|
|
142
|
+
raise ScopeAccessError(msg)
|
|
143
|
+
facts.append(f"{hints[class_name].__name__}.{attr}")
|
|
144
|
+
|
|
145
|
+
self.facts = tuple(facts)
|
|
146
|
+
|
|
147
|
+
def _extract_lambda_source(self) -> str:
|
|
148
|
+
"""Extracts just the lambda expression from source code."""
|
|
149
|
+
lambda_start = self.source.find("lambda")
|
|
150
|
+
if lambda_start == -1: # pragma: no cover - internal AST error
|
|
151
|
+
msg = "Could not find lambda expression in source"
|
|
152
|
+
raise ASTProcessingError(msg)
|
|
153
|
+
|
|
154
|
+
# The source includes the entire line of code (e.g., assignment and condition() call)
|
|
155
|
+
# We need to parse parentheses to extract just the lambda expression, handling any
|
|
156
|
+
# nested parentheses in the lambda's body correctly
|
|
157
|
+
source = self.source[lambda_start:]
|
|
158
|
+
paren_level = 0
|
|
159
|
+
for i, char in enumerate(source):
|
|
160
|
+
if char == "(":
|
|
161
|
+
paren_level += 1
|
|
162
|
+
elif char == ")" and paren_level > 0:
|
|
163
|
+
paren_level -= 1
|
|
164
|
+
elif char == ")" and paren_level == 0:
|
|
165
|
+
return source[:i]
|
|
166
|
+
|
|
167
|
+
return source
|
|
168
|
+
|
|
169
|
+
def _get_caller_globals(self) -> dict[str, Any]:
|
|
170
|
+
"""Find the globals of the caller of the decorator in order to validate accessed types."""
|
|
171
|
+
try:
|
|
172
|
+
decorator_name = self.decorator.__name__
|
|
173
|
+
frame = inspect.currentframe()
|
|
174
|
+
while frame.f_code.co_name != decorator_name: # type: ignore
|
|
175
|
+
frame = frame.f_back # type: ignore
|
|
176
|
+
return frame.f_back.f_globals # type: ignore # noqa: TRY300
|
|
177
|
+
|
|
178
|
+
except AttributeError as err: # pragma: no cover - internal AST error
|
|
179
|
+
msg = f"Unable to locate caller ('{decorator_name}') globals"
|
|
180
|
+
raise ASTProcessingError(msg) from err
|
|
181
|
+
|
|
182
|
+
def _validate_ast(self) -> None:
|
|
183
|
+
# Check for nested attribute access
|
|
184
|
+
visitor = _NestedAttributeVisitor()
|
|
185
|
+
visitor.visit(self.tree)
|
|
186
|
+
if visitor.has_nested:
|
|
187
|
+
msg = "Nested attribute access (X.y.z) is not allowed"
|
|
188
|
+
raise ScopeAccessError(msg)
|
|
189
|
+
|
|
190
|
+
# Checks for async functions
|
|
191
|
+
if isinstance(self.tree.body[0], ast.AsyncFunctionDef):
|
|
192
|
+
msg = "Async functions are not supported"
|
|
193
|
+
raise CallableSignatureError(msg)
|
|
194
|
+
|
|
195
|
+
# Lambda-specific checks
|
|
196
|
+
if self.is_lambda:
|
|
197
|
+
if not isinstance(self.tree, ast.Module) or not isinstance(
|
|
198
|
+
self.tree.body[0], ast.Expr
|
|
199
|
+
): # pragma: no cover - internal AST error
|
|
200
|
+
msg = "Expected an expression in AST body"
|
|
201
|
+
raise ASTProcessingError(msg)
|
|
202
|
+
|
|
203
|
+
lambda_node = self.tree.body[0].value
|
|
204
|
+
if not isinstance(lambda_node, ast.Lambda): # pragma: no cover - internal AST error
|
|
205
|
+
msg = "Expected a lambda expression"
|
|
206
|
+
raise ASTProcessingError(msg)
|
|
207
|
+
|
|
208
|
+
if lambda_node.args.args:
|
|
209
|
+
msg = "Lambda expressions must not have parameters"
|
|
210
|
+
raise CallableSignatureError(msg)
|
|
211
|
+
|
|
212
|
+
def _discover_attributes(self) -> list[tuple[str, str]]:
|
|
213
|
+
"""Discover attribute accessed within the AST."""
|
|
214
|
+
visitor = _AttributeVisitor()
|
|
215
|
+
visitor.visit(self.tree)
|
|
216
|
+
return visitor.attributes
|
|
217
|
+
|
|
218
|
+
def _resolve_facts(self, attributes: list[tuple[str, str]], globals_dict: dict) -> tuple[list[str], dict[str, str]]:
|
|
219
|
+
"""Validate attribute accesses and return normalized fact strings."""
|
|
220
|
+
facts = []
|
|
221
|
+
class_to_param = {}
|
|
222
|
+
param_counter = 0
|
|
223
|
+
|
|
224
|
+
for class_name, attr in attributes:
|
|
225
|
+
# Verify name refers to a class type
|
|
226
|
+
if class_name not in globals_dict or not isinstance(globals_dict[class_name], type):
|
|
227
|
+
msg = f"Accessing undefined class '{class_name}'"
|
|
228
|
+
raise ScopeAccessError(msg)
|
|
229
|
+
|
|
230
|
+
# Verify it's a Fact subclass
|
|
231
|
+
class_obj = globals_dict[class_name]
|
|
232
|
+
if not issubclass(class_obj, Fact):
|
|
233
|
+
raise NotAFactError(class_obj)
|
|
234
|
+
|
|
235
|
+
facts.append(f"{class_name}.{attr}")
|
|
236
|
+
if class_name not in class_to_param:
|
|
237
|
+
class_to_param[class_name] = f"p{param_counter}"
|
|
238
|
+
param_counter += 1
|
|
239
|
+
|
|
240
|
+
# Deduplicate facts while preserving order
|
|
241
|
+
seen = set()
|
|
242
|
+
facts = [fact for fact in facts if not (fact in seen or seen.add(fact))]
|
|
243
|
+
|
|
244
|
+
return facts, class_to_param
|
|
245
|
+
|
|
246
|
+
def _validate_signature(self, hints: dict, params: MappingProxyType[str, inspect.Parameter]) -> None:
|
|
247
|
+
"""Validate function signature requirements."""
|
|
248
|
+
|
|
249
|
+
# Validate return type
|
|
250
|
+
if "return" not in hints or hints["return"] is not self.return_type:
|
|
251
|
+
msg = f"Return type hint is required and must be {self.return_type!r}"
|
|
252
|
+
raise CallableSignatureError(msg)
|
|
253
|
+
|
|
254
|
+
# Track parameter types to check for duplicates
|
|
255
|
+
param_types = []
|
|
256
|
+
|
|
257
|
+
# Validate parameters
|
|
258
|
+
for param in params.values():
|
|
259
|
+
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
|
|
260
|
+
msg = "Variable arguments (*args, **kwargs) are not supported"
|
|
261
|
+
raise CallableSignatureError(msg)
|
|
262
|
+
|
|
263
|
+
if param.name not in hints:
|
|
264
|
+
msg = "All parameters must have type hints"
|
|
265
|
+
raise CallableSignatureError(msg)
|
|
266
|
+
|
|
267
|
+
if param.name != "return":
|
|
268
|
+
param_type = hints[param.name]
|
|
269
|
+
if not isinstance(param_type, type) or not issubclass(param_type, Fact):
|
|
270
|
+
raise NotAFactError(param_type)
|
|
271
|
+
param_types.append(param_type)
|
|
272
|
+
|
|
273
|
+
# Check for duplicate parameter types
|
|
274
|
+
seen_types = set()
|
|
275
|
+
for param_type in param_types:
|
|
276
|
+
if param_type in seen_types:
|
|
277
|
+
msg = f"Duplicate parameter type '{param_type.__name__}' is not allowed"
|
|
278
|
+
raise CallableSignatureError(msg)
|
|
279
|
+
seen_types.add(param_type)
|
|
280
|
+
|
|
281
|
+
def _transform_lambda(self, class_to_param: dict[str, str], caller_globals: dict[str, Any]) -> T:
|
|
282
|
+
# Transform and create new lambda
|
|
283
|
+
transformer = AttributeTransformer(class_to_param)
|
|
284
|
+
new_tree = transformer.visit(self.tree)
|
|
285
|
+
lambda_body = ast.unparse(new_tree.body[0].value)
|
|
286
|
+
|
|
287
|
+
# The AST unparsing creates a full lambda expression, but we only want its body. This handles edge cases where
|
|
288
|
+
# the transformed AST might generate different lambda syntax. than the original source code, ensuring we only
|
|
289
|
+
# get the expression part.
|
|
290
|
+
if lambda_body.startswith("lambda"):
|
|
291
|
+
lambda_body = lambda_body[lambda_body.find(":") + 1 :].strip()
|
|
292
|
+
|
|
293
|
+
# TODO: Find a way to avoid using exec or eval here
|
|
294
|
+
lambda_code = f"lambda {', '.join(class_to_param.values())}: {lambda_body}"
|
|
295
|
+
new_func = eval(lambda_code, caller_globals) # noqa: S307 # nosec B307
|
|
296
|
+
return new_func
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import _string # type: ignore
|
|
7
|
+
import re
|
|
8
|
+
from abc import abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum, auto
|
|
11
|
+
from string import Formatter
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from langchain.prompts import ChatPromptTemplate
|
|
15
|
+
from langchain_openai import ChatOpenAI
|
|
16
|
+
from pydantic import BaseModel, Field
|
|
17
|
+
|
|
18
|
+
from vulcan_core.actions import ASTProcessor
|
|
19
|
+
from vulcan_core.models import ConditionCallable, DeclaresFacts, Fact, FactHandler, Similarity
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING: # pragma: no cover - not used at runtime
|
|
22
|
+
from langchain_core.language_models import BaseChatModel
|
|
23
|
+
from langchain_core.runnables import RunnableSerializable
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True, slots=True)
|
|
27
|
+
class Expression(DeclaresFacts):
|
|
28
|
+
"""
|
|
29
|
+
Abstract base class for defining deferred logical expressions. It captures the assosciation of logic with Facts so
|
|
30
|
+
that upon a Fact update, the logical expression can be selectively evaluated. It also provides a set of logical
|
|
31
|
+
operators for combining conditions, resulting in a new CompoundCondition.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
inverted: bool = field(kw_only=True, default=False)
|
|
35
|
+
|
|
36
|
+
def _compound(self, other: Expression, operator: Operator) -> Expression:
|
|
37
|
+
# Be sure to preserve the order of facts while removing duplicates
|
|
38
|
+
combined_facts = tuple(dict.fromkeys(self.facts + other.facts))
|
|
39
|
+
return CompoundCondition(combined_facts, self, operator, other)
|
|
40
|
+
|
|
41
|
+
def __and__(self, other: Expression) -> Expression:
|
|
42
|
+
return self._compound(other, Operator.AND)
|
|
43
|
+
|
|
44
|
+
def __or__(self, other: Expression) -> Expression:
|
|
45
|
+
return self._compound(other, Operator.OR)
|
|
46
|
+
|
|
47
|
+
def __xor__(self, other: Expression) -> Expression:
|
|
48
|
+
return self._compound(other, Operator.XOR)
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def __call__(self, *args: Fact) -> bool: ...
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def __invert__(self) -> Expression: ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# TODO: Investigate cached condition and deadline parameters, useful for expensive calls like AI/DB conditions
|
|
58
|
+
@dataclass(frozen=True, slots=True)
|
|
59
|
+
class Condition(FactHandler[ConditionCallable, bool], Expression):
|
|
60
|
+
"""
|
|
61
|
+
A Condition is a container to defer logical epxressions against a supplied Fact. The expression can be inverted
|
|
62
|
+
using the `~` operator. Conditions also support the '&', '|', and '^' operators for combinatorial logic.
|
|
63
|
+
|
|
64
|
+
Attributes:
|
|
65
|
+
facts (tuple[str, ...]): A tuple of strings representing the facts/attributes this condition
|
|
66
|
+
depends upon. Each string should be in the format "ClassName.attribute" without nesting.
|
|
67
|
+
func (Callable[..., bool]): A callable that implements the actual condition logic. It should
|
|
68
|
+
return a boolean value indicating whether the condition is satisfied.
|
|
69
|
+
is_inverted (bool): Flag indicating whether the condition result should be inverted.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __call__(self, *args: Fact) -> bool:
|
|
73
|
+
result = self.func(*args)
|
|
74
|
+
return not result if self.inverted else result
|
|
75
|
+
|
|
76
|
+
def __invert__(self) -> Condition:
|
|
77
|
+
return Condition(self.facts, self.func, inverted=not self.inverted)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Operator(Enum):
|
|
81
|
+
"""Represents the logical operation of a CompoundCondition"""
|
|
82
|
+
|
|
83
|
+
AND = auto()
|
|
84
|
+
OR = auto()
|
|
85
|
+
XOR = auto()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(frozen=True, slots=True)
|
|
89
|
+
class CompoundCondition(Expression):
|
|
90
|
+
"""
|
|
91
|
+
Represents a compound logical condition composed of two sub-conditions, an operator, and an negation flag. This
|
|
92
|
+
class allows for the deferred evaluation of complex logical expressions by combining simpler conditions using
|
|
93
|
+
logical operators such as `&`, `|`, and `^`.
|
|
94
|
+
|
|
95
|
+
CompoundConditions are chain evaluated from left to right. For example, `a | b | c` is equivalent to: `(a | b) | c`
|
|
96
|
+
but ordering can be overriden with parenthesis: `a | (b | c)` which is equivalent to: `(a) | (b | c)`.
|
|
97
|
+
|
|
98
|
+
This clas should not be used directly in favor of the logical operators.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
left: Expression
|
|
102
|
+
operator: Operator
|
|
103
|
+
right: Expression
|
|
104
|
+
|
|
105
|
+
# TODO: Add a compile method that generates a lambda function with AST for faster evaluation
|
|
106
|
+
|
|
107
|
+
def _pick_args(self, expr: Expression, args) -> list[Fact]:
|
|
108
|
+
"""Returns the arg values passed to this CompoundCondition that are needed by the given expression."""
|
|
109
|
+
return [arg for fact, arg in zip(self.facts, args, strict=False) if fact in expr.facts]
|
|
110
|
+
|
|
111
|
+
def __call__(self, *args: Fact) -> bool:
|
|
112
|
+
"""
|
|
113
|
+
Upon evaluation, each sub-condition is evaluated and combined using the operator. If the CompoundCondition is
|
|
114
|
+
negated, the result is inverted before being returned.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
left_args = self._pick_args(self.left, args)
|
|
118
|
+
right_args = self._pick_args(self.right, args)
|
|
119
|
+
|
|
120
|
+
left_result = self.left(*left_args)
|
|
121
|
+
right_result = self.right(*right_args)
|
|
122
|
+
|
|
123
|
+
if self.operator == Operator.AND:
|
|
124
|
+
result = left_result and right_result
|
|
125
|
+
elif self.operator == Operator.OR:
|
|
126
|
+
result = left_result or right_result
|
|
127
|
+
elif self.operator == Operator.XOR:
|
|
128
|
+
result = left_result ^ right_result
|
|
129
|
+
else:
|
|
130
|
+
msg = (
|
|
131
|
+
f"Operator {self.operator} not implemented" # pragma: no cover - Saftey check for future enum additions
|
|
132
|
+
)
|
|
133
|
+
raise NotImplementedError(msg)
|
|
134
|
+
|
|
135
|
+
return not result if self.inverted else result
|
|
136
|
+
|
|
137
|
+
def __invert__(self) -> CompoundCondition:
|
|
138
|
+
return CompoundCondition(self.facts, self.left, self.operator, self.right, inverted=not self.inverted)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class MissingFactError(Exception):
|
|
142
|
+
"""Raised when and AI condition has no declared facts for context."""
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class AIDecisionError(Exception):
|
|
146
|
+
"""Raised when an AI detrmines an error with the inquiry during evaluation."""
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# TODO: Move this to models module?
|
|
150
|
+
class BooleanDecision(BaseModel):
|
|
151
|
+
rationale: str = Field(description="A short explanation for the decision or error.")
|
|
152
|
+
answer: bool = Field(description="The answer to the inquiry.")
|
|
153
|
+
error: bool = Field(description="'True' if any error was encountered with the inquiry and/or response.")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class DeferredFormatter(Formatter):
|
|
157
|
+
"""Formatter that defers the evaluation of value searches."""
|
|
158
|
+
|
|
159
|
+
def __init__(self):
|
|
160
|
+
super().__init__()
|
|
161
|
+
self.found_lookups: dict[str, Similarity] = {}
|
|
162
|
+
|
|
163
|
+
def get_field(self, field_name, args, kwargs):
|
|
164
|
+
first, rest = _string.formatter_field_name_split(field_name)
|
|
165
|
+
obj = self.get_value(first, args, kwargs)
|
|
166
|
+
|
|
167
|
+
for is_attr, i in rest:
|
|
168
|
+
obj = getattr(obj, i) if is_attr else obj[i]
|
|
169
|
+
if isinstance(obj, Similarity):
|
|
170
|
+
self.found_lookups[field_name] = obj
|
|
171
|
+
return (f"{{{field_name}}}", field_name)
|
|
172
|
+
return obj, first
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class LiteralFormatter(Formatter):
|
|
176
|
+
"""A formatter that does not inspect attributes of the object being formatted."""
|
|
177
|
+
|
|
178
|
+
def get_field(self, field_name, args, kwargs):
|
|
179
|
+
return (self.get_value(field_name, args, kwargs), field_name)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@dataclass(frozen=True, slots=True)
|
|
183
|
+
class AICondition(Condition):
|
|
184
|
+
chain: RunnableSerializable
|
|
185
|
+
system_template: str
|
|
186
|
+
inquiry_template: str
|
|
187
|
+
func: None = field(init=False, default=None)
|
|
188
|
+
|
|
189
|
+
def __call__(self, *args: Fact) -> bool:
|
|
190
|
+
# Use just the fact names to format the system message
|
|
191
|
+
keys = {key.split(".")[0]: key for key in self.facts}.keys()
|
|
192
|
+
|
|
193
|
+
# Format everything except any LazyLookup objects
|
|
194
|
+
formatter = DeferredFormatter()
|
|
195
|
+
system_msg = formatter.vformat(self.system_template, [], dict(zip(keys, args, strict=False)))
|
|
196
|
+
rag_lookup = formatter.vformat(self.inquiry_template, [], dict(zip(keys, args, strict=False)))
|
|
197
|
+
rag_lookup = rag_lookup.translate(str.maketrans("{}", "<>"))
|
|
198
|
+
|
|
199
|
+
values = {}
|
|
200
|
+
for f_name, lookup in formatter.found_lookups.items():
|
|
201
|
+
values[f_name] = lookup[rag_lookup]
|
|
202
|
+
|
|
203
|
+
system_msg = LiteralFormatter().vformat(system_msg, [], values)
|
|
204
|
+
|
|
205
|
+
# Invoke the LLM and get the result
|
|
206
|
+
inquiry = self.inquiry_template.translate(str.maketrans("{}", "<>"))
|
|
207
|
+
result: BooleanDecision = self.chain.invoke({"system_msg": system_msg, "inquiry": inquiry})
|
|
208
|
+
if result.error:
|
|
209
|
+
raise AIDecisionError(result.rationale)
|
|
210
|
+
|
|
211
|
+
return not result.answer if self.inverted else result.answer
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# TODO: Investigate how best to register tools for specific consitions
|
|
215
|
+
def ai_condition(model: BaseChatModel, inquiry: str) -> AICondition:
|
|
216
|
+
# TODO: Optimize by precompiling regex and storing translation table globally
|
|
217
|
+
# Find and referenced facts and replace braces with angle brackets
|
|
218
|
+
facts = tuple(re.findall(r"\{([^}]+)\}", inquiry))
|
|
219
|
+
# inquiry = inquiry.translate(str.maketrans("{}", "<>"))
|
|
220
|
+
|
|
221
|
+
# TODO: Determine if this should be kept, especially with LLMs calling tools
|
|
222
|
+
if not facts:
|
|
223
|
+
msg = "An AI condition requires at least one referenced fact."
|
|
224
|
+
raise MissingFactError(msg)
|
|
225
|
+
|
|
226
|
+
# TODO: Move these rules to a validation rule set for ai conditions
|
|
227
|
+
system = "Answer the <inquiry> by referencing the following information tags:\n\n"
|
|
228
|
+
|
|
229
|
+
for fact in facts:
|
|
230
|
+
system += f"<{fact}>\n{{{fact}}}\n<{fact}/>\n\n"
|
|
231
|
+
system += "</instructions>"
|
|
232
|
+
|
|
233
|
+
prompt_template = ChatPromptTemplate.from_messages(
|
|
234
|
+
[
|
|
235
|
+
("system", "{system_msg}"),
|
|
236
|
+
("user", "<inquiry>{inquiry}</inquiry>"),
|
|
237
|
+
]
|
|
238
|
+
)
|
|
239
|
+
structured_model = model.with_structured_output(BooleanDecision)
|
|
240
|
+
chain = prompt_template | structured_model
|
|
241
|
+
return AICondition(chain=chain, system_template=system, inquiry_template=inquiry, facts=facts)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
default_model = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=100) # type: ignore[call-arg] - pyright can't see the args for some reason
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def condition(func: ConditionCallable | str) -> Condition:
|
|
248
|
+
"""
|
|
249
|
+
Creates a Condition object from a lambda or function. It performs limited static analysis of the code to ensure
|
|
250
|
+
proper usage and discover the facts/attributes accessed by the condition. This allows the rule engine to track
|
|
251
|
+
dependencies between conditions and facts with minimal boilerplate code.
|
|
252
|
+
|
|
253
|
+
Lambda usage requires Fact access via accessing static class attributes (e.g., User.age). Whereas functions are not
|
|
254
|
+
allowed to access class attributes statically, and must only access attributes via parameter instances. Neither
|
|
255
|
+
lambdas or functions are allowed to access instances outside of their scope.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
func (Callable[..., bool]): A lambda or function that returns a boolean value.
|
|
259
|
+
For regular functions, parameters must be properly type-hinted with Fact subclasses. For lambdas, no
|
|
260
|
+
parameters are allowed.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Condition: A Condition object containing:
|
|
264
|
+
- facts: A tuple of fact identifiers in the form "FactClass.attribute"
|
|
265
|
+
- func: The transformed callable that will evaluate the condition
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
- ASTProcessingError: If unable to retrieve caller globals or process the AST.
|
|
269
|
+
- CallableSignatureError: If async functions are used or signature validation fails
|
|
270
|
+
- ScopeAccessError: If attributes are accessed from classes not passed as parameters
|
|
271
|
+
|
|
272
|
+
Example:
|
|
273
|
+
# Will be transformed to accept instances of User:
|
|
274
|
+
is_user_adult = condition(lambda: User.age >= User.max_age)
|
|
275
|
+
|
|
276
|
+
# As with the lambda, decorated functions will be analyzed for which Facts attributes are accessed:
|
|
277
|
+
@condition
|
|
278
|
+
def is_user_adult(user: User) -> bool:
|
|
279
|
+
return user.age >= user.max_age
|
|
280
|
+
|
|
281
|
+
Notes:
|
|
282
|
+
- Async functions are not supported
|
|
283
|
+
- Nested attribute access (e.g., a.b.c) is not allowed
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
if not isinstance(func, str):
|
|
287
|
+
processed = ASTProcessor[ConditionCallable](func, condition, bool)
|
|
288
|
+
return Condition(processed.facts, processed.func)
|
|
289
|
+
else:
|
|
290
|
+
return ai_condition(default_model, func)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# TODO: Create a convenience function for creating OnFactChanged conditions
|
|
294
|
+
@dataclass(frozen=True, slots=True)
|
|
295
|
+
class OnFactChanged(Condition):
|
|
296
|
+
"""
|
|
297
|
+
A condition that always returns True. It is used to trigger rules when a Fact is updated. It is useful for rules
|
|
298
|
+
that need to simply update a Fact when another fact is updated.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __call__(self, *args: Fact) -> bool:
|
|
302
|
+
return True
|