modaic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modaic might be problematic. Click here for more details.
- modaic/__init__.py +25 -0
- modaic/agents/rag_agent.py +33 -0
- modaic/agents/registry.py +84 -0
- modaic/auto_agent.py +228 -0
- modaic/context/__init__.py +34 -0
- modaic/context/base.py +1064 -0
- modaic/context/dtype_mapping.py +25 -0
- modaic/context/table.py +585 -0
- modaic/context/text.py +94 -0
- modaic/databases/__init__.py +35 -0
- modaic/databases/graph_database.py +269 -0
- modaic/databases/sql_database.py +355 -0
- modaic/databases/vector_database/__init__.py +12 -0
- modaic/databases/vector_database/benchmarks/baseline.py +123 -0
- modaic/databases/vector_database/benchmarks/common.py +48 -0
- modaic/databases/vector_database/benchmarks/fork.py +132 -0
- modaic/databases/vector_database/benchmarks/threaded.py +119 -0
- modaic/databases/vector_database/vector_database.py +722 -0
- modaic/databases/vector_database/vendors/milvus.py +408 -0
- modaic/databases/vector_database/vendors/mongodb.py +0 -0
- modaic/databases/vector_database/vendors/pinecone.py +0 -0
- modaic/databases/vector_database/vendors/qdrant.py +1 -0
- modaic/exceptions.py +38 -0
- modaic/hub.py +305 -0
- modaic/indexing.py +127 -0
- modaic/module_utils.py +341 -0
- modaic/observability.py +275 -0
- modaic/precompiled.py +429 -0
- modaic/query_language.py +321 -0
- modaic/storage/__init__.py +3 -0
- modaic/storage/file_store.py +239 -0
- modaic/storage/pickle_store.py +25 -0
- modaic/types.py +287 -0
- modaic/utils.py +21 -0
- modaic-0.1.0.dist-info/METADATA +281 -0
- modaic-0.1.0.dist-info/RECORD +39 -0
- modaic-0.1.0.dist-info/WHEEL +5 -0
- modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
- modaic-0.1.0.dist-info/top_level.txt +1 -0
modaic/context/base.py
ADDED
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import typing as t
|
|
3
|
+
import uuid
|
|
4
|
+
from functools import lru_cache, wraps
|
|
5
|
+
from types import UnionType
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from pydantic import (
|
|
10
|
+
BaseModel,
|
|
11
|
+
ConfigDict,
|
|
12
|
+
PrivateAttr,
|
|
13
|
+
SerializationInfo,
|
|
14
|
+
SerializerFunctionWrapHandler,
|
|
15
|
+
ValidationError,
|
|
16
|
+
ValidatorFunctionWrapHandler,
|
|
17
|
+
field_validator,
|
|
18
|
+
model_serializer,
|
|
19
|
+
model_validator,
|
|
20
|
+
)
|
|
21
|
+
from pydantic._internal._model_construction import ModelMetaclass
|
|
22
|
+
from pydantic.fields import ModelPrivateAttr
|
|
23
|
+
from pydantic.main import IncEx
|
|
24
|
+
from pydantic.v1 import Field as V1Field
|
|
25
|
+
from pydantic_core import CoreSchema, SchemaSerializer
|
|
26
|
+
|
|
27
|
+
from ..query_language import Prop
|
|
28
|
+
from ..storage.file_store import FileStore
|
|
29
|
+
from ..types import Field, Schema
|
|
30
|
+
|
|
31
|
+
if t.TYPE_CHECKING:
|
|
32
|
+
import gqlalchemy
|
|
33
|
+
|
|
34
|
+
from modaic.databases.graph_database import GraphDatabase
|
|
35
|
+
from modaic.storage.file_store import FileStore
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
GQLALCHEMY_EXCLUDED_FIELDS = [
|
|
39
|
+
"id",
|
|
40
|
+
"_gqlalchemy_id",
|
|
41
|
+
"_type_registry",
|
|
42
|
+
"_labels",
|
|
43
|
+
"_gqlalchemy_class_registry",
|
|
44
|
+
"_type",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ModaicHydrationError(Exception):
|
|
49
|
+
"""Error raised when a function tries to use a Context param that is not hydrated."""
|
|
50
|
+
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ModelHydratedAttr(ModelPrivateAttr):
|
|
55
|
+
def __init__(self):
|
|
56
|
+
super().__init__(default=None, default_factory=None)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def HydratedAttr(): # noqa: N802, ANN201
|
|
60
|
+
"""
|
|
61
|
+
Created a hydrated field. Hydrated fields are fields that are None by default and are hydrated by Context.hydrate()
|
|
62
|
+
"""
|
|
63
|
+
return ModelHydratedAttr()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def requires_hydration(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
|
|
67
|
+
"""
|
|
68
|
+
Decorator that ensures all hydrated attributes are set before calling the function.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
func: The method being wrapped.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The wrapped method that raises if any hydrated attribute is None.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
@wraps(func)
|
|
78
|
+
def wrapper(*args, **kwargs):
|
|
79
|
+
self = args[0]
|
|
80
|
+
|
|
81
|
+
for attr in self.__class__.__hydrated_attributes__:
|
|
82
|
+
if getattr(self, attr) is None:
|
|
83
|
+
raise ModaicHydrationError(
|
|
84
|
+
f"Attribute {attr} is not hydrated. Please call `self.hydrate()` to hydrate the attribute."
|
|
85
|
+
)
|
|
86
|
+
return func(*args, **kwargs)
|
|
87
|
+
|
|
88
|
+
return wrapper
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_unhidden_serializer(cls: type[BaseModel]) -> SchemaSerializer:
|
|
92
|
+
"""
|
|
93
|
+
Creates a new serializer from cls.__pydantic_core_schema__ with the hidden fields included.
|
|
94
|
+
This is nescesarry to recursively dump hidden Context objects with hidden fields inside of other context objects.
|
|
95
|
+
"""
|
|
96
|
+
core = copy.deepcopy(cls.__pydantic_core_schema__)
|
|
97
|
+
|
|
98
|
+
def walk(node: dict | list):
|
|
99
|
+
if isinstance(node, dict):
|
|
100
|
+
if (
|
|
101
|
+
node.get("type") == "model"
|
|
102
|
+
and node.get("serialization", {}).get("function", None) is Context.hidden_serializer
|
|
103
|
+
):
|
|
104
|
+
del node["serialization"]
|
|
105
|
+
for v in node.values():
|
|
106
|
+
walk(v)
|
|
107
|
+
elif isinstance(node, list):
|
|
108
|
+
for v in node:
|
|
109
|
+
walk(v)
|
|
110
|
+
|
|
111
|
+
walk(core)
|
|
112
|
+
return SchemaSerializer(core)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ContextMeta(ModelMetaclass):
|
|
116
|
+
def __getattr__(cls, name: str) -> t.Any: # noqa: N805
|
|
117
|
+
"""
|
|
118
|
+
Enablees the creation of Prop classes via ContextClass.property_name. Does this in a safe way that doesn't conflict with pydantic's own metaclass.
|
|
119
|
+
"""
|
|
120
|
+
# 1) Let Pydantic's own metaclass handle private attrs etc.
|
|
121
|
+
try:
|
|
122
|
+
return ModelMetaclass.__getattr__(cls, name)
|
|
123
|
+
except AttributeError:
|
|
124
|
+
pass # not a private attr; continue
|
|
125
|
+
|
|
126
|
+
# 2) Safely look up fields without triggering descriptors or our __getattr__ again
|
|
127
|
+
d = type.__getattribute__(cls, "__dict__")
|
|
128
|
+
fields = d.get("__pydantic_fields__")
|
|
129
|
+
if fields and name in fields:
|
|
130
|
+
return Prop(name) # FieldInfo (or whatever Pydantic stores)
|
|
131
|
+
|
|
132
|
+
# 3) Not a field either
|
|
133
|
+
raise AttributeError(name)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Context(BaseModel, metaclass=ContextMeta):
|
|
137
|
+
"""
|
|
138
|
+
Base class for all Context objects.
|
|
139
|
+
|
|
140
|
+
Attributes:
|
|
141
|
+
id: The id of the serialized context.
|
|
142
|
+
source: The source of the context object.
|
|
143
|
+
metadata: The metadata of the context object.
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
In this example, `CaptionedImage` stores the caption and the caption embedding the image path and the image itself. Since we can't serialize the image, we use the `HydratedAttr` decorator to mark the `_image` field as requiring hydration.
|
|
147
|
+
```python
|
|
148
|
+
from modaic.context import Context
|
|
149
|
+
from modaic.types import String, Vector, Float16Vector
|
|
150
|
+
|
|
151
|
+
class CaptionedImage(Context):
|
|
152
|
+
caption: String[100]
|
|
153
|
+
caption_embedding: Float16Vector[384]
|
|
154
|
+
_image: PIL.Image.Image = HydratedAttr()
|
|
155
|
+
|
|
156
|
+
def hydrate(self, file_store: FileStore):
|
|
157
|
+
image_path = file_store.get_files(self.id)["image"]
|
|
158
|
+
self._image = PIL.Image.open(image_path)
|
|
159
|
+
|
|
160
|
+
```
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()), hidden=True)
|
|
164
|
+
parent: t.Optional[str] = Field(default=None, hidden=True)
|
|
165
|
+
metadata: dict = Field(default_factory=dict, hidden=True)
|
|
166
|
+
|
|
167
|
+
_gqlalchemy_id: t.Optional[int] = PrivateAttr(default=None)
|
|
168
|
+
_chunks: t.List["Context"] = PrivateAttr(default_factory=list)
|
|
169
|
+
|
|
170
|
+
# CAVEAT: All Context subclasses share the same instance of _type_registry. This is intentional.
|
|
171
|
+
_type_registry: t.ClassVar[t.Dict[str, t.Type["Context"]]] = {}
|
|
172
|
+
_labels: t.ClassVar[frozenset[str]] = frozenset()
|
|
173
|
+
_gqlalchemy_class_registry: t.ClassVar[t.Dict[str, t.Type["gqlalchemy.models.GraphObject"]]] = {}
|
|
174
|
+
|
|
175
|
+
def __init_subclass__(cls, **kwargs: t.Any) -> None:
|
|
176
|
+
"""Allow class-header keywords without raising TypeError.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
**kwargs: Arbitrary keywords from subclass declarations (e.g., type="Label").
|
|
180
|
+
"""
|
|
181
|
+
super().__init_subclass__()
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def __pydantic_init_subclass__(cls, **kwargs):
|
|
185
|
+
if "type" in kwargs:
|
|
186
|
+
cls._type = kwargs["type"]
|
|
187
|
+
else:
|
|
188
|
+
cls._type = cls.__name__
|
|
189
|
+
|
|
190
|
+
assert cls._type != "Node" and cls._type != "Relationship", (
|
|
191
|
+
f"Class {cls.__name__} cannot use name 'Node' or 'Relationship' as type. Please use a different name. You can use a custom type by using the 'type' keyword. Example: `class {cls.__name__}(Context, type=<custom_type_name>)`"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# TODO: revisit this. Should we allow multiple parents?
|
|
195
|
+
# Get parent class labels
|
|
196
|
+
parent_labels = [b._labels for b in cls.__bases__ if issubclass(b, Context)]
|
|
197
|
+
assert len(parent_labels) == 1, (
|
|
198
|
+
f"Context class {cls.__name__} cannot have multiple Context classes as parents. Should it? Submit an issue to tell us about your use case. https://github.com/modaic-ai/modaic/issues"
|
|
199
|
+
)
|
|
200
|
+
cls._labels = frozenset({cls._type}) | parent_labels[0]
|
|
201
|
+
assert cls._type not in cls._type_registry, (
|
|
202
|
+
f"Cannot have multiple Context/Relation classes with type = '{cls._type}'"
|
|
203
|
+
)
|
|
204
|
+
cls._type_registry[cls._type] = cls
|
|
205
|
+
|
|
206
|
+
cls.__hydrated_attributes__ = set()
|
|
207
|
+
for name in (private_attrs := cls.__private_attributes__):
|
|
208
|
+
if isinstance(private_attrs[name], ModelHydratedAttr):
|
|
209
|
+
cls.__hydrated_attributes__.add(name)
|
|
210
|
+
|
|
211
|
+
cls.__modaic_serializer__ = _get_unhidden_serializer(cls)
|
|
212
|
+
|
|
213
|
+
def __str__(self) -> str:
|
|
214
|
+
"""
|
|
215
|
+
Returns a string representation of the Context instance, including all field values.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
str: String representation with all field values.
|
|
219
|
+
"""
|
|
220
|
+
values = self.model_dump(mode="json", include_hidden=True)
|
|
221
|
+
return f"{self.__class__._type}({values})"
|
|
222
|
+
|
|
223
|
+
def __repr__(self):
|
|
224
|
+
return self.__str__()
|
|
225
|
+
|
|
226
|
+
def to_gqlalchemy(self, db: "GraphDatabase") -> "gqlalchemy.Node":
|
|
227
|
+
"""
|
|
228
|
+
Convert the Context object to a GQLAlchemy object.
|
|
229
|
+
!!! warning
|
|
230
|
+
This method is not thread safe. We are actively working on a solution to make it thread safe.
|
|
231
|
+
"""
|
|
232
|
+
try:
|
|
233
|
+
import gqlalchemy
|
|
234
|
+
|
|
235
|
+
from modaic.databases.graph_database import GraphDatabase
|
|
236
|
+
except ImportError:
|
|
237
|
+
raise ImportError(
|
|
238
|
+
"GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
|
|
239
|
+
) from None
|
|
240
|
+
assert isinstance(db, GraphDatabase), (
|
|
241
|
+
f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
|
|
242
|
+
)
|
|
243
|
+
cls = self.__class__
|
|
244
|
+
|
|
245
|
+
# Dynamically create a GQLAlchemy Node class for the Context if it doesn't exist
|
|
246
|
+
if cls._type not in cls._gqlalchemy_class_registry:
|
|
247
|
+
field_annotations = get_annotations(
|
|
248
|
+
cls,
|
|
249
|
+
exclude=GQLALCHEMY_EXCLUDED_FIELDS,
|
|
250
|
+
)
|
|
251
|
+
field_defaults = get_defaults(cls, exclude=GQLALCHEMY_EXCLUDED_FIELDS)
|
|
252
|
+
gqlalchemy_class = type(
|
|
253
|
+
f"{cls.__name__}Node",
|
|
254
|
+
(gqlalchemy.Node,),
|
|
255
|
+
{
|
|
256
|
+
"__annotations__": {**field_annotations, "modaic_id": str},
|
|
257
|
+
"modaic_id": V1Field(unique=True, db=db._client),
|
|
258
|
+
**field_defaults,
|
|
259
|
+
},
|
|
260
|
+
label=cls._type,
|
|
261
|
+
)
|
|
262
|
+
cls._gqlalchemy_class_registry[cls._type] = gqlalchemy_class
|
|
263
|
+
# Return a new GQLAlchemy Node object
|
|
264
|
+
gqlalchemy_class = cls._gqlalchemy_class_registry[cls._type]
|
|
265
|
+
if self._gqlalchemy_id is None:
|
|
266
|
+
return gqlalchemy_class(
|
|
267
|
+
_labels=set(self._labels),
|
|
268
|
+
modaic_id=self.id,
|
|
269
|
+
**self.model_dump(exclude={"id"}, include_hidden=True),
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
return gqlalchemy_class(
|
|
273
|
+
_labels=set(self._labels),
|
|
274
|
+
modaic_id=self.id,
|
|
275
|
+
_id=self._gqlalchemy_id,
|
|
276
|
+
**self.model_dump(exclude={"id"}, include_hidden=True),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
@classmethod
|
|
280
|
+
def from_gqlalchemy(cls, gqlalchemy_node: "gqlalchemy.Node") -> "Context":
|
|
281
|
+
"""
|
|
282
|
+
Convert a GQLAlchemy Node into a `Context` instance. If cls is the Context class itself, it will return the best subclass of Context that matches the labels of the GQLAlchemy Node.
|
|
283
|
+
Args:
|
|
284
|
+
gqlalchemy_node: The GQLAlchemy Node to convert.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
The converted Context or Context subclass instance.
|
|
288
|
+
|
|
289
|
+
"""
|
|
290
|
+
if cls is not Context:
|
|
291
|
+
if cls._type not in gqlalchemy_node._labels:
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"Cannot convert GQLAlchemy Node {gqlalchemy_node} to {cls.__name__} because it does not have the label '{cls._type}'"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
kwargs = {**gqlalchemy_node._properties}
|
|
298
|
+
modaic_id = kwargs.pop("modaic_id")
|
|
299
|
+
kwargs["id"] = modaic_id
|
|
300
|
+
context_obj = cls(**kwargs)
|
|
301
|
+
context_obj._gqlalchemy_id = gqlalchemy_node._id
|
|
302
|
+
return context_obj
|
|
303
|
+
except ValidationError as e:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Failed to convert GQLAlchemy Node {gqlalchemy_node} to {cls.__name__} because it does not have the required fields.\nError: {e}"
|
|
306
|
+
) from e
|
|
307
|
+
|
|
308
|
+
# If cls is Context, we need to find the best subclass of Context that matches the labels of the GQLAlchemy Node.
|
|
309
|
+
best_subclass = Context._best_subclass(frozenset(gqlalchemy_node._labels))
|
|
310
|
+
return best_subclass.from_gqlalchemy(gqlalchemy_node)
|
|
311
|
+
|
|
312
|
+
def save(self, db: "GraphDatabase"):
|
|
313
|
+
"""
|
|
314
|
+
Save the Context object to the graph database.
|
|
315
|
+
|
|
316
|
+
!!! warning
|
|
317
|
+
This method is not thread safe. We are actively working on a solution to make it thread safe.
|
|
318
|
+
"""
|
|
319
|
+
try:
|
|
320
|
+
from modaic.databases.graph_database import GraphDatabase
|
|
321
|
+
except ImportError:
|
|
322
|
+
raise ImportError(
|
|
323
|
+
"GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
|
|
324
|
+
) from None
|
|
325
|
+
|
|
326
|
+
assert isinstance(db, GraphDatabase), (
|
|
327
|
+
f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
result = db.save_node(self)
|
|
331
|
+
|
|
332
|
+
for k in self.model_dump(exclude={"id"}, include_hidden=True):
|
|
333
|
+
setattr(self, k, getattr(result, k))
|
|
334
|
+
self._gqlalchemy_id = result._id
|
|
335
|
+
|
|
336
|
+
def load(self, database: "GraphDatabase"):
|
|
337
|
+
"""
|
|
338
|
+
Loads a node from Memgraph.
|
|
339
|
+
If the node._id is not None it fetches the node from Memgraph with that
|
|
340
|
+
internal id.
|
|
341
|
+
If the node has unique fields it fetches the node from Memgraph with
|
|
342
|
+
those unique fields set.
|
|
343
|
+
Otherwise it tries to find any node in Memgraph that has all properties
|
|
344
|
+
set to exactly the same values.
|
|
345
|
+
If no node is found or no properties are set it raises a GQLAlchemyError.
|
|
346
|
+
"""
|
|
347
|
+
raise NotImplementedError("Not implemented")
|
|
348
|
+
|
|
349
|
+
@staticmethod
|
|
350
|
+
@lru_cache
|
|
351
|
+
def _best_subclass(labels: t.FrozenSet[str]) -> t.Type["Context"]:
|
|
352
|
+
best_subclass = None
|
|
353
|
+
for label in labels:
|
|
354
|
+
if current_subclass := Context._type_registry.get(label):
|
|
355
|
+
# check if the current subclass has more parents than the best subclass
|
|
356
|
+
if best_subclass is None or len(current_subclass.__mro__) > len(best_subclass.__mro__):
|
|
357
|
+
best_subclass = current_subclass
|
|
358
|
+
|
|
359
|
+
if best_subclass is None:
|
|
360
|
+
raise ValueError(f"Cannot find a matching Context class for labels: {labels}")
|
|
361
|
+
return best_subclass
|
|
362
|
+
|
|
363
|
+
# TODO: Make iterable-friendly
|
|
364
|
+
def chunk_with(
|
|
365
|
+
self,
|
|
366
|
+
chunk_fn: t.Callable[["Context"], t.Iterable["Context"]],
|
|
367
|
+
kwargs: t.Optional[t.Dict] = None,
|
|
368
|
+
) -> None:
|
|
369
|
+
"""
|
|
370
|
+
Chunks the context object into a list of context objects.
|
|
371
|
+
"""
|
|
372
|
+
if kwargs is None:
|
|
373
|
+
kwargs = {}
|
|
374
|
+
self._chunks = list(chunk_fn(self, **kwargs))
|
|
375
|
+
for chunk in self._chunks:
|
|
376
|
+
chunk.parent = self.id
|
|
377
|
+
|
|
378
|
+
def apply_to_chunks(self, apply_fn: t.Callable[["Context"], None], **kwargs):
|
|
379
|
+
"""
|
|
380
|
+
Applies apply_fn to each chunk in chunks.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
apply_fn: The function to apply to each chunk. Function should take in a Context object and mutate it.
|
|
384
|
+
**kwargs: Additional keyword arguments to pass to apply_fn.
|
|
385
|
+
"""
|
|
386
|
+
for chunk in self.chunks:
|
|
387
|
+
apply_fn(chunk, **kwargs)
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def chunks(self) -> t.List["Context"]:
|
|
391
|
+
"""
|
|
392
|
+
Returns the chunks of the context object.
|
|
393
|
+
"""
|
|
394
|
+
return self._chunks
|
|
395
|
+
|
|
396
|
+
@property
|
|
397
|
+
def is_hydrated(self) -> bool:
|
|
398
|
+
"""
|
|
399
|
+
Returns True if the context object is hydrated.
|
|
400
|
+
"""
|
|
401
|
+
if not hasattr(self, "__hydrated_attributes__"):
|
|
402
|
+
return True
|
|
403
|
+
return all(getattr(self, attr) is not None for attr in self.__hydrated_attributes__)
|
|
404
|
+
|
|
405
|
+
@classmethod
|
|
406
|
+
def schema(cls) -> Schema:
|
|
407
|
+
if hasattr(cls, "_schema"):
|
|
408
|
+
return cls._schema
|
|
409
|
+
cls._schema = Schema.from_json_schema(cls.model_json_schema())
|
|
410
|
+
return cls._schema
|
|
411
|
+
|
|
412
|
+
@model_serializer(mode="wrap")
|
|
413
|
+
def hidden_serializer(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> dict[str, Any]:
|
|
414
|
+
dump = handler(self)
|
|
415
|
+
if info.context is None or not info.context.get("include_hidden"):
|
|
416
|
+
for name, field in self.__class__.model_fields.items():
|
|
417
|
+
if (extra := getattr(field, "json_schema_extra", None)) and extra.get("hidden"):
|
|
418
|
+
dump.pop(name, None)
|
|
419
|
+
return dump
|
|
420
|
+
|
|
421
|
+
def model_dump(
|
|
422
|
+
self,
|
|
423
|
+
*,
|
|
424
|
+
mode: str | Literal["json", "python"] = "python",
|
|
425
|
+
include: IncEx | None = None,
|
|
426
|
+
exclude: IncEx | None = None,
|
|
427
|
+
context: Any | None = None,
|
|
428
|
+
by_alias: bool | None = None,
|
|
429
|
+
exclude_unset: bool = False,
|
|
430
|
+
exclude_defaults: bool = False,
|
|
431
|
+
exclude_none: bool = False,
|
|
432
|
+
round_trip: bool = False,
|
|
433
|
+
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
434
|
+
fallback: t.Callable[[Any], Any] | None = None,
|
|
435
|
+
serialize_as_any: bool = False,
|
|
436
|
+
include_hidden: bool = False,
|
|
437
|
+
) -> Any:
|
|
438
|
+
"""
|
|
439
|
+
Override of pydantics BaseModel.model_dump to allow for showing hidden fields
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
include_hidden: Whether to show hidden fields.
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
The dictionary representation of the model.
|
|
446
|
+
"""
|
|
447
|
+
if include_hidden:
|
|
448
|
+
return self.__modaic_serializer__.to_python(
|
|
449
|
+
self,
|
|
450
|
+
mode=mode,
|
|
451
|
+
include=include,
|
|
452
|
+
exclude=exclude,
|
|
453
|
+
context=context,
|
|
454
|
+
by_alias=by_alias,
|
|
455
|
+
exclude_unset=exclude_unset,
|
|
456
|
+
exclude_defaults=exclude_defaults,
|
|
457
|
+
exclude_none=exclude_none,
|
|
458
|
+
round_trip=round_trip,
|
|
459
|
+
warnings=warnings,
|
|
460
|
+
fallback=fallback,
|
|
461
|
+
serialize_as_any=False,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
else:
|
|
465
|
+
return super().model_dump(
|
|
466
|
+
mode=mode,
|
|
467
|
+
include=include,
|
|
468
|
+
exclude=exclude,
|
|
469
|
+
context=context,
|
|
470
|
+
by_alias=by_alias,
|
|
471
|
+
exclude_unset=exclude_unset,
|
|
472
|
+
exclude_defaults=exclude_defaults,
|
|
473
|
+
exclude_none=exclude_none,
|
|
474
|
+
round_trip=round_trip,
|
|
475
|
+
warnings=warnings,
|
|
476
|
+
fallback=fallback,
|
|
477
|
+
serialize_as_any=serialize_as_any,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
def model_dump_json(
|
|
481
|
+
self,
|
|
482
|
+
*,
|
|
483
|
+
indent: int | None = None,
|
|
484
|
+
include: IncEx | None = None,
|
|
485
|
+
exclude: IncEx | None = None,
|
|
486
|
+
context: Any | None = None,
|
|
487
|
+
by_alias: bool | None = None,
|
|
488
|
+
exclude_unset: bool = False,
|
|
489
|
+
exclude_defaults: bool = False,
|
|
490
|
+
exclude_none: bool = False,
|
|
491
|
+
round_trip: bool = False,
|
|
492
|
+
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
493
|
+
fallback: t.Callable[[Any], Any] | None = None,
|
|
494
|
+
serialize_as_any: bool = False,
|
|
495
|
+
include_hidden: bool = False,
|
|
496
|
+
) -> bytes | str:
|
|
497
|
+
"""
|
|
498
|
+
Override of pydantic's BaseModel.model_dump_json to allow for showing hidden fields
|
|
499
|
+
"""
|
|
500
|
+
if include_hidden:
|
|
501
|
+
return self.__modaic_serializer__.to_json(
|
|
502
|
+
self,
|
|
503
|
+
indent=indent,
|
|
504
|
+
include=include,
|
|
505
|
+
exclude=exclude,
|
|
506
|
+
context=context,
|
|
507
|
+
by_alias=by_alias,
|
|
508
|
+
exclude_unset=exclude_unset,
|
|
509
|
+
exclude_defaults=exclude_defaults,
|
|
510
|
+
exclude_none=exclude_none,
|
|
511
|
+
round_trip=round_trip,
|
|
512
|
+
warnings=warnings,
|
|
513
|
+
fallback=fallback,
|
|
514
|
+
serialize_as_any=True,
|
|
515
|
+
)
|
|
516
|
+
else:
|
|
517
|
+
return super().model_dump_json(
|
|
518
|
+
indent=indent,
|
|
519
|
+
include=include,
|
|
520
|
+
exclude=exclude,
|
|
521
|
+
context=context,
|
|
522
|
+
by_alias=by_alias,
|
|
523
|
+
exclude_unset=exclude_unset,
|
|
524
|
+
exclude_defaults=exclude_defaults,
|
|
525
|
+
exclude_none=exclude_none,
|
|
526
|
+
round_trip=round_trip,
|
|
527
|
+
warnings=warnings,
|
|
528
|
+
fallback=fallback,
|
|
529
|
+
serialize_as_any=serialize_as_any,
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class RelationMeta(ContextMeta):
|
|
534
|
+
def __new__(cls, name, bases, dct): # noqa: ANN002, ANN001, ANN003
|
|
535
|
+
# Make Relation class allow extra fields but subclasses default to ignore (pydantic default)
|
|
536
|
+
# BUG: Doesn't allow users to define their own "extra" behavior
|
|
537
|
+
if name == "Relation":
|
|
538
|
+
dct["model_config"] = ConfigDict(extra="allow")
|
|
539
|
+
elif "model_config" not in dct:
|
|
540
|
+
dct["model_config"] = ConfigDict(extra="ignore")
|
|
541
|
+
elif dct["model_config"].get("extra", None) is None:
|
|
542
|
+
dct["model_config"]["extra"] = "ignore"
|
|
543
|
+
|
|
544
|
+
return super().__new__(cls, name, bases, dct)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
class Relation(Context, metaclass=RelationMeta):
|
|
548
|
+
"""
|
|
549
|
+
Base class for all Relation objects.
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
_start_node: t.Optional[Context] = PrivateAttr(default=None)
|
|
553
|
+
_end_node: t.Optional[Context] = PrivateAttr(default=None)
|
|
554
|
+
|
|
555
|
+
start_node: t.Optional[int] = None
|
|
556
|
+
end_node: t.Optional[int] = None
|
|
557
|
+
|
|
558
|
+
@t.overload
|
|
559
|
+
def __init__(self, start_node: Context | int, end_node: Context | int, **data: Any) -> "Relation": ...
|
|
560
|
+
|
|
561
|
+
@model_validator(mode="wrap")
|
|
562
|
+
@classmethod
|
|
563
|
+
def truncate(cls, data: Any, handler: ValidatorFunctionWrapHandler) -> "Relation":
|
|
564
|
+
"""
|
|
565
|
+
Truncates the start_node and end_node to their gqlalchemy ids.
|
|
566
|
+
"""
|
|
567
|
+
ids = {}
|
|
568
|
+
objs = {}
|
|
569
|
+
for name in ["start_node", "end_node"]:
|
|
570
|
+
node = data[name]
|
|
571
|
+
if isinstance(node, Context):
|
|
572
|
+
ids[name] = node._gqlalchemy_id
|
|
573
|
+
objs[name] = node
|
|
574
|
+
else:
|
|
575
|
+
ids[name] = node
|
|
576
|
+
objs[name] = None
|
|
577
|
+
data["start_node"] = ids["start_node"]
|
|
578
|
+
data["end_node"] = ids["end_node"]
|
|
579
|
+
self = handler(data)
|
|
580
|
+
self._start_node = objs["start_node"]
|
|
581
|
+
self._end_node = objs["end_node"]
|
|
582
|
+
return self
|
|
583
|
+
|
|
584
|
+
def get_start_node_obj(self, db: "GraphDatabase") -> Context:
|
|
585
|
+
"""
|
|
586
|
+
Get the start node object of the relation as a Context object.
|
|
587
|
+
Args:
|
|
588
|
+
db: The GraphDatabase instance to use to fetch the start node.
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
The start node object as a Context object.
|
|
592
|
+
"""
|
|
593
|
+
if self._start_node:
|
|
594
|
+
return self._start_node
|
|
595
|
+
else:
|
|
596
|
+
return Context.from_gqlalchemy(
|
|
597
|
+
next(db.execute_and_fetch(f"MATCH (n) WHERE id(n) = {self.start_node} RETURN n"))
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
def get_end_node_obj(self, db: "GraphDatabase") -> Context:
|
|
601
|
+
"""
|
|
602
|
+
Get the end node object of the relation as a Context object.
|
|
603
|
+
Args:
|
|
604
|
+
db: The GraphDatabase instance to use to fetch the end node.
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
The end node object as a Context object.
|
|
608
|
+
"""
|
|
609
|
+
if self._end_node:
|
|
610
|
+
return self.end_node
|
|
611
|
+
else:
|
|
612
|
+
return Context.from_gqlalchemy(
|
|
613
|
+
next(db.execute_and_fetch(f"MATCH (n) WHERE id(n) = {self.end_node} RETURN n"))
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
@field_validator("start_node", "end_node")
|
|
617
|
+
@classmethod
|
|
618
|
+
def check_node(cls, v: Any) -> Context | int:
|
|
619
|
+
assert isinstance(v, (Context, int)), f"start_node/end_node must be a Context or int, got {type(v)}: {v}"
|
|
620
|
+
assert not isinstance(v, Relation), f"start_node/end_node cannot be a Relation object: {v}"
|
|
621
|
+
return v
|
|
622
|
+
|
|
623
|
+
@model_validator(mode="after")
|
|
624
|
+
def post_init(self) -> "Relation":
|
|
625
|
+
# Sets type for inline declaration of Relation objects
|
|
626
|
+
if type(self) is Relation:
|
|
627
|
+
assert "_type" in self.model_dump(), "Inline declaration of Relation objects must specify the '_type' field"
|
|
628
|
+
self._type = self.model_dump()["_type"]
|
|
629
|
+
return self
|
|
630
|
+
|
|
631
|
+
# other >> self
|
|
632
|
+
def __rrshift__(self, other: Context | int):
|
|
633
|
+
# left_node >> self >> right_node
|
|
634
|
+
self.start_node = other
|
|
635
|
+
return self
|
|
636
|
+
|
|
637
|
+
# self >> other
|
|
638
|
+
def __rshift__(self, other: Context | int):
|
|
639
|
+
# left_node >> self >> right_node
|
|
640
|
+
self.end_node = other
|
|
641
|
+
return self
|
|
642
|
+
|
|
643
|
+
# other << self
|
|
644
|
+
def __rlshift__(self, other: Context | int):
|
|
645
|
+
# left_node << self << right_node
|
|
646
|
+
self.end_node = other
|
|
647
|
+
return self
|
|
648
|
+
|
|
649
|
+
# self << other
|
|
650
|
+
def __lshift__(self, other: Context | int):
|
|
651
|
+
# left_node << self << right_node
|
|
652
|
+
self.start_node = other
|
|
653
|
+
return self
|
|
654
|
+
|
|
655
|
+
def __str__(self):
|
|
656
|
+
"""
|
|
657
|
+
Returns a string representation of the Relation object, including all fields and their values.
|
|
658
|
+
|
|
659
|
+
Returns:
|
|
660
|
+
str: String representation of the Relation object with all fields and their values.
|
|
661
|
+
"""
|
|
662
|
+
fields_repr = ", ".join(f"{k}={repr(v)}" for k, v in self.model_dump(include_hidden=True).items())
|
|
663
|
+
return f"{self.__class__._type}({fields_repr})"
|
|
664
|
+
|
|
665
|
+
def __repr__(self):
|
|
666
|
+
return self.__str__()
|
|
667
|
+
|
|
668
|
+
def to_gqlalchemy(self, db: "GraphDatabase") -> "gqlalchemy.Relationship":
|
|
669
|
+
"""
|
|
670
|
+
Convert the Context object to a GQLAlchemy object.
|
|
671
|
+
|
|
672
|
+
<Warning>Saves the start_node and end_node to the database if they are not already saved.</Warning>
|
|
673
|
+
|
|
674
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
675
|
+
Args:
|
|
676
|
+
db: The GraphDatabase instance to use to save the start_node and end_node if they are not already saved.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
The GQLAlchemy Relationship object.
|
|
680
|
+
|
|
681
|
+
Raises:
|
|
682
|
+
AssertionError: If db is not a modaic.databases.GraphDatabase instance.
|
|
683
|
+
ImportError: If GQLAlchemy is not installed.
|
|
684
|
+
|
|
685
|
+
"""
|
|
686
|
+
try:
|
|
687
|
+
import gqlalchemy
|
|
688
|
+
|
|
689
|
+
from modaic.databases.graph_database import GraphDatabase
|
|
690
|
+
except ImportError:
|
|
691
|
+
raise ImportError(
|
|
692
|
+
"GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
|
|
693
|
+
) from None
|
|
694
|
+
|
|
695
|
+
assert isinstance(db, GraphDatabase), (
|
|
696
|
+
f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
cls = self.__class__
|
|
700
|
+
|
|
701
|
+
# Dynamically create a GQLAlchemy Node class for the Context if it doesn't exist
|
|
702
|
+
if self._type not in cls._gqlalchemy_class_registry:
|
|
703
|
+
ad_hoc_annotations = get_ad_hoc_annotations(self) if cls is Relation else {}
|
|
704
|
+
field_annotations = get_annotations(
|
|
705
|
+
cls,
|
|
706
|
+
exclude=GQLALCHEMY_EXCLUDED_FIELDS + ["start_node", "end_node"],
|
|
707
|
+
)
|
|
708
|
+
field_defaults = get_defaults(
|
|
709
|
+
cls,
|
|
710
|
+
exclude=GQLALCHEMY_EXCLUDED_FIELDS + ["start_node", "end_node"],
|
|
711
|
+
)
|
|
712
|
+
gqlalchemy_class = type(
|
|
713
|
+
f"{cls.__name__}Rel",
|
|
714
|
+
(gqlalchemy.Relationship,),
|
|
715
|
+
{
|
|
716
|
+
"__annotations__": {
|
|
717
|
+
**ad_hoc_annotations,
|
|
718
|
+
**field_annotations,
|
|
719
|
+
"modaic_id": str,
|
|
720
|
+
},
|
|
721
|
+
"modaic_id": V1Field(unique=True, db=db._client),
|
|
722
|
+
**field_defaults,
|
|
723
|
+
},
|
|
724
|
+
type=self._type,
|
|
725
|
+
)
|
|
726
|
+
cls._gqlalchemy_class_registry[self._type] = gqlalchemy_class
|
|
727
|
+
|
|
728
|
+
gqlalchemy_class = cls._gqlalchemy_class_registry[self._type]
|
|
729
|
+
|
|
730
|
+
if self.start_node is not None and self.start_node_gql_id is None:
|
|
731
|
+
self.start_node.save(db)
|
|
732
|
+
if self.end_node is not None and self.end_node_gql_id is None:
|
|
733
|
+
self.end_node.save(db)
|
|
734
|
+
|
|
735
|
+
if self._gqlalchemy_id is None:
|
|
736
|
+
return gqlalchemy.Relationship.parse_obj(
|
|
737
|
+
{
|
|
738
|
+
"_type": self._type,
|
|
739
|
+
"modaic_id": self.id,
|
|
740
|
+
"_start_node_id": self.start_node_gql_id,
|
|
741
|
+
"_end_node_id": self.end_node_gql_id,
|
|
742
|
+
**self.model_dump(
|
|
743
|
+
exclude={"id", "start_node", "end_node", "_type"},
|
|
744
|
+
include_hidden=True,
|
|
745
|
+
),
|
|
746
|
+
}
|
|
747
|
+
)
|
|
748
|
+
else:
|
|
749
|
+
return gqlalchemy.Relationship.parse_obj(
|
|
750
|
+
{
|
|
751
|
+
"_type": self._type,
|
|
752
|
+
"modaic_id": self.id,
|
|
753
|
+
"_id": self._gqlalchemy_id,
|
|
754
|
+
"_start_node_id": self.start_node_gql_id,
|
|
755
|
+
"_end_node_id": self.end_node_gql_id,
|
|
756
|
+
**self.model_dump(
|
|
757
|
+
exclude={"id", "start_node", "end_node", "_type"},
|
|
758
|
+
include_hidden=True,
|
|
759
|
+
),
|
|
760
|
+
}
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
@classmethod
|
|
764
|
+
def from_gqlalchemy(cls, gqlalchemy_rel: "gqlalchemy.Relationship") -> "Relation":
|
|
765
|
+
"""
|
|
766
|
+
Convert a GQLAlchemy `Relationship` into a `Relation` instance. If `cls` is the `Relation` class itself, it will try to return an instance of a subclass of `Relation` that matches the type of the GQLAlchemy Relationship. If none are found it will fallback to an instance of `Relation` since the `Relation` class allows definiing inline.
|
|
767
|
+
If `cls` is instead a subclass of `Relation`, it will return an instance of that subclass and fail if the properties do not align.
|
|
768
|
+
Args:
|
|
769
|
+
gqlalchemy_obj: The GQLAlchemy Relationship to convert.
|
|
770
|
+
|
|
771
|
+
Raises:
|
|
772
|
+
ValueError: If the GQLAlchemy Relationship does not have the required fields.
|
|
773
|
+
AssertionError: If the GQLAlchemy Relationship does not have the required type.
|
|
774
|
+
|
|
775
|
+
Returns:
|
|
776
|
+
The converted Relation or Relation subclass instance.
|
|
777
|
+
"""
|
|
778
|
+
if cls is not Relation:
|
|
779
|
+
assert cls._type == gqlalchemy_rel._type, (
|
|
780
|
+
f"Cannot convert GQLAlchemy Relationship {gqlalchemy_rel} to {cls.__name__} because it does not have {cls.__name__}'s type: '{cls._type}'"
|
|
781
|
+
)
|
|
782
|
+
try:
|
|
783
|
+
kwargs = {**gqlalchemy_rel._properties}
|
|
784
|
+
kwargs["id"] = kwargs.pop("modaic_id")
|
|
785
|
+
kwargs["start_node"] = gqlalchemy_rel._start_node_id
|
|
786
|
+
kwargs["end_node"] = gqlalchemy_rel._end_node_id
|
|
787
|
+
new_relation = cls(**kwargs)
|
|
788
|
+
new_relation._gqlalchemy_id = gqlalchemy_rel._id
|
|
789
|
+
return new_relation
|
|
790
|
+
except ValidationError as e:
|
|
791
|
+
raise ValueError(
|
|
792
|
+
f"Failed to convert GQLAlchemy Relationship {gqlalchemy_rel} to {cls.__name__} because it does not have the required fields.\nError: {e}"
|
|
793
|
+
) from e
|
|
794
|
+
|
|
795
|
+
# If cls is Relation, we need to find the subclass of Relation that matches the type of the GQLAlchemy Relationship.
|
|
796
|
+
# CAVEAT: Relation is a subclass of Context, so we can just use the same Context._type_registry.
|
|
797
|
+
if subclass := Context._type_registry.get(gqlalchemy_rel._type):
|
|
798
|
+
assert issubclass(subclass, Relation), (
|
|
799
|
+
f"Found Relation subclass with matching type, but cannot convert GQLAlchemy Relationship {gqlalchemy_rel} to {subclass.__name__} because it is not a subclass of Relation"
|
|
800
|
+
)
|
|
801
|
+
return subclass.from_gqlalchemy(gqlalchemy_rel)
|
|
802
|
+
# If no subclass is found, we can just create a new Relation object with the properties of the GQLAlchemy Relationship.
|
|
803
|
+
else:
|
|
804
|
+
kwargs = {**gqlalchemy_rel._properties}
|
|
805
|
+
kwargs["id"] = kwargs.pop("modaic_id")
|
|
806
|
+
kwargs["start_node"] = gqlalchemy_rel._start_node_id
|
|
807
|
+
kwargs["end_node"] = gqlalchemy_rel._end_node_id
|
|
808
|
+
kwargs["_type"] = gqlalchemy_rel._type
|
|
809
|
+
new_relation = cls(**kwargs)
|
|
810
|
+
new_relation._gqlalchemy_id = gqlalchemy_rel._id
|
|
811
|
+
return new_relation
|
|
812
|
+
|
|
813
|
+
def save(self, db: "GraphDatabase"):
|
|
814
|
+
"""
|
|
815
|
+
Save the Relation object to the GraphDatabase.
|
|
816
|
+
|
|
817
|
+
!!! warning
|
|
818
|
+
This method is not thread safe. We are actively working on a solution to make it thread safe.
|
|
819
|
+
"""
|
|
820
|
+
|
|
821
|
+
try:
|
|
822
|
+
from modaic.databases.graph_database import GraphDatabase
|
|
823
|
+
except ImportError:
|
|
824
|
+
raise ImportError(
|
|
825
|
+
"GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
|
|
826
|
+
) from None
|
|
827
|
+
|
|
828
|
+
assert isinstance(db, GraphDatabase), (
|
|
829
|
+
f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
|
|
830
|
+
)
|
|
831
|
+
result = db.save_relationship(self)
|
|
832
|
+
for k in self.model_dump(exclude={"id", "start_node", "end_node"}, include_hidden=True):
|
|
833
|
+
setattr(self, k, getattr(result, k))
|
|
834
|
+
self._gqlalchemy_id = result._id
|
|
835
|
+
|
|
836
|
+
def load(self, db: "GraphDatabase"):
|
|
837
|
+
"""
|
|
838
|
+
Loads a relationship from GraphDatabase.
|
|
839
|
+
If the relationship._id is not None it fetches the relationship from GraphDatabase with that
|
|
840
|
+
internal id.
|
|
841
|
+
If the relationship has unique fields it fetches the relationship from GraphDatabase with
|
|
842
|
+
those unique fields set.
|
|
843
|
+
Otherwise it tries to find any relationship in GraphDatabase that has all properties
|
|
844
|
+
set to exactly the same values.
|
|
845
|
+
If no relationship is found or no properties are set it raises a GQLAlchemyError.
|
|
846
|
+
"""
|
|
847
|
+
raise NotImplementedError("Not implemented")
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def _cast_type_if_base_model(field_type: t.Type) -> t.Type:
|
|
851
|
+
"""
|
|
852
|
+
If field_type is a typing construct, reconstruct it from origin/args.
|
|
853
|
+
If it's a Pydantic BaseModel subclass, map it to `dict`.
|
|
854
|
+
Otherwise return the type itself.
|
|
855
|
+
"""
|
|
856
|
+
origin = t.get_origin(field_type)
|
|
857
|
+
|
|
858
|
+
# Non-typing constructs
|
|
859
|
+
if origin is None:
|
|
860
|
+
# Only call issubclass on real classes
|
|
861
|
+
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
|
862
|
+
return dict
|
|
863
|
+
return field_type
|
|
864
|
+
|
|
865
|
+
args = t.get_args(field_type)
|
|
866
|
+
|
|
867
|
+
# Annotated[T, m1, m2, ...] # noqa: ERA001
|
|
868
|
+
if origin is t.Annotated:
|
|
869
|
+
base, *meta = args
|
|
870
|
+
# Annotated allows multiple args; pass a tuple to __class_getitem__
|
|
871
|
+
return t.Annotated.__class_getitem__((_cast_type_if_base_model(base), *meta))
|
|
872
|
+
|
|
873
|
+
# Unions: typing.Union[...] or PEP 604 (A | B)
|
|
874
|
+
if origin in (t.Union, UnionType):
|
|
875
|
+
return t.Union[tuple(_cast_type_if_base_model(a) for a in args)]
|
|
876
|
+
|
|
877
|
+
# Literal / Final / ClassVar accept tuple args via typing protocol
|
|
878
|
+
if origin in (t.Literal, t.Final, t.ClassVar):
|
|
879
|
+
return origin.__getitem__([_cast_type_if_base_model(a) for a in args])
|
|
880
|
+
|
|
881
|
+
# Builtin generics (PEP 585): list[T], dict[K, V], set[T], tuple[...]
|
|
882
|
+
if origin in (list, set, frozenset):
|
|
883
|
+
(inner_type,) = args
|
|
884
|
+
return origin[_cast_type_if_base_model(inner_type)]
|
|
885
|
+
if origin is dict:
|
|
886
|
+
key_type, value_type = args
|
|
887
|
+
return dict[_cast_type_if_base_model(key_type), _cast_type_if_base_model(value_type)]
|
|
888
|
+
if origin is tuple:
|
|
889
|
+
# tuple[int, ...] vs tuple[int, str]
|
|
890
|
+
if len(args) == 2 and args[1] is Ellipsis:
|
|
891
|
+
return tuple[_cast_type_if_base_model(args[0]), ...]
|
|
892
|
+
return tuple[tuple([_cast_type_if_base_model(a) for a in args])] # tuple[(A, B, C)]
|
|
893
|
+
|
|
894
|
+
# ABC generics (e.g., Mapping, Sequence, Iterable, etc.) usually accept tuple args
|
|
895
|
+
try:
|
|
896
|
+
return origin.__class_getitem__([_cast_type_if_base_model(a) for a in args])
|
|
897
|
+
except Exception:
|
|
898
|
+
# Last resort: try simple unpack for 1–2 arity generics
|
|
899
|
+
if len(args) == 1:
|
|
900
|
+
return origin[_cast_type_if_base_model(args[0])]
|
|
901
|
+
elif len(args) == 2:
|
|
902
|
+
return origin[
|
|
903
|
+
_cast_type_if_base_model(args[0]),
|
|
904
|
+
_cast_type_if_base_model(args[1]),
|
|
905
|
+
]
|
|
906
|
+
raise
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def get_annotations(cls: t.Type, exclude: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Type]:
|
|
910
|
+
if exclude is None:
|
|
911
|
+
exclude = []
|
|
912
|
+
if not issubclass(cls, Context):
|
|
913
|
+
return {}
|
|
914
|
+
elif cls is Context:
|
|
915
|
+
res = {k: _cast_type_if_base_model(v) for k, v in cls.__annotations__.items() if k not in exclude}
|
|
916
|
+
return res
|
|
917
|
+
else:
|
|
918
|
+
annotations = {}
|
|
919
|
+
for base in cls.__bases__:
|
|
920
|
+
annotations.update(get_annotations(base, exclude))
|
|
921
|
+
annotations.update({k: _cast_type_if_base_model(v) for k, v in cls.__annotations__.items() if k not in exclude})
|
|
922
|
+
return annotations
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
def _cast_if_base_model(field_default: t.Any) -> t.Any:
|
|
926
|
+
if isinstance(field_default, BaseModel):
|
|
927
|
+
return field_default.model_dump()
|
|
928
|
+
return field_default
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def get_defaults(cls: t.Type[Context], exclude: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Any]:
|
|
932
|
+
if exclude is None:
|
|
933
|
+
exclude = []
|
|
934
|
+
defaults: t.Dict[str, t.Any] = {}
|
|
935
|
+
for name, v2_field in cls.model_fields.items():
|
|
936
|
+
if name in exclude or v2_field.is_required():
|
|
937
|
+
continue
|
|
938
|
+
kwargs = {}
|
|
939
|
+
if extra_kwargs := getattr(v2_field, "json_schema_extra", None):
|
|
940
|
+
kwargs.update(extra_kwargs)
|
|
941
|
+
|
|
942
|
+
factory = v2_field.default_factory
|
|
943
|
+
if factory is not None:
|
|
944
|
+
kwargs["default_factory"] = lambda f=factory: _cast_if_base_model(f())
|
|
945
|
+
else:
|
|
946
|
+
kwargs["default"] = _cast_if_base_model(v2_field.default)
|
|
947
|
+
|
|
948
|
+
v1_field = V1Field(**kwargs)
|
|
949
|
+
defaults[name] = v1_field
|
|
950
|
+
|
|
951
|
+
return defaults
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
def get_ad_hoc_annotations(rel: Relation) -> t.Dict[str, t.Type]:
|
|
955
|
+
"""
|
|
956
|
+
Gets "adhoc" annotations for a Relation object. Specifically, for when Relations are created inline.
|
|
957
|
+
(i.e. when you do `Relation(x="test", _type="TEST_REL")`).
|
|
958
|
+
This is for those fields that were decleared inline.
|
|
959
|
+
Args:
|
|
960
|
+
rel: The Relation object to get the adhoc annotations for.
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
A dictionary of the adhoc annotations.
|
|
964
|
+
"""
|
|
965
|
+
annotations = {}
|
|
966
|
+
for name, val in rel.model_dump(
|
|
967
|
+
exclude=GQLALCHEMY_EXCLUDED_FIELDS + ["start_node", "end_node"],
|
|
968
|
+
include_hidden=True,
|
|
969
|
+
).items():
|
|
970
|
+
if val is None:
|
|
971
|
+
annotations[name] = t.Any
|
|
972
|
+
elif isinstance(val, BaseModel):
|
|
973
|
+
annotations[name] = dict
|
|
974
|
+
else:
|
|
975
|
+
annotations[name] = type(val)
|
|
976
|
+
return annotations
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
@t.runtime_checkable
|
|
980
|
+
class Hydratable(t.Protocol):
|
|
981
|
+
def hydrate(self, file_store: FileStore) -> None:
|
|
982
|
+
pass
|
|
983
|
+
|
|
984
|
+
@classmethod
|
|
985
|
+
def from_file(cls, file: str, file_store: FileStore, type: str, params: dict = None) -> "Hydratable":
|
|
986
|
+
"""
|
|
987
|
+
Load a Hydratable instance from a file.
|
|
988
|
+
|
|
989
|
+
Args:
|
|
990
|
+
file: The file to load.
|
|
991
|
+
file_store: The file store to use.
|
|
992
|
+
type: The type of file to expect.
|
|
993
|
+
params: Extra parameters to pass to the constructor.
|
|
994
|
+
"""
|
|
995
|
+
pass
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
if t.TYPE_CHECKING:
|
|
999
|
+
# @runtime_checkable
|
|
1000
|
+
class HydratableContext(Hydratable, Context):
|
|
1001
|
+
pass
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
def is_hydratable(obj: t.Any) -> bool:
|
|
1005
|
+
return isinstance(obj, Hydratable) and isinstance(obj, Context)
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
@t.runtime_checkable
|
|
1009
|
+
class Embeddable(t.Protocol):
|
|
1010
|
+
"""
|
|
1011
|
+
A protocol for objects that can be embedded. These objects define the embedme function which can either return a string or an image.
|
|
1012
|
+
The embedme function can either take no args, or take an index name as an argument, which will be used to select the index to embed for.
|
|
1013
|
+
"""
|
|
1014
|
+
|
|
1015
|
+
@t.overload
|
|
1016
|
+
def embedme(self) -> str | Image.Image: ...
|
|
1017
|
+
|
|
1018
|
+
@t.overload
|
|
1019
|
+
def embedme(self, index: t.Optional[str] = None) -> str | Image.Image: ...
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
@t.runtime_checkable
|
|
1023
|
+
class MultiEmbeddable(t.Protocol):
|
|
1024
|
+
"""
|
|
1025
|
+
A protocol for objects that can be embedded and have multiple embeddings. These objects define the embedme function which can either return a string or an image.
|
|
1026
|
+
The embedme function can either take no args, or take an index name as an argument, which will be used to select the index to embed for.
|
|
1027
|
+
"""
|
|
1028
|
+
|
|
1029
|
+
@t.overload
|
|
1030
|
+
def embedme(self, index: t.Optional[str] = None) -> str | Image.Image: ...
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
def is_embeddable(obj: t.Any) -> bool:
|
|
1034
|
+
return isinstance(obj, Embeddable) and isinstance(obj, Context)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def is_multi_embeddable(obj: t.Any) -> bool:
|
|
1038
|
+
return isinstance(obj, MultiEmbeddable) and isinstance(obj, Context)
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
if t.TYPE_CHECKING:
|
|
1042
|
+
# @runtime_checkable
|
|
1043
|
+
class EmbeddableContext(t.Protocol, Context):
|
|
1044
|
+
pass
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
def _update_exclude(exclude: IncEx, hidden: t.Set[str]):
|
|
1048
|
+
if isinstance(exclude, set):
|
|
1049
|
+
return exclude.update(hidden)
|
|
1050
|
+
else: # NOTE: if not a set, it's a dict
|
|
1051
|
+
return exclude.update({k: True for k in hidden})
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def _dump_hidden_recursive(obj: t.Any):
|
|
1055
|
+
if isinstance(obj, BaseModel):
|
|
1056
|
+
return obj.model_dump(include_hidden=True)
|
|
1057
|
+
elif isinstance(obj, Context):
|
|
1058
|
+
return obj.model_dump(include_hidden=True)
|
|
1059
|
+
elif isinstance(obj, list):
|
|
1060
|
+
return [_dump_hidden_recursive(item) for item in obj]
|
|
1061
|
+
elif isinstance(obj, dict):
|
|
1062
|
+
return {k: _dump_hidden_recursive(v) for k, v in obj.items()}
|
|
1063
|
+
else:
|
|
1064
|
+
return obj
|