modaic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modaic might be problematic. Click here for more details.

Files changed (39) hide show
  1. modaic/__init__.py +25 -0
  2. modaic/agents/rag_agent.py +33 -0
  3. modaic/agents/registry.py +84 -0
  4. modaic/auto_agent.py +228 -0
  5. modaic/context/__init__.py +34 -0
  6. modaic/context/base.py +1064 -0
  7. modaic/context/dtype_mapping.py +25 -0
  8. modaic/context/table.py +585 -0
  9. modaic/context/text.py +94 -0
  10. modaic/databases/__init__.py +35 -0
  11. modaic/databases/graph_database.py +269 -0
  12. modaic/databases/sql_database.py +355 -0
  13. modaic/databases/vector_database/__init__.py +12 -0
  14. modaic/databases/vector_database/benchmarks/baseline.py +123 -0
  15. modaic/databases/vector_database/benchmarks/common.py +48 -0
  16. modaic/databases/vector_database/benchmarks/fork.py +132 -0
  17. modaic/databases/vector_database/benchmarks/threaded.py +119 -0
  18. modaic/databases/vector_database/vector_database.py +722 -0
  19. modaic/databases/vector_database/vendors/milvus.py +408 -0
  20. modaic/databases/vector_database/vendors/mongodb.py +0 -0
  21. modaic/databases/vector_database/vendors/pinecone.py +0 -0
  22. modaic/databases/vector_database/vendors/qdrant.py +1 -0
  23. modaic/exceptions.py +38 -0
  24. modaic/hub.py +305 -0
  25. modaic/indexing.py +127 -0
  26. modaic/module_utils.py +341 -0
  27. modaic/observability.py +275 -0
  28. modaic/precompiled.py +429 -0
  29. modaic/query_language.py +321 -0
  30. modaic/storage/__init__.py +3 -0
  31. modaic/storage/file_store.py +239 -0
  32. modaic/storage/pickle_store.py +25 -0
  33. modaic/types.py +287 -0
  34. modaic/utils.py +21 -0
  35. modaic-0.1.0.dist-info/METADATA +281 -0
  36. modaic-0.1.0.dist-info/RECORD +39 -0
  37. modaic-0.1.0.dist-info/WHEEL +5 -0
  38. modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
  39. modaic-0.1.0.dist-info/top_level.txt +1 -0
modaic/context/base.py ADDED
@@ -0,0 +1,1064 @@
1
+ import copy
2
+ import typing as t
3
+ import uuid
4
+ from functools import lru_cache, wraps
5
+ from types import UnionType
6
+ from typing import Any, Literal
7
+
8
+ from PIL import Image
9
+ from pydantic import (
10
+ BaseModel,
11
+ ConfigDict,
12
+ PrivateAttr,
13
+ SerializationInfo,
14
+ SerializerFunctionWrapHandler,
15
+ ValidationError,
16
+ ValidatorFunctionWrapHandler,
17
+ field_validator,
18
+ model_serializer,
19
+ model_validator,
20
+ )
21
+ from pydantic._internal._model_construction import ModelMetaclass
22
+ from pydantic.fields import ModelPrivateAttr
23
+ from pydantic.main import IncEx
24
+ from pydantic.v1 import Field as V1Field
25
+ from pydantic_core import CoreSchema, SchemaSerializer
26
+
27
+ from ..query_language import Prop
28
+ from ..storage.file_store import FileStore
29
+ from ..types import Field, Schema
30
+
31
+ if t.TYPE_CHECKING:
32
+ import gqlalchemy
33
+
34
+ from modaic.databases.graph_database import GraphDatabase
35
+ from modaic.storage.file_store import FileStore
36
+
37
+
38
+ GQLALCHEMY_EXCLUDED_FIELDS = [
39
+ "id",
40
+ "_gqlalchemy_id",
41
+ "_type_registry",
42
+ "_labels",
43
+ "_gqlalchemy_class_registry",
44
+ "_type",
45
+ ]
46
+
47
+
48
+ class ModaicHydrationError(Exception):
49
+ """Error raised when a function tries to use a Context param that is not hydrated."""
50
+
51
+ pass
52
+
53
+
54
+ class ModelHydratedAttr(ModelPrivateAttr):
55
+ def __init__(self):
56
+ super().__init__(default=None, default_factory=None)
57
+
58
+
59
+ def HydratedAttr(): # noqa: N802, ANN201
60
+ """
61
+ Created a hydrated field. Hydrated fields are fields that are None by default and are hydrated by Context.hydrate()
62
+ """
63
+ return ModelHydratedAttr()
64
+
65
+
66
+ def requires_hydration(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
67
+ """
68
+ Decorator that ensures all hydrated attributes are set before calling the function.
69
+
70
+ Args:
71
+ func: The method being wrapped.
72
+
73
+ Returns:
74
+ The wrapped method that raises if any hydrated attribute is None.
75
+ """
76
+
77
+ @wraps(func)
78
+ def wrapper(*args, **kwargs):
79
+ self = args[0]
80
+
81
+ for attr in self.__class__.__hydrated_attributes__:
82
+ if getattr(self, attr) is None:
83
+ raise ModaicHydrationError(
84
+ f"Attribute {attr} is not hydrated. Please call `self.hydrate()` to hydrate the attribute."
85
+ )
86
+ return func(*args, **kwargs)
87
+
88
+ return wrapper
89
+
90
+
91
+ def _get_unhidden_serializer(cls: type[BaseModel]) -> SchemaSerializer:
92
+ """
93
+ Creates a new serializer from cls.__pydantic_core_schema__ with the hidden fields included.
94
+ This is nescesarry to recursively dump hidden Context objects with hidden fields inside of other context objects.
95
+ """
96
+ core = copy.deepcopy(cls.__pydantic_core_schema__)
97
+
98
+ def walk(node: dict | list):
99
+ if isinstance(node, dict):
100
+ if (
101
+ node.get("type") == "model"
102
+ and node.get("serialization", {}).get("function", None) is Context.hidden_serializer
103
+ ):
104
+ del node["serialization"]
105
+ for v in node.values():
106
+ walk(v)
107
+ elif isinstance(node, list):
108
+ for v in node:
109
+ walk(v)
110
+
111
+ walk(core)
112
+ return SchemaSerializer(core)
113
+
114
+
115
+ class ContextMeta(ModelMetaclass):
116
+ def __getattr__(cls, name: str) -> t.Any: # noqa: N805
117
+ """
118
+ Enablees the creation of Prop classes via ContextClass.property_name. Does this in a safe way that doesn't conflict with pydantic's own metaclass.
119
+ """
120
+ # 1) Let Pydantic's own metaclass handle private attrs etc.
121
+ try:
122
+ return ModelMetaclass.__getattr__(cls, name)
123
+ except AttributeError:
124
+ pass # not a private attr; continue
125
+
126
+ # 2) Safely look up fields without triggering descriptors or our __getattr__ again
127
+ d = type.__getattribute__(cls, "__dict__")
128
+ fields = d.get("__pydantic_fields__")
129
+ if fields and name in fields:
130
+ return Prop(name) # FieldInfo (or whatever Pydantic stores)
131
+
132
+ # 3) Not a field either
133
+ raise AttributeError(name)
134
+
135
+
136
+ class Context(BaseModel, metaclass=ContextMeta):
137
+ """
138
+ Base class for all Context objects.
139
+
140
+ Attributes:
141
+ id: The id of the serialized context.
142
+ source: The source of the context object.
143
+ metadata: The metadata of the context object.
144
+
145
+ Example:
146
+ In this example, `CaptionedImage` stores the caption and the caption embedding the image path and the image itself. Since we can't serialize the image, we use the `HydratedAttr` decorator to mark the `_image` field as requiring hydration.
147
+ ```python
148
+ from modaic.context import Context
149
+ from modaic.types import String, Vector, Float16Vector
150
+
151
+ class CaptionedImage(Context):
152
+ caption: String[100]
153
+ caption_embedding: Float16Vector[384]
154
+ _image: PIL.Image.Image = HydratedAttr()
155
+
156
+ def hydrate(self, file_store: FileStore):
157
+ image_path = file_store.get_files(self.id)["image"]
158
+ self._image = PIL.Image.open(image_path)
159
+
160
+ ```
161
+ """
162
+
163
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()), hidden=True)
164
+ parent: t.Optional[str] = Field(default=None, hidden=True)
165
+ metadata: dict = Field(default_factory=dict, hidden=True)
166
+
167
+ _gqlalchemy_id: t.Optional[int] = PrivateAttr(default=None)
168
+ _chunks: t.List["Context"] = PrivateAttr(default_factory=list)
169
+
170
+ # CAVEAT: All Context subclasses share the same instance of _type_registry. This is intentional.
171
+ _type_registry: t.ClassVar[t.Dict[str, t.Type["Context"]]] = {}
172
+ _labels: t.ClassVar[frozenset[str]] = frozenset()
173
+ _gqlalchemy_class_registry: t.ClassVar[t.Dict[str, t.Type["gqlalchemy.models.GraphObject"]]] = {}
174
+
175
+ def __init_subclass__(cls, **kwargs: t.Any) -> None:
176
+ """Allow class-header keywords without raising TypeError.
177
+
178
+ Args:
179
+ **kwargs: Arbitrary keywords from subclass declarations (e.g., type="Label").
180
+ """
181
+ super().__init_subclass__()
182
+
183
+ @classmethod
184
+ def __pydantic_init_subclass__(cls, **kwargs):
185
+ if "type" in kwargs:
186
+ cls._type = kwargs["type"]
187
+ else:
188
+ cls._type = cls.__name__
189
+
190
+ assert cls._type != "Node" and cls._type != "Relationship", (
191
+ f"Class {cls.__name__} cannot use name 'Node' or 'Relationship' as type. Please use a different name. You can use a custom type by using the 'type' keyword. Example: `class {cls.__name__}(Context, type=<custom_type_name>)`"
192
+ )
193
+
194
+ # TODO: revisit this. Should we allow multiple parents?
195
+ # Get parent class labels
196
+ parent_labels = [b._labels for b in cls.__bases__ if issubclass(b, Context)]
197
+ assert len(parent_labels) == 1, (
198
+ f"Context class {cls.__name__} cannot have multiple Context classes as parents. Should it? Submit an issue to tell us about your use case. https://github.com/modaic-ai/modaic/issues"
199
+ )
200
+ cls._labels = frozenset({cls._type}) | parent_labels[0]
201
+ assert cls._type not in cls._type_registry, (
202
+ f"Cannot have multiple Context/Relation classes with type = '{cls._type}'"
203
+ )
204
+ cls._type_registry[cls._type] = cls
205
+
206
+ cls.__hydrated_attributes__ = set()
207
+ for name in (private_attrs := cls.__private_attributes__):
208
+ if isinstance(private_attrs[name], ModelHydratedAttr):
209
+ cls.__hydrated_attributes__.add(name)
210
+
211
+ cls.__modaic_serializer__ = _get_unhidden_serializer(cls)
212
+
213
+ def __str__(self) -> str:
214
+ """
215
+ Returns a string representation of the Context instance, including all field values.
216
+
217
+ Returns:
218
+ str: String representation with all field values.
219
+ """
220
+ values = self.model_dump(mode="json", include_hidden=True)
221
+ return f"{self.__class__._type}({values})"
222
+
223
+ def __repr__(self):
224
+ return self.__str__()
225
+
226
+ def to_gqlalchemy(self, db: "GraphDatabase") -> "gqlalchemy.Node":
227
+ """
228
+ Convert the Context object to a GQLAlchemy object.
229
+ !!! warning
230
+ This method is not thread safe. We are actively working on a solution to make it thread safe.
231
+ """
232
+ try:
233
+ import gqlalchemy
234
+
235
+ from modaic.databases.graph_database import GraphDatabase
236
+ except ImportError:
237
+ raise ImportError(
238
+ "GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
239
+ ) from None
240
+ assert isinstance(db, GraphDatabase), (
241
+ f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
242
+ )
243
+ cls = self.__class__
244
+
245
+ # Dynamically create a GQLAlchemy Node class for the Context if it doesn't exist
246
+ if cls._type not in cls._gqlalchemy_class_registry:
247
+ field_annotations = get_annotations(
248
+ cls,
249
+ exclude=GQLALCHEMY_EXCLUDED_FIELDS,
250
+ )
251
+ field_defaults = get_defaults(cls, exclude=GQLALCHEMY_EXCLUDED_FIELDS)
252
+ gqlalchemy_class = type(
253
+ f"{cls.__name__}Node",
254
+ (gqlalchemy.Node,),
255
+ {
256
+ "__annotations__": {**field_annotations, "modaic_id": str},
257
+ "modaic_id": V1Field(unique=True, db=db._client),
258
+ **field_defaults,
259
+ },
260
+ label=cls._type,
261
+ )
262
+ cls._gqlalchemy_class_registry[cls._type] = gqlalchemy_class
263
+ # Return a new GQLAlchemy Node object
264
+ gqlalchemy_class = cls._gqlalchemy_class_registry[cls._type]
265
+ if self._gqlalchemy_id is None:
266
+ return gqlalchemy_class(
267
+ _labels=set(self._labels),
268
+ modaic_id=self.id,
269
+ **self.model_dump(exclude={"id"}, include_hidden=True),
270
+ )
271
+ else:
272
+ return gqlalchemy_class(
273
+ _labels=set(self._labels),
274
+ modaic_id=self.id,
275
+ _id=self._gqlalchemy_id,
276
+ **self.model_dump(exclude={"id"}, include_hidden=True),
277
+ )
278
+
279
+ @classmethod
280
+ def from_gqlalchemy(cls, gqlalchemy_node: "gqlalchemy.Node") -> "Context":
281
+ """
282
+ Convert a GQLAlchemy Node into a `Context` instance. If cls is the Context class itself, it will return the best subclass of Context that matches the labels of the GQLAlchemy Node.
283
+ Args:
284
+ gqlalchemy_node: The GQLAlchemy Node to convert.
285
+
286
+ Returns:
287
+ The converted Context or Context subclass instance.
288
+
289
+ """
290
+ if cls is not Context:
291
+ if cls._type not in gqlalchemy_node._labels:
292
+ raise ValueError(
293
+ f"Cannot convert GQLAlchemy Node {gqlalchemy_node} to {cls.__name__} because it does not have the label '{cls._type}'"
294
+ )
295
+
296
+ try:
297
+ kwargs = {**gqlalchemy_node._properties}
298
+ modaic_id = kwargs.pop("modaic_id")
299
+ kwargs["id"] = modaic_id
300
+ context_obj = cls(**kwargs)
301
+ context_obj._gqlalchemy_id = gqlalchemy_node._id
302
+ return context_obj
303
+ except ValidationError as e:
304
+ raise ValueError(
305
+ f"Failed to convert GQLAlchemy Node {gqlalchemy_node} to {cls.__name__} because it does not have the required fields.\nError: {e}"
306
+ ) from e
307
+
308
+ # If cls is Context, we need to find the best subclass of Context that matches the labels of the GQLAlchemy Node.
309
+ best_subclass = Context._best_subclass(frozenset(gqlalchemy_node._labels))
310
+ return best_subclass.from_gqlalchemy(gqlalchemy_node)
311
+
312
+ def save(self, db: "GraphDatabase"):
313
+ """
314
+ Save the Context object to the graph database.
315
+
316
+ !!! warning
317
+ This method is not thread safe. We are actively working on a solution to make it thread safe.
318
+ """
319
+ try:
320
+ from modaic.databases.graph_database import GraphDatabase
321
+ except ImportError:
322
+ raise ImportError(
323
+ "GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
324
+ ) from None
325
+
326
+ assert isinstance(db, GraphDatabase), (
327
+ f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
328
+ )
329
+
330
+ result = db.save_node(self)
331
+
332
+ for k in self.model_dump(exclude={"id"}, include_hidden=True):
333
+ setattr(self, k, getattr(result, k))
334
+ self._gqlalchemy_id = result._id
335
+
336
+ def load(self, database: "GraphDatabase"):
337
+ """
338
+ Loads a node from Memgraph.
339
+ If the node._id is not None it fetches the node from Memgraph with that
340
+ internal id.
341
+ If the node has unique fields it fetches the node from Memgraph with
342
+ those unique fields set.
343
+ Otherwise it tries to find any node in Memgraph that has all properties
344
+ set to exactly the same values.
345
+ If no node is found or no properties are set it raises a GQLAlchemyError.
346
+ """
347
+ raise NotImplementedError("Not implemented")
348
+
349
+ @staticmethod
350
+ @lru_cache
351
+ def _best_subclass(labels: t.FrozenSet[str]) -> t.Type["Context"]:
352
+ best_subclass = None
353
+ for label in labels:
354
+ if current_subclass := Context._type_registry.get(label):
355
+ # check if the current subclass has more parents than the best subclass
356
+ if best_subclass is None or len(current_subclass.__mro__) > len(best_subclass.__mro__):
357
+ best_subclass = current_subclass
358
+
359
+ if best_subclass is None:
360
+ raise ValueError(f"Cannot find a matching Context class for labels: {labels}")
361
+ return best_subclass
362
+
363
+ # TODO: Make iterable-friendly
364
+ def chunk_with(
365
+ self,
366
+ chunk_fn: t.Callable[["Context"], t.Iterable["Context"]],
367
+ kwargs: t.Optional[t.Dict] = None,
368
+ ) -> None:
369
+ """
370
+ Chunks the context object into a list of context objects.
371
+ """
372
+ if kwargs is None:
373
+ kwargs = {}
374
+ self._chunks = list(chunk_fn(self, **kwargs))
375
+ for chunk in self._chunks:
376
+ chunk.parent = self.id
377
+
378
+ def apply_to_chunks(self, apply_fn: t.Callable[["Context"], None], **kwargs):
379
+ """
380
+ Applies apply_fn to each chunk in chunks.
381
+
382
+ Args:
383
+ apply_fn: The function to apply to each chunk. Function should take in a Context object and mutate it.
384
+ **kwargs: Additional keyword arguments to pass to apply_fn.
385
+ """
386
+ for chunk in self.chunks:
387
+ apply_fn(chunk, **kwargs)
388
+
389
+ @property
390
+ def chunks(self) -> t.List["Context"]:
391
+ """
392
+ Returns the chunks of the context object.
393
+ """
394
+ return self._chunks
395
+
396
+ @property
397
+ def is_hydrated(self) -> bool:
398
+ """
399
+ Returns True if the context object is hydrated.
400
+ """
401
+ if not hasattr(self, "__hydrated_attributes__"):
402
+ return True
403
+ return all(getattr(self, attr) is not None for attr in self.__hydrated_attributes__)
404
+
405
+ @classmethod
406
+ def schema(cls) -> Schema:
407
+ if hasattr(cls, "_schema"):
408
+ return cls._schema
409
+ cls._schema = Schema.from_json_schema(cls.model_json_schema())
410
+ return cls._schema
411
+
412
+ @model_serializer(mode="wrap")
413
+ def hidden_serializer(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> dict[str, Any]:
414
+ dump = handler(self)
415
+ if info.context is None or not info.context.get("include_hidden"):
416
+ for name, field in self.__class__.model_fields.items():
417
+ if (extra := getattr(field, "json_schema_extra", None)) and extra.get("hidden"):
418
+ dump.pop(name, None)
419
+ return dump
420
+
421
+ def model_dump(
422
+ self,
423
+ *,
424
+ mode: str | Literal["json", "python"] = "python",
425
+ include: IncEx | None = None,
426
+ exclude: IncEx | None = None,
427
+ context: Any | None = None,
428
+ by_alias: bool | None = None,
429
+ exclude_unset: bool = False,
430
+ exclude_defaults: bool = False,
431
+ exclude_none: bool = False,
432
+ round_trip: bool = False,
433
+ warnings: bool | Literal["none", "warn", "error"] = True,
434
+ fallback: t.Callable[[Any], Any] | None = None,
435
+ serialize_as_any: bool = False,
436
+ include_hidden: bool = False,
437
+ ) -> Any:
438
+ """
439
+ Override of pydantics BaseModel.model_dump to allow for showing hidden fields
440
+
441
+ Args:
442
+ include_hidden: Whether to show hidden fields.
443
+
444
+ Returns:
445
+ The dictionary representation of the model.
446
+ """
447
+ if include_hidden:
448
+ return self.__modaic_serializer__.to_python(
449
+ self,
450
+ mode=mode,
451
+ include=include,
452
+ exclude=exclude,
453
+ context=context,
454
+ by_alias=by_alias,
455
+ exclude_unset=exclude_unset,
456
+ exclude_defaults=exclude_defaults,
457
+ exclude_none=exclude_none,
458
+ round_trip=round_trip,
459
+ warnings=warnings,
460
+ fallback=fallback,
461
+ serialize_as_any=False,
462
+ )
463
+
464
+ else:
465
+ return super().model_dump(
466
+ mode=mode,
467
+ include=include,
468
+ exclude=exclude,
469
+ context=context,
470
+ by_alias=by_alias,
471
+ exclude_unset=exclude_unset,
472
+ exclude_defaults=exclude_defaults,
473
+ exclude_none=exclude_none,
474
+ round_trip=round_trip,
475
+ warnings=warnings,
476
+ fallback=fallback,
477
+ serialize_as_any=serialize_as_any,
478
+ )
479
+
480
+ def model_dump_json(
481
+ self,
482
+ *,
483
+ indent: int | None = None,
484
+ include: IncEx | None = None,
485
+ exclude: IncEx | None = None,
486
+ context: Any | None = None,
487
+ by_alias: bool | None = None,
488
+ exclude_unset: bool = False,
489
+ exclude_defaults: bool = False,
490
+ exclude_none: bool = False,
491
+ round_trip: bool = False,
492
+ warnings: bool | Literal["none", "warn", "error"] = True,
493
+ fallback: t.Callable[[Any], Any] | None = None,
494
+ serialize_as_any: bool = False,
495
+ include_hidden: bool = False,
496
+ ) -> bytes | str:
497
+ """
498
+ Override of pydantic's BaseModel.model_dump_json to allow for showing hidden fields
499
+ """
500
+ if include_hidden:
501
+ return self.__modaic_serializer__.to_json(
502
+ self,
503
+ indent=indent,
504
+ include=include,
505
+ exclude=exclude,
506
+ context=context,
507
+ by_alias=by_alias,
508
+ exclude_unset=exclude_unset,
509
+ exclude_defaults=exclude_defaults,
510
+ exclude_none=exclude_none,
511
+ round_trip=round_trip,
512
+ warnings=warnings,
513
+ fallback=fallback,
514
+ serialize_as_any=True,
515
+ )
516
+ else:
517
+ return super().model_dump_json(
518
+ indent=indent,
519
+ include=include,
520
+ exclude=exclude,
521
+ context=context,
522
+ by_alias=by_alias,
523
+ exclude_unset=exclude_unset,
524
+ exclude_defaults=exclude_defaults,
525
+ exclude_none=exclude_none,
526
+ round_trip=round_trip,
527
+ warnings=warnings,
528
+ fallback=fallback,
529
+ serialize_as_any=serialize_as_any,
530
+ )
531
+
532
+
533
+ class RelationMeta(ContextMeta):
534
+ def __new__(cls, name, bases, dct): # noqa: ANN002, ANN001, ANN003
535
+ # Make Relation class allow extra fields but subclasses default to ignore (pydantic default)
536
+ # BUG: Doesn't allow users to define their own "extra" behavior
537
+ if name == "Relation":
538
+ dct["model_config"] = ConfigDict(extra="allow")
539
+ elif "model_config" not in dct:
540
+ dct["model_config"] = ConfigDict(extra="ignore")
541
+ elif dct["model_config"].get("extra", None) is None:
542
+ dct["model_config"]["extra"] = "ignore"
543
+
544
+ return super().__new__(cls, name, bases, dct)
545
+
546
+
547
+ class Relation(Context, metaclass=RelationMeta):
548
+ """
549
+ Base class for all Relation objects.
550
+ """
551
+
552
+ _start_node: t.Optional[Context] = PrivateAttr(default=None)
553
+ _end_node: t.Optional[Context] = PrivateAttr(default=None)
554
+
555
+ start_node: t.Optional[int] = None
556
+ end_node: t.Optional[int] = None
557
+
558
+ @t.overload
559
+ def __init__(self, start_node: Context | int, end_node: Context | int, **data: Any) -> "Relation": ...
560
+
561
+ @model_validator(mode="wrap")
562
+ @classmethod
563
+ def truncate(cls, data: Any, handler: ValidatorFunctionWrapHandler) -> "Relation":
564
+ """
565
+ Truncates the start_node and end_node to their gqlalchemy ids.
566
+ """
567
+ ids = {}
568
+ objs = {}
569
+ for name in ["start_node", "end_node"]:
570
+ node = data[name]
571
+ if isinstance(node, Context):
572
+ ids[name] = node._gqlalchemy_id
573
+ objs[name] = node
574
+ else:
575
+ ids[name] = node
576
+ objs[name] = None
577
+ data["start_node"] = ids["start_node"]
578
+ data["end_node"] = ids["end_node"]
579
+ self = handler(data)
580
+ self._start_node = objs["start_node"]
581
+ self._end_node = objs["end_node"]
582
+ return self
583
+
584
+ def get_start_node_obj(self, db: "GraphDatabase") -> Context:
585
+ """
586
+ Get the start node object of the relation as a Context object.
587
+ Args:
588
+ db: The GraphDatabase instance to use to fetch the start node.
589
+
590
+ Returns:
591
+ The start node object as a Context object.
592
+ """
593
+ if self._start_node:
594
+ return self._start_node
595
+ else:
596
+ return Context.from_gqlalchemy(
597
+ next(db.execute_and_fetch(f"MATCH (n) WHERE id(n) = {self.start_node} RETURN n"))
598
+ )
599
+
600
+ def get_end_node_obj(self, db: "GraphDatabase") -> Context:
601
+ """
602
+ Get the end node object of the relation as a Context object.
603
+ Args:
604
+ db: The GraphDatabase instance to use to fetch the end node.
605
+
606
+ Returns:
607
+ The end node object as a Context object.
608
+ """
609
+ if self._end_node:
610
+ return self.end_node
611
+ else:
612
+ return Context.from_gqlalchemy(
613
+ next(db.execute_and_fetch(f"MATCH (n) WHERE id(n) = {self.end_node} RETURN n"))
614
+ )
615
+
616
+ @field_validator("start_node", "end_node")
617
+ @classmethod
618
+ def check_node(cls, v: Any) -> Context | int:
619
+ assert isinstance(v, (Context, int)), f"start_node/end_node must be a Context or int, got {type(v)}: {v}"
620
+ assert not isinstance(v, Relation), f"start_node/end_node cannot be a Relation object: {v}"
621
+ return v
622
+
623
+ @model_validator(mode="after")
624
+ def post_init(self) -> "Relation":
625
+ # Sets type for inline declaration of Relation objects
626
+ if type(self) is Relation:
627
+ assert "_type" in self.model_dump(), "Inline declaration of Relation objects must specify the '_type' field"
628
+ self._type = self.model_dump()["_type"]
629
+ return self
630
+
631
+ # other >> self
632
+ def __rrshift__(self, other: Context | int):
633
+ # left_node >> self >> right_node
634
+ self.start_node = other
635
+ return self
636
+
637
+ # self >> other
638
+ def __rshift__(self, other: Context | int):
639
+ # left_node >> self >> right_node
640
+ self.end_node = other
641
+ return self
642
+
643
+ # other << self
644
+ def __rlshift__(self, other: Context | int):
645
+ # left_node << self << right_node
646
+ self.end_node = other
647
+ return self
648
+
649
+ # self << other
650
+ def __lshift__(self, other: Context | int):
651
+ # left_node << self << right_node
652
+ self.start_node = other
653
+ return self
654
+
655
+ def __str__(self):
656
+ """
657
+ Returns a string representation of the Relation object, including all fields and their values.
658
+
659
+ Returns:
660
+ str: String representation of the Relation object with all fields and their values.
661
+ """
662
+ fields_repr = ", ".join(f"{k}={repr(v)}" for k, v in self.model_dump(include_hidden=True).items())
663
+ return f"{self.__class__._type}({fields_repr})"
664
+
665
+ def __repr__(self):
666
+ return self.__str__()
667
+
668
+ def to_gqlalchemy(self, db: "GraphDatabase") -> "gqlalchemy.Relationship":
669
+ """
670
+ Convert the Context object to a GQLAlchemy object.
671
+
672
+ <Warning>Saves the start_node and end_node to the database if they are not already saved.</Warning>
673
+
674
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
675
+ Args:
676
+ db: The GraphDatabase instance to use to save the start_node and end_node if they are not already saved.
677
+
678
+ Returns:
679
+ The GQLAlchemy Relationship object.
680
+
681
+ Raises:
682
+ AssertionError: If db is not a modaic.databases.GraphDatabase instance.
683
+ ImportError: If GQLAlchemy is not installed.
684
+
685
+ """
686
+ try:
687
+ import gqlalchemy
688
+
689
+ from modaic.databases.graph_database import GraphDatabase
690
+ except ImportError:
691
+ raise ImportError(
692
+ "GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
693
+ ) from None
694
+
695
+ assert isinstance(db, GraphDatabase), (
696
+ f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
697
+ )
698
+
699
+ cls = self.__class__
700
+
701
+ # Dynamically create a GQLAlchemy Node class for the Context if it doesn't exist
702
+ if self._type not in cls._gqlalchemy_class_registry:
703
+ ad_hoc_annotations = get_ad_hoc_annotations(self) if cls is Relation else {}
704
+ field_annotations = get_annotations(
705
+ cls,
706
+ exclude=GQLALCHEMY_EXCLUDED_FIELDS + ["start_node", "end_node"],
707
+ )
708
+ field_defaults = get_defaults(
709
+ cls,
710
+ exclude=GQLALCHEMY_EXCLUDED_FIELDS + ["start_node", "end_node"],
711
+ )
712
+ gqlalchemy_class = type(
713
+ f"{cls.__name__}Rel",
714
+ (gqlalchemy.Relationship,),
715
+ {
716
+ "__annotations__": {
717
+ **ad_hoc_annotations,
718
+ **field_annotations,
719
+ "modaic_id": str,
720
+ },
721
+ "modaic_id": V1Field(unique=True, db=db._client),
722
+ **field_defaults,
723
+ },
724
+ type=self._type,
725
+ )
726
+ cls._gqlalchemy_class_registry[self._type] = gqlalchemy_class
727
+
728
+ gqlalchemy_class = cls._gqlalchemy_class_registry[self._type]
729
+
730
+ if self.start_node is not None and self.start_node_gql_id is None:
731
+ self.start_node.save(db)
732
+ if self.end_node is not None and self.end_node_gql_id is None:
733
+ self.end_node.save(db)
734
+
735
+ if self._gqlalchemy_id is None:
736
+ return gqlalchemy.Relationship.parse_obj(
737
+ {
738
+ "_type": self._type,
739
+ "modaic_id": self.id,
740
+ "_start_node_id": self.start_node_gql_id,
741
+ "_end_node_id": self.end_node_gql_id,
742
+ **self.model_dump(
743
+ exclude={"id", "start_node", "end_node", "_type"},
744
+ include_hidden=True,
745
+ ),
746
+ }
747
+ )
748
+ else:
749
+ return gqlalchemy.Relationship.parse_obj(
750
+ {
751
+ "_type": self._type,
752
+ "modaic_id": self.id,
753
+ "_id": self._gqlalchemy_id,
754
+ "_start_node_id": self.start_node_gql_id,
755
+ "_end_node_id": self.end_node_gql_id,
756
+ **self.model_dump(
757
+ exclude={"id", "start_node", "end_node", "_type"},
758
+ include_hidden=True,
759
+ ),
760
+ }
761
+ )
762
+
763
+ @classmethod
764
+ def from_gqlalchemy(cls, gqlalchemy_rel: "gqlalchemy.Relationship") -> "Relation":
765
+ """
766
+ Convert a GQLAlchemy `Relationship` into a `Relation` instance. If `cls` is the `Relation` class itself, it will try to return an instance of a subclass of `Relation` that matches the type of the GQLAlchemy Relationship. If none are found it will fallback to an instance of `Relation` since the `Relation` class allows definiing inline.
767
+ If `cls` is instead a subclass of `Relation`, it will return an instance of that subclass and fail if the properties do not align.
768
+ Args:
769
+ gqlalchemy_obj: The GQLAlchemy Relationship to convert.
770
+
771
+ Raises:
772
+ ValueError: If the GQLAlchemy Relationship does not have the required fields.
773
+ AssertionError: If the GQLAlchemy Relationship does not have the required type.
774
+
775
+ Returns:
776
+ The converted Relation or Relation subclass instance.
777
+ """
778
+ if cls is not Relation:
779
+ assert cls._type == gqlalchemy_rel._type, (
780
+ f"Cannot convert GQLAlchemy Relationship {gqlalchemy_rel} to {cls.__name__} because it does not have {cls.__name__}'s type: '{cls._type}'"
781
+ )
782
+ try:
783
+ kwargs = {**gqlalchemy_rel._properties}
784
+ kwargs["id"] = kwargs.pop("modaic_id")
785
+ kwargs["start_node"] = gqlalchemy_rel._start_node_id
786
+ kwargs["end_node"] = gqlalchemy_rel._end_node_id
787
+ new_relation = cls(**kwargs)
788
+ new_relation._gqlalchemy_id = gqlalchemy_rel._id
789
+ return new_relation
790
+ except ValidationError as e:
791
+ raise ValueError(
792
+ f"Failed to convert GQLAlchemy Relationship {gqlalchemy_rel} to {cls.__name__} because it does not have the required fields.\nError: {e}"
793
+ ) from e
794
+
795
+ # If cls is Relation, we need to find the subclass of Relation that matches the type of the GQLAlchemy Relationship.
796
+ # CAVEAT: Relation is a subclass of Context, so we can just use the same Context._type_registry.
797
+ if subclass := Context._type_registry.get(gqlalchemy_rel._type):
798
+ assert issubclass(subclass, Relation), (
799
+ f"Found Relation subclass with matching type, but cannot convert GQLAlchemy Relationship {gqlalchemy_rel} to {subclass.__name__} because it is not a subclass of Relation"
800
+ )
801
+ return subclass.from_gqlalchemy(gqlalchemy_rel)
802
+ # If no subclass is found, we can just create a new Relation object with the properties of the GQLAlchemy Relationship.
803
+ else:
804
+ kwargs = {**gqlalchemy_rel._properties}
805
+ kwargs["id"] = kwargs.pop("modaic_id")
806
+ kwargs["start_node"] = gqlalchemy_rel._start_node_id
807
+ kwargs["end_node"] = gqlalchemy_rel._end_node_id
808
+ kwargs["_type"] = gqlalchemy_rel._type
809
+ new_relation = cls(**kwargs)
810
+ new_relation._gqlalchemy_id = gqlalchemy_rel._id
811
+ return new_relation
812
+
813
+ def save(self, db: "GraphDatabase"):
814
+ """
815
+ Save the Relation object to the GraphDatabase.
816
+
817
+ !!! warning
818
+ This method is not thread safe. We are actively working on a solution to make it thread safe.
819
+ """
820
+
821
+ try:
822
+ from modaic.databases.graph_database import GraphDatabase
823
+ except ImportError:
824
+ raise ImportError(
825
+ "GQLAlchemy is not installed. Please install the graph extension for modaic with `uv add modaic[graph]`"
826
+ ) from None
827
+
828
+ assert isinstance(db, GraphDatabase), (
829
+ f"Expected db to be a modaic.databases.GraphDatabase instance. Got {type(db)} instead."
830
+ )
831
+ result = db.save_relationship(self)
832
+ for k in self.model_dump(exclude={"id", "start_node", "end_node"}, include_hidden=True):
833
+ setattr(self, k, getattr(result, k))
834
+ self._gqlalchemy_id = result._id
835
+
836
+ def load(self, db: "GraphDatabase"):
837
+ """
838
+ Loads a relationship from GraphDatabase.
839
+ If the relationship._id is not None it fetches the relationship from GraphDatabase with that
840
+ internal id.
841
+ If the relationship has unique fields it fetches the relationship from GraphDatabase with
842
+ those unique fields set.
843
+ Otherwise it tries to find any relationship in GraphDatabase that has all properties
844
+ set to exactly the same values.
845
+ If no relationship is found or no properties are set it raises a GQLAlchemyError.
846
+ """
847
+ raise NotImplementedError("Not implemented")
848
+
849
+
850
+ def _cast_type_if_base_model(field_type: t.Type) -> t.Type:
851
+ """
852
+ If field_type is a typing construct, reconstruct it from origin/args.
853
+ If it's a Pydantic BaseModel subclass, map it to `dict`.
854
+ Otherwise return the type itself.
855
+ """
856
+ origin = t.get_origin(field_type)
857
+
858
+ # Non-typing constructs
859
+ if origin is None:
860
+ # Only call issubclass on real classes
861
+ if isinstance(field_type, type) and issubclass(field_type, BaseModel):
862
+ return dict
863
+ return field_type
864
+
865
+ args = t.get_args(field_type)
866
+
867
+ # Annotated[T, m1, m2, ...] # noqa: ERA001
868
+ if origin is t.Annotated:
869
+ base, *meta = args
870
+ # Annotated allows multiple args; pass a tuple to __class_getitem__
871
+ return t.Annotated.__class_getitem__((_cast_type_if_base_model(base), *meta))
872
+
873
+ # Unions: typing.Union[...] or PEP 604 (A | B)
874
+ if origin in (t.Union, UnionType):
875
+ return t.Union[tuple(_cast_type_if_base_model(a) for a in args)]
876
+
877
+ # Literal / Final / ClassVar accept tuple args via typing protocol
878
+ if origin in (t.Literal, t.Final, t.ClassVar):
879
+ return origin.__getitem__([_cast_type_if_base_model(a) for a in args])
880
+
881
+ # Builtin generics (PEP 585): list[T], dict[K, V], set[T], tuple[...]
882
+ if origin in (list, set, frozenset):
883
+ (inner_type,) = args
884
+ return origin[_cast_type_if_base_model(inner_type)]
885
+ if origin is dict:
886
+ key_type, value_type = args
887
+ return dict[_cast_type_if_base_model(key_type), _cast_type_if_base_model(value_type)]
888
+ if origin is tuple:
889
+ # tuple[int, ...] vs tuple[int, str]
890
+ if len(args) == 2 and args[1] is Ellipsis:
891
+ return tuple[_cast_type_if_base_model(args[0]), ...]
892
+ return tuple[tuple([_cast_type_if_base_model(a) for a in args])] # tuple[(A, B, C)]
893
+
894
+ # ABC generics (e.g., Mapping, Sequence, Iterable, etc.) usually accept tuple args
895
+ try:
896
+ return origin.__class_getitem__([_cast_type_if_base_model(a) for a in args])
897
+ except Exception:
898
+ # Last resort: try simple unpack for 1–2 arity generics
899
+ if len(args) == 1:
900
+ return origin[_cast_type_if_base_model(args[0])]
901
+ elif len(args) == 2:
902
+ return origin[
903
+ _cast_type_if_base_model(args[0]),
904
+ _cast_type_if_base_model(args[1]),
905
+ ]
906
+ raise
907
+
908
+
909
+ def get_annotations(cls: t.Type, exclude: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Type]:
910
+ if exclude is None:
911
+ exclude = []
912
+ if not issubclass(cls, Context):
913
+ return {}
914
+ elif cls is Context:
915
+ res = {k: _cast_type_if_base_model(v) for k, v in cls.__annotations__.items() if k not in exclude}
916
+ return res
917
+ else:
918
+ annotations = {}
919
+ for base in cls.__bases__:
920
+ annotations.update(get_annotations(base, exclude))
921
+ annotations.update({k: _cast_type_if_base_model(v) for k, v in cls.__annotations__.items() if k not in exclude})
922
+ return annotations
923
+
924
+
925
+ def _cast_if_base_model(field_default: t.Any) -> t.Any:
926
+ if isinstance(field_default, BaseModel):
927
+ return field_default.model_dump()
928
+ return field_default
929
+
930
+
931
+ def get_defaults(cls: t.Type[Context], exclude: t.Optional[t.List[str]] = None) -> t.Dict[str, t.Any]:
932
+ if exclude is None:
933
+ exclude = []
934
+ defaults: t.Dict[str, t.Any] = {}
935
+ for name, v2_field in cls.model_fields.items():
936
+ if name in exclude or v2_field.is_required():
937
+ continue
938
+ kwargs = {}
939
+ if extra_kwargs := getattr(v2_field, "json_schema_extra", None):
940
+ kwargs.update(extra_kwargs)
941
+
942
+ factory = v2_field.default_factory
943
+ if factory is not None:
944
+ kwargs["default_factory"] = lambda f=factory: _cast_if_base_model(f())
945
+ else:
946
+ kwargs["default"] = _cast_if_base_model(v2_field.default)
947
+
948
+ v1_field = V1Field(**kwargs)
949
+ defaults[name] = v1_field
950
+
951
+ return defaults
952
+
953
+
954
+ def get_ad_hoc_annotations(rel: Relation) -> t.Dict[str, t.Type]:
955
+ """
956
+ Gets "adhoc" annotations for a Relation object. Specifically, for when Relations are created inline.
957
+ (i.e. when you do `Relation(x="test", _type="TEST_REL")`).
958
+ This is for those fields that were decleared inline.
959
+ Args:
960
+ rel: The Relation object to get the adhoc annotations for.
961
+
962
+ Returns:
963
+ A dictionary of the adhoc annotations.
964
+ """
965
+ annotations = {}
966
+ for name, val in rel.model_dump(
967
+ exclude=GQLALCHEMY_EXCLUDED_FIELDS + ["start_node", "end_node"],
968
+ include_hidden=True,
969
+ ).items():
970
+ if val is None:
971
+ annotations[name] = t.Any
972
+ elif isinstance(val, BaseModel):
973
+ annotations[name] = dict
974
+ else:
975
+ annotations[name] = type(val)
976
+ return annotations
977
+
978
+
979
+ @t.runtime_checkable
980
+ class Hydratable(t.Protocol):
981
+ def hydrate(self, file_store: FileStore) -> None:
982
+ pass
983
+
984
+ @classmethod
985
+ def from_file(cls, file: str, file_store: FileStore, type: str, params: dict = None) -> "Hydratable":
986
+ """
987
+ Load a Hydratable instance from a file.
988
+
989
+ Args:
990
+ file: The file to load.
991
+ file_store: The file store to use.
992
+ type: The type of file to expect.
993
+ params: Extra parameters to pass to the constructor.
994
+ """
995
+ pass
996
+
997
+
998
+ if t.TYPE_CHECKING:
999
+ # @runtime_checkable
1000
+ class HydratableContext(Hydratable, Context):
1001
+ pass
1002
+
1003
+
1004
+ def is_hydratable(obj: t.Any) -> bool:
1005
+ return isinstance(obj, Hydratable) and isinstance(obj, Context)
1006
+
1007
+
1008
+ @t.runtime_checkable
1009
+ class Embeddable(t.Protocol):
1010
+ """
1011
+ A protocol for objects that can be embedded. These objects define the embedme function which can either return a string or an image.
1012
+ The embedme function can either take no args, or take an index name as an argument, which will be used to select the index to embed for.
1013
+ """
1014
+
1015
+ @t.overload
1016
+ def embedme(self) -> str | Image.Image: ...
1017
+
1018
+ @t.overload
1019
+ def embedme(self, index: t.Optional[str] = None) -> str | Image.Image: ...
1020
+
1021
+
1022
+ @t.runtime_checkable
1023
+ class MultiEmbeddable(t.Protocol):
1024
+ """
1025
+ A protocol for objects that can be embedded and have multiple embeddings. These objects define the embedme function which can either return a string or an image.
1026
+ The embedme function can either take no args, or take an index name as an argument, which will be used to select the index to embed for.
1027
+ """
1028
+
1029
+ @t.overload
1030
+ def embedme(self, index: t.Optional[str] = None) -> str | Image.Image: ...
1031
+
1032
+
1033
+ def is_embeddable(obj: t.Any) -> bool:
1034
+ return isinstance(obj, Embeddable) and isinstance(obj, Context)
1035
+
1036
+
1037
+ def is_multi_embeddable(obj: t.Any) -> bool:
1038
+ return isinstance(obj, MultiEmbeddable) and isinstance(obj, Context)
1039
+
1040
+
1041
+ if t.TYPE_CHECKING:
1042
+ # @runtime_checkable
1043
+ class EmbeddableContext(t.Protocol, Context):
1044
+ pass
1045
+
1046
+
1047
+ def _update_exclude(exclude: IncEx, hidden: t.Set[str]):
1048
+ if isinstance(exclude, set):
1049
+ return exclude.update(hidden)
1050
+ else: # NOTE: if not a set, it's a dict
1051
+ return exclude.update({k: True for k in hidden})
1052
+
1053
+
1054
+ def _dump_hidden_recursive(obj: t.Any):
1055
+ if isinstance(obj, BaseModel):
1056
+ return obj.model_dump(include_hidden=True)
1057
+ elif isinstance(obj, Context):
1058
+ return obj.model_dump(include_hidden=True)
1059
+ elif isinstance(obj, list):
1060
+ return [_dump_hidden_recursive(item) for item in obj]
1061
+ elif isinstance(obj, dict):
1062
+ return {k: _dump_hidden_recursive(v) for k, v in obj.items()}
1063
+ else:
1064
+ return obj