vulcan-core 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vulcan_core/__init__.py +45 -0
- vulcan_core/actions.py +31 -0
- vulcan_core/ast_utils.py +506 -0
- vulcan_core/conditions.py +432 -0
- vulcan_core/engine.py +287 -0
- vulcan_core/models.py +271 -0
- vulcan_core/reporting.py +595 -0
- vulcan_core/util.py +127 -0
- vulcan_core-1.2.1.dist-info/METADATA +88 -0
- vulcan_core-1.2.1.dist-info/RECORD +11 -0
- vulcan_core-1.2.1.dist-info/WHEEL +4 -0
vulcan_core/engine.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from functools import cached_property, partial
|
|
9
|
+
from types import MappingProxyType
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
from uuid import UUID, uuid4
|
|
12
|
+
|
|
13
|
+
from vulcan_core.ast_utils import NotAFactError
|
|
14
|
+
from vulcan_core.models import DeclaresFacts, Fact
|
|
15
|
+
from vulcan_core.reporting import Auditor
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING: # pragma: no cover - not used at runtime
|
|
18
|
+
from collections.abc import Mapping
|
|
19
|
+
|
|
20
|
+
from vulcan_core.actions import Action
|
|
21
|
+
from vulcan_core.conditions import Expression
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class InternalStateError(RuntimeError):
|
|
27
|
+
"""Raised when the internal state of the RuleEngine is invalid."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RecursionLimitError(RuntimeError):
|
|
31
|
+
"""Raised when the recursion limit is reached during rule evaluation."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class Rule:
|
|
36
|
+
"""
|
|
37
|
+
Represents a rule with a condition and corresponding actions.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
- id (UUID): A unique identifier for the rule, automatically generated.
|
|
41
|
+
- name (Optional[str]): The name of the rule.
|
|
42
|
+
- when (Expression): The condition that triggers the rule.
|
|
43
|
+
- then (Action): The action to be executed when the condition is met.
|
|
44
|
+
- inverse (Optional[Action]): An optional action to be executed when the condition is not met.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
id: UUID = field(default_factory=uuid4, init=False)
|
|
48
|
+
name: str | None
|
|
49
|
+
when: Expression
|
|
50
|
+
then: Action
|
|
51
|
+
inverse: Action | None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# TODO: Look into support for langchain operators and lang graph integration
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(kw_only=True)
|
|
58
|
+
class RuleEngine:
|
|
59
|
+
"""
|
|
60
|
+
RuleEngine is a class that manages the evaluation of rules based on a set of facts. It allows for the addition of rules,
|
|
61
|
+
updating of facts, and cascading evaluation of rules.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
enabled (bool): Indicates whether the rule engine is enabled.
|
|
65
|
+
recusion_limit (int): The maximum number of recursive evaluations allowed.
|
|
66
|
+
facts (dict[type[Fact], Fact]): A dictionary to store facts with their types as keys.
|
|
67
|
+
rules (dict[str, list[Rule]]): A dictionary to store rules associated with fact strings.
|
|
68
|
+
|
|
69
|
+
Methods:
|
|
70
|
+
rule(self, *, name: str | None = None, when: LogicEvaluator, then: BaseAction, inverse: BaseAction | None = None): Adds a rule to the rule engine.
|
|
71
|
+
update_facts(self, fact: tuple[Fact | partial[Fact], ...] | partial[Fact] | Fact) -> Iterator[str]: Updates the facts in the working memory.
|
|
72
|
+
evaluate(self, trace: bool = False): Evaluates the rules based on the current facts in working memory.
|
|
73
|
+
yaml_report(self): Returns the YAML report of the last evaluation (if tracing was enabled).
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
enabled: bool = False
|
|
77
|
+
recusion_limit: int = 10
|
|
78
|
+
_facts: dict[str, Fact] = field(default_factory=dict, init=False)
|
|
79
|
+
_rules: dict[str, list[Rule]] = field(default_factory=dict, init=False)
|
|
80
|
+
_audit: Auditor = field(default_factory=Auditor, init=False)
|
|
81
|
+
|
|
82
|
+
@cached_property
|
|
83
|
+
def facts(self) -> Mapping[str, Fact]:
|
|
84
|
+
return MappingProxyType(self._facts)
|
|
85
|
+
|
|
86
|
+
@cached_property
|
|
87
|
+
def rules(self) -> Mapping[str, list[Rule]]:
|
|
88
|
+
return MappingProxyType(self._rules)
|
|
89
|
+
|
|
90
|
+
def __getitem__[T: Fact](self, key: type[T]) -> T:
|
|
91
|
+
"""
|
|
92
|
+
Retrieves a fact from the working memory.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
key (type[Fact]): The type of the fact to retrieve.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
T: The fact instance of the specified type.
|
|
99
|
+
"""
|
|
100
|
+
return self._facts[key.__name__] # type: ignore
|
|
101
|
+
|
|
102
|
+
def fact(self, fact: Fact | partial[Fact]):
|
|
103
|
+
"""
|
|
104
|
+
Updates the working memory with a new fact or merges a partial fact.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
fact (Union[Fact, partial[Fact]]): The fact instance or partial fact to update the working memory with.
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
InternalStateError: If a partial fact cannot be instantiated due to missing required fields
|
|
111
|
+
"""
|
|
112
|
+
# TODO: Figure out how to track only fact attributes that have changed, and fire on affected rules
|
|
113
|
+
|
|
114
|
+
if isinstance(fact, partial):
|
|
115
|
+
fact_name = fact.func.__name__
|
|
116
|
+
fact_class = fact.func
|
|
117
|
+
if not issubclass(fact_class, Fact): # type: ignore
|
|
118
|
+
raise NotAFactError(fact_class)
|
|
119
|
+
|
|
120
|
+
if fact_name in self._facts:
|
|
121
|
+
self._facts[fact_name] |= fact
|
|
122
|
+
else:
|
|
123
|
+
try:
|
|
124
|
+
self._facts[fact_name] = fact()
|
|
125
|
+
except TypeError as err:
|
|
126
|
+
msg = f"Fact '{fact_name}' is missing and lacks sufficient defaults to create from partial: {fact}"
|
|
127
|
+
raise InternalStateError(msg) from err
|
|
128
|
+
else:
|
|
129
|
+
fact_class = type(fact)
|
|
130
|
+
if not issubclass(fact_class, Fact):
|
|
131
|
+
raise NotAFactError(fact_class)
|
|
132
|
+
|
|
133
|
+
self._facts[type(fact).__name__] = fact
|
|
134
|
+
|
|
135
|
+
def rule[T: Fact](
|
|
136
|
+
self, *, name: str | None = None, when: Expression, then: Action, inverse: Action | None = None
|
|
137
|
+
) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Convenience method for adding a rule to the rule engine.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
name (Optional[str]): The name of the rule. Defaults to None.
|
|
143
|
+
when (Expression): The condition that triggers the rule.
|
|
144
|
+
then (Action): The action to be executed when the condition is met.
|
|
145
|
+
inverse (Optional[Action]): The action to be executed when the condition is not met. Defaults to None.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
None
|
|
149
|
+
"""
|
|
150
|
+
rule = Rule(name, when, then, inverse)
|
|
151
|
+
|
|
152
|
+
# TODO: Add automatic inverse option?
|
|
153
|
+
|
|
154
|
+
# Update the facts to rule mapping
|
|
155
|
+
for fact_str in when.facts:
|
|
156
|
+
if fact_str in self._rules:
|
|
157
|
+
self._rules[fact_str].append(rule)
|
|
158
|
+
else:
|
|
159
|
+
self._rules[fact_str] = [rule]
|
|
160
|
+
|
|
161
|
+
def _update_facts(self, fact: tuple[Fact | partial[Fact], ...] | partial[Fact] | Fact) -> list[str]:
|
|
162
|
+
"""
|
|
163
|
+
Updates the fact in the facts dictionary. If the provided fact is an instance of Fact, it updates the dictionary
|
|
164
|
+
with the type of the fact as the key. If the provided fact is a partial function, it updates the dictionary with
|
|
165
|
+
the function of the partial as the key.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
fact (tuple[Fact | partial[Fact], ...] | partial[Fact] | Fact): The fact(s) to be updated, either as an instance of Fact, a partial function, or a tuple of either.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Iterator[str]: An iterator over the fact strings of the updated facts.
|
|
172
|
+
"""
|
|
173
|
+
facts = fact if isinstance(fact, tuple) else (fact,)
|
|
174
|
+
updated = []
|
|
175
|
+
|
|
176
|
+
for f in facts:
|
|
177
|
+
self.fact(f)
|
|
178
|
+
|
|
179
|
+
# Track which attributes were updated
|
|
180
|
+
if isinstance(f, partial):
|
|
181
|
+
fact_name = f.func.__name__
|
|
182
|
+
attrs = f.keywords
|
|
183
|
+
else:
|
|
184
|
+
fact_name = f.__class__.__name__
|
|
185
|
+
attrs = vars(f)
|
|
186
|
+
|
|
187
|
+
updated.extend([f"{fact_name}.{attr}" for attr in attrs])
|
|
188
|
+
|
|
189
|
+
return updated
|
|
190
|
+
|
|
191
|
+
def _resolve_facts(self, declared: DeclaresFacts, facts: dict[str, Fact]) -> list[Fact]:
|
|
192
|
+
# Deduplicate the fact strings and retrieve unique fact instances
|
|
193
|
+
keys = {key.split(".")[0]: key for key in declared.facts}.values()
|
|
194
|
+
return [facts[key.split(".")[0]] for key in keys]
|
|
195
|
+
|
|
196
|
+
def evaluate(self, fact: Fact | partial[Fact] | None = None, *, audit: bool = False):
|
|
197
|
+
"""
|
|
198
|
+
Cascading evaluation of rules based on the facts in working memory.
|
|
199
|
+
|
|
200
|
+
If provided a fact, will update and evaluate immediately. Otherwise all rules will be evaluated.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
fact: Optional fact to update and evaluate immediately
|
|
204
|
+
audit: Enables tracing for explanbility report generation
|
|
205
|
+
"""
|
|
206
|
+
evaluated_rules: set[UUID] = set()
|
|
207
|
+
consequence: set[str] = set()
|
|
208
|
+
|
|
209
|
+
# TODO: Create an internal consistency check to determine if all referenced Facts are present?
|
|
210
|
+
|
|
211
|
+
# TODO: detect cycles in graph before executing
|
|
212
|
+
# Move to a separate lifecycle step?
|
|
213
|
+
# Provide option for handling
|
|
214
|
+
|
|
215
|
+
# TODO: Check whether fact attributes have actually changed, and only fire rules that are affected
|
|
216
|
+
if fact:
|
|
217
|
+
scope = self._update_facts(fact)
|
|
218
|
+
else:
|
|
219
|
+
# By default, evaluate all facts
|
|
220
|
+
fact_list = self._facts.values()
|
|
221
|
+
scope = {f"{fact.__class__.__name__}.{attr}" for fact in fact_list for attr in vars(fact)}
|
|
222
|
+
|
|
223
|
+
if audit:
|
|
224
|
+
self._audit.evaluation_reset()
|
|
225
|
+
|
|
226
|
+
# Iterate over the rules until the recusion limit is reached or no new rules are fired
|
|
227
|
+
for iteration in range(self.recusion_limit + 1):
|
|
228
|
+
if iteration == self.recusion_limit:
|
|
229
|
+
msg = f"Recursion limit of {self.recusion_limit} reached"
|
|
230
|
+
raise RecursionLimitError(msg)
|
|
231
|
+
|
|
232
|
+
# Ensure that rules do not interfere with one another in the same iteration
|
|
233
|
+
facts_snapshot = self._facts.copy()
|
|
234
|
+
|
|
235
|
+
if audit:
|
|
236
|
+
self._audit.iteration_start()
|
|
237
|
+
|
|
238
|
+
# Evaluate matching rules
|
|
239
|
+
for fact_str, rules in self._rules.items():
|
|
240
|
+
if fact_str in scope:
|
|
241
|
+
for rule in rules:
|
|
242
|
+
# Skip the rule if it was already evaluated in this iteration (due to matching on another Fact)
|
|
243
|
+
if rule.id in evaluated_rules:
|
|
244
|
+
continue
|
|
245
|
+
evaluated_rules.add(rule.id)
|
|
246
|
+
|
|
247
|
+
# Skip if not all facts required by the rule are present
|
|
248
|
+
try:
|
|
249
|
+
resolved_facts = self._resolve_facts(rule.when, facts_snapshot)
|
|
250
|
+
except KeyError as e:
|
|
251
|
+
logger.debug("Rule %s (%s) skipped due to missing fact: %s", rule.name, rule.id, str(e))
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
if audit:
|
|
255
|
+
self._audit.rule_start()
|
|
256
|
+
|
|
257
|
+
# Evaluate the rule and prepare the aciton
|
|
258
|
+
action = None
|
|
259
|
+
condition_result = rule.when(*resolved_facts)
|
|
260
|
+
if condition_result:
|
|
261
|
+
action = rule.then
|
|
262
|
+
elif rule.inverse:
|
|
263
|
+
action = rule.inverse
|
|
264
|
+
|
|
265
|
+
# Evaluate the action and update the consequences
|
|
266
|
+
action_result = None
|
|
267
|
+
if action:
|
|
268
|
+
action_result = action(*self._resolve_facts(action, facts_snapshot))
|
|
269
|
+
facts = self._update_facts(action_result)
|
|
270
|
+
consequence.update(facts)
|
|
271
|
+
|
|
272
|
+
if audit:
|
|
273
|
+
self._audit.rule_end(rule, action_result, facts_snapshot, condition_result=condition_result)
|
|
274
|
+
|
|
275
|
+
if audit:
|
|
276
|
+
self._audit.iteration_end()
|
|
277
|
+
|
|
278
|
+
# Check for next iteration
|
|
279
|
+
if consequence:
|
|
280
|
+
scope = consequence
|
|
281
|
+
consequence = set()
|
|
282
|
+
evaluated_rules.clear()
|
|
283
|
+
else:
|
|
284
|
+
break
|
|
285
|
+
|
|
286
|
+
def yaml_report(self) -> str:
|
|
287
|
+
return self._audit.generate_yaml_report()
|
vulcan_core/models.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Copyright 2025 Latchfield Technologies http://latchfield.com
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from collections.abc import Callable, Iterator, Mapping
|
|
8
|
+
from copy import copy
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from enum import StrEnum, auto
|
|
11
|
+
from typing import (
|
|
12
|
+
TYPE_CHECKING,
|
|
13
|
+
Any,
|
|
14
|
+
Protocol,
|
|
15
|
+
Self,
|
|
16
|
+
dataclass_transform,
|
|
17
|
+
runtime_checkable,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from langchain.schema import Document
|
|
21
|
+
|
|
22
|
+
from vulcan_core.util import is_private
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING: # pragma: no cover - not used at runtime
|
|
25
|
+
from functools import partial
|
|
26
|
+
|
|
27
|
+
from langchain_core.vectorstores import VectorStoreRetriever
|
|
28
|
+
|
|
29
|
+
type ActionReturn = tuple[partial[Fact] | Fact, ...] | partial[Fact] | Fact
|
|
30
|
+
type ActionCallable = Callable[..., ActionReturn]
|
|
31
|
+
type ConditionCallable = Callable[..., bool | None]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# TODO: Consolidate with AttrDict, and/or figure out how to extende from Mapping
|
|
35
|
+
class ImmutableAttrAsDict:
|
|
36
|
+
"""
|
|
37
|
+
ImmutableAttrAsDict is an abstract base class that provides dictionary-like access to its attributes.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __getitem__(self, key: str) -> Any:
|
|
41
|
+
try:
|
|
42
|
+
return getattr(self, self.validate(key))
|
|
43
|
+
except KeyError:
|
|
44
|
+
if hasattr(self, "__missing__"):
|
|
45
|
+
return self.__missing__(key) # type: ignore
|
|
46
|
+
else:
|
|
47
|
+
raise
|
|
48
|
+
|
|
49
|
+
def __contains__(self, key: str) -> bool:
|
|
50
|
+
return hasattr(self, self.validate(key))
|
|
51
|
+
|
|
52
|
+
def __iter__(self) -> Iterator[str]:
|
|
53
|
+
return (key for key in self.__annotations__ if not is_private(key))
|
|
54
|
+
|
|
55
|
+
def __len__(self) -> int:
|
|
56
|
+
return sum(1 for _ in self)
|
|
57
|
+
|
|
58
|
+
def validate(self, key: str) -> str:
|
|
59
|
+
if is_private(key):
|
|
60
|
+
msg = f"Access denied to private attribute: {key}"
|
|
61
|
+
raise KeyError(msg)
|
|
62
|
+
|
|
63
|
+
if key not in self.__annotations__:
|
|
64
|
+
raise KeyError(key)
|
|
65
|
+
|
|
66
|
+
return key
|
|
67
|
+
|
|
68
|
+
def __init__(self):
|
|
69
|
+
if type(self) is ImmutableAttrAsDict:
|
|
70
|
+
msg = f"{ImmutableAttrAsDict.__name__} is an abstract class that can not be directly instantiated."
|
|
71
|
+
raise TypeError(msg)
|
|
72
|
+
|
|
73
|
+
def __reversed__(self) -> Iterator[str]:
|
|
74
|
+
return reversed(list(self))
|
|
75
|
+
|
|
76
|
+
def __or__(self, other: dict) -> dict:
|
|
77
|
+
return dict(self) | other
|
|
78
|
+
|
|
79
|
+
def keys(self) -> list[str]:
|
|
80
|
+
return list(self)
|
|
81
|
+
|
|
82
|
+
def values(self) -> list[Any]:
|
|
83
|
+
return [getattr(self, key) for key in self]
|
|
84
|
+
|
|
85
|
+
def items(self) -> list[tuple[str, Any]]:
|
|
86
|
+
return [(key, getattr(self, key)) for key in self]
|
|
87
|
+
|
|
88
|
+
def get(self, key: str, default: Any = None):
|
|
89
|
+
return getattr(self, self.validate(key), default)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass_transform(kw_only_default=True, frozen_default=True)
|
|
93
|
+
class FactMetaclass(type):
|
|
94
|
+
"""
|
|
95
|
+
FactMetaclass is a metaclass that modifies the creation of new classes to automatically
|
|
96
|
+
apply the `dataclass` decorator with `kw_only=True` and `frozen=True` options.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __new__(cls, name: str, bases: tuple[type], class_dict: dict[str, Any], **kwargs: Any):
|
|
100
|
+
self = super().__new__(cls, name, bases, class_dict, **kwargs)
|
|
101
|
+
return dataclass(kw_only=True, frozen=True)(self)
|
|
102
|
+
|
|
103
|
+
def _is_dataclass_instance(cls) -> bool:
|
|
104
|
+
"""Determine if this is a dataclass instance by looking for __dataclass_fields__"""
|
|
105
|
+
return "__dataclass_fields__" not in super().__getattribute__("__dict__")
|
|
106
|
+
|
|
107
|
+
# TODO: Implement a context manager to allow access to the default class values
|
|
108
|
+
# BUG: This causes pylance to not report missing attributes, we need a different way to handle f strings... maybe
|
|
109
|
+
# the __format__ method?
|
|
110
|
+
def __getattribute__(cls, name):
|
|
111
|
+
"""
|
|
112
|
+
Returns a {templated} representation of the Fact's public attributes for deferred use in fstrings. This is
|
|
113
|
+
useful in rule clauses so that IDE autocomplete can be used in fstrings while deferring evaluation of
|
|
114
|
+
the content."""
|
|
115
|
+
if name.startswith("_") or cls._is_dataclass_instance():
|
|
116
|
+
return super().__getattribute__(name)
|
|
117
|
+
else:
|
|
118
|
+
return f"{{{cls.__name__}.{name}}}"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Fact(ImmutableAttrAsDict, metaclass=FactMetaclass):
|
|
122
|
+
"""
|
|
123
|
+
An abstract class that must be used to define rule engine fact schemas and instantiate data into working memory. Facts
|
|
124
|
+
may be combined with partial facts of the same type using the `|` operator. This is useful for Actions that only
|
|
125
|
+
need to update a portion of working memory.
|
|
126
|
+
|
|
127
|
+
Example: `new_fact = Inventory(apples=1) | partial(Inventory, oranges=2)`
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __or__(self, other: partial[Self] | Self) -> Self:
|
|
131
|
+
"""
|
|
132
|
+
If the right hand operand is a Fact, it is returned as-is. However, if it is a partial Fact, a copy of the
|
|
133
|
+
lefthand operand is created with the partial Fact's keywords applied.
|
|
134
|
+
"""
|
|
135
|
+
if isinstance(other, Fact):
|
|
136
|
+
if type(self) is not type(other):
|
|
137
|
+
msg = f"Union operator disallowed for types {type(self).__name__} and {type(other).__name__}"
|
|
138
|
+
raise TypeError(msg)
|
|
139
|
+
|
|
140
|
+
return other # type: ignore
|
|
141
|
+
else:
|
|
142
|
+
if type(self) is not other.func:
|
|
143
|
+
msg = f"Union operator disallowed for types {type(self).__name__} and {other.func}"
|
|
144
|
+
raise TypeError(msg)
|
|
145
|
+
|
|
146
|
+
new_fact = copy(self)
|
|
147
|
+
for kw, value in other.keywords.items():
|
|
148
|
+
object.__setattr__(new_fact, kw, value)
|
|
149
|
+
return new_fact # type: ignore
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass(frozen=True)
|
|
153
|
+
class DeclaresFacts(ABC):
|
|
154
|
+
facts: tuple[str, ...]
|
|
155
|
+
# TODO differentiate bettwen facts consumed vs produced for better tracking/diagnostics
|
|
156
|
+
# Will probably be needed to detecte cycles in the graph
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclass(frozen=True)
|
|
160
|
+
class FactHandler[T: Callable, R: Any](ABC):
|
|
161
|
+
func: T
|
|
162
|
+
|
|
163
|
+
@abstractmethod
|
|
164
|
+
def _evaluate(self, *args: Fact) -> R: ...
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@runtime_checkable
|
|
168
|
+
class HasSource(Protocol):
|
|
169
|
+
__source__: str
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class ChunkingStrategy(StrEnum):
|
|
173
|
+
SENTENCE = auto()
|
|
174
|
+
PARAGRAPH = auto()
|
|
175
|
+
PAGE = auto()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@dataclass(kw_only=True, slots=True)
|
|
179
|
+
class Similarity(Mapping[str, list[tuple[str, float]]]):
|
|
180
|
+
# TODO: Figure out how to cache vectors / and results?
|
|
181
|
+
|
|
182
|
+
@abstractmethod
|
|
183
|
+
def __getitem__(self, key: str) -> list[str]:
|
|
184
|
+
"""Vectorizes key and performs similarity search returning a list of matching."""
|
|
185
|
+
raise NotImplementedError
|
|
186
|
+
|
|
187
|
+
@abstractmethod
|
|
188
|
+
def __contains__(self, key: str) -> bool:
|
|
189
|
+
"""Vectorizes key and performs similarity search returning a boolean if there is at least one match."""
|
|
190
|
+
raise NotImplementedError
|
|
191
|
+
|
|
192
|
+
@abstractmethod
|
|
193
|
+
def __iadd__(self, value: str) -> Self:
|
|
194
|
+
raise NotImplementedError
|
|
195
|
+
|
|
196
|
+
@abstractmethod
|
|
197
|
+
def __iter__(self) -> str:
|
|
198
|
+
raise NotImplementedError
|
|
199
|
+
|
|
200
|
+
@abstractmethod
|
|
201
|
+
def __len__(self) -> int:
|
|
202
|
+
raise NotImplementedError
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class ProxyInitializationError(Exception):
|
|
206
|
+
"""Raised when a Proxy class is used without the proxy being initialized."""
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@dataclass(kw_only=True, slots=True)
|
|
210
|
+
class ProxyLazyLookup(Similarity):
|
|
211
|
+
_proxy: Similarity | None = None
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def proxy(self) -> Similarity:
|
|
215
|
+
if self._proxy:
|
|
216
|
+
return self
|
|
217
|
+
else:
|
|
218
|
+
msg = "The `proxy` attribute must be set before the class instance can be used."
|
|
219
|
+
raise ProxyInitializationError(msg)
|
|
220
|
+
|
|
221
|
+
@proxy.setter
|
|
222
|
+
def proxy(self, value: Similarity) -> None:
|
|
223
|
+
if not self._proxy:
|
|
224
|
+
self._proxy = value
|
|
225
|
+
else:
|
|
226
|
+
msg = "The `proxy` attribute can only be initialized once."
|
|
227
|
+
raise ProxyInitializationError(msg)
|
|
228
|
+
|
|
229
|
+
def __getitem__(self, key: str) -> list[str]:
|
|
230
|
+
return self.proxy[key]
|
|
231
|
+
|
|
232
|
+
def __contains__(self, key: str) -> bool:
|
|
233
|
+
raise NotImplementedError
|
|
234
|
+
|
|
235
|
+
def __iadd__(self, value: str) -> Self:
|
|
236
|
+
self.proxy += value
|
|
237
|
+
return self
|
|
238
|
+
|
|
239
|
+
def __iter__(self) -> str:
|
|
240
|
+
raise NotImplementedError
|
|
241
|
+
|
|
242
|
+
def __len__(self) -> int:
|
|
243
|
+
raise NotImplementedError
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@dataclass(kw_only=True, slots=True)
|
|
247
|
+
class RetrieverAdapter(Similarity):
|
|
248
|
+
"""A lazy lookup that uses the Chroma vector store to perform similarity searches using OpenAI embeddings."""
|
|
249
|
+
|
|
250
|
+
store: VectorStoreRetriever
|
|
251
|
+
|
|
252
|
+
def __getitem__(self, key: str) -> list[str]:
|
|
253
|
+
"""Vectorizes key and performs similarity search returning a list of matching content."""
|
|
254
|
+
return [doc.page_content for doc in self.store.invoke(key)]
|
|
255
|
+
|
|
256
|
+
def __contains__(self, key: str) -> bool:
|
|
257
|
+
"""Vectorizes key and performs similarity search returning a boolean if there is at least one match."""
|
|
258
|
+
raise NotImplementedError
|
|
259
|
+
|
|
260
|
+
def __iadd__(self, value: str) -> Self:
|
|
261
|
+
self.store.add_documents([Document(value)])
|
|
262
|
+
return self
|
|
263
|
+
|
|
264
|
+
def __iter__(self) -> str:
|
|
265
|
+
raise NotImplementedError
|
|
266
|
+
|
|
267
|
+
def __len__(self) -> int:
|
|
268
|
+
raise NotImplementedError
|
|
269
|
+
|
|
270
|
+
def __str__(self) -> str:
|
|
271
|
+
return f"RetrieverAdapter(search_type={self.store.search_type})"
|