modaic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modaic might be problematic. Click here for more details.

Files changed (39) hide show
  1. modaic/__init__.py +25 -0
  2. modaic/agents/rag_agent.py +33 -0
  3. modaic/agents/registry.py +84 -0
  4. modaic/auto_agent.py +228 -0
  5. modaic/context/__init__.py +34 -0
  6. modaic/context/base.py +1064 -0
  7. modaic/context/dtype_mapping.py +25 -0
  8. modaic/context/table.py +585 -0
  9. modaic/context/text.py +94 -0
  10. modaic/databases/__init__.py +35 -0
  11. modaic/databases/graph_database.py +269 -0
  12. modaic/databases/sql_database.py +355 -0
  13. modaic/databases/vector_database/__init__.py +12 -0
  14. modaic/databases/vector_database/benchmarks/baseline.py +123 -0
  15. modaic/databases/vector_database/benchmarks/common.py +48 -0
  16. modaic/databases/vector_database/benchmarks/fork.py +132 -0
  17. modaic/databases/vector_database/benchmarks/threaded.py +119 -0
  18. modaic/databases/vector_database/vector_database.py +722 -0
  19. modaic/databases/vector_database/vendors/milvus.py +408 -0
  20. modaic/databases/vector_database/vendors/mongodb.py +0 -0
  21. modaic/databases/vector_database/vendors/pinecone.py +0 -0
  22. modaic/databases/vector_database/vendors/qdrant.py +1 -0
  23. modaic/exceptions.py +38 -0
  24. modaic/hub.py +305 -0
  25. modaic/indexing.py +127 -0
  26. modaic/module_utils.py +341 -0
  27. modaic/observability.py +275 -0
  28. modaic/precompiled.py +429 -0
  29. modaic/query_language.py +321 -0
  30. modaic/storage/__init__.py +3 -0
  31. modaic/storage/file_store.py +239 -0
  32. modaic/storage/pickle_store.py +25 -0
  33. modaic/types.py +287 -0
  34. modaic/utils.py +21 -0
  35. modaic-0.1.0.dist-info/METADATA +281 -0
  36. modaic-0.1.0.dist-info/RECORD +39 -0
  37. modaic-0.1.0.dist-info/WHEEL +5 -0
  38. modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
  39. modaic-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,722 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import (
3
+ Any,
4
+ Callable,
5
+ ClassVar,
6
+ Dict,
7
+ Generic,
8
+ Iterable,
9
+ List,
10
+ Literal,
11
+ NamedTuple,
12
+ NoReturn,
13
+ Optional,
14
+ Protocol,
15
+ Tuple,
16
+ Type,
17
+ TypeVar,
18
+ overload,
19
+ runtime_checkable,
20
+ )
21
+
22
+ import immutables
23
+ import numpy as np
24
+ from aenum import AutoNumberEnum
25
+ from langchain_core.structured_query import Visitor
26
+ from more_itertools import peekable
27
+ from PIL import Image
28
+ from tqdm.auto import tqdm
29
+
30
+ from ... import Embedder
31
+ from ...context.base import Context, Embeddable
32
+ from ...observability import Trackable, track_modaic_obj
33
+ from ...query_language import Condition, parse_modaic_filter
34
+
35
+ DEFAULT_INDEX_NAME = "default"
36
+
37
+
38
+ class SearchResult(NamedTuple):
39
+ id: str
40
+ score: float
41
+ context: Context
42
+
43
+
44
+ # TODO: Add casting logic
45
+ class VectorType(AutoNumberEnum):
46
+ _init_ = "supported_libraries"
47
+ # name | supported_libraries
48
+ FLOAT = ["milvus", "qdrant", "mongo", "pinecone"] # float32
49
+ FLOAT16 = ["milvus", "qdrant"]
50
+ BFLOAT16 = ["milvus"]
51
+ INT8 = ["milvus", "mongo"]
52
+ UINT8 = ["qdrant"]
53
+ BINARY = ["milvus", "mongo"]
54
+ MULTI = ["qdrant"]
55
+ FLOAT_SPARSE = ["milvus", "qdrant", "pinecone"]
56
+ FLOAT16_SPARSE = ["qdrant"]
57
+ INT8_SPARSE = ["qdrant"]
58
+
59
+
60
+ class IndexType(AutoNumberEnum):
61
+ """
62
+ The ANN or ENN algorithm to use for an index. IndexType.DEFAULT is IndexType.HNSW for most vector databases (milvus, qdrant, mongo).
63
+ """
64
+
65
+ _init_ = "supported_libraries"
66
+ # name | supported_libraries
67
+ DEFAULT = ["milvus", "qdrant", "mongo", "pinecone"]
68
+ HNSW = ["milvus", "qdrant", "mongo"]
69
+ FLAT = ["milvus", "redis"]
70
+ IVF_FLAT = ["milvus"]
71
+ IVF_SQ8 = ["milvus"]
72
+ IVF_PQ = ["milvus"]
73
+ IVF_RABITQ = ["milvus"]
74
+ GPU_IVF_FLAT = ["milvus"]
75
+ GPU_IVF_PQ = ["milvus"]
76
+ DISKANN = ["milvus"]
77
+ BIN_FLAT = ["milvus"]
78
+ BIN_IVF_FLAT = ["milvus"]
79
+ MINHASH_LSH = ["milvus"]
80
+ SPARSE_INVERTED_INDEX = ["milvus"]
81
+ INVERTED = ["milvus"]
82
+ BITMAP = ["milvus"]
83
+ TRIE = ["milvus"]
84
+ STL_SORT = ["milvus"]
85
+
86
+
87
+ class Metric(AutoNumberEnum):
88
+ _init_ = "supported_libraries" # mapping of the library that supports the metric and the name the library uses to refer to it
89
+ EUCLIDEAN = {
90
+ "milvus": "L2",
91
+ "qdrant": "Euclid",
92
+ "mongo": "euclidean",
93
+ "pinecone": "euclidean",
94
+ }
95
+ DOT_PRODUCT = {
96
+ "milvus": "IP",
97
+ "qdrant": "Dot",
98
+ "mongo": "dotProduct",
99
+ "pinecone": "dotproduct",
100
+ }
101
+ COSINE = {
102
+ "milvus": "COSINE",
103
+ "qdrant": "Cosine",
104
+ "mongo": "cosine",
105
+ "pinecone": "cosine",
106
+ }
107
+ MANHATTAN = {
108
+ "qdrant": "Manhattan",
109
+ "mongo": "manhattan",
110
+ }
111
+ HAMMING = {"milvus": "HAMMING"}
112
+ JACCARD = {"milvus": "JACCARD"}
113
+ MHJACCARD = {"milvus": "MHJACCARD"}
114
+ BM25 = {"milvus": "BM25"}
115
+
116
+
117
+ # TODO Make this support non-vector indexes like full-text search maybe?
118
+ @dataclass
119
+ class IndexConfig:
120
+ """
121
+ Configuration for a VDB index.
122
+
123
+ Args:
124
+ vector_type: The type of vector used by the index.
125
+ index_type: The type of index to use. see IndexType for available options.
126
+ metric: The metric to use for the index. see Metric for available options.
127
+ embedder: The embedder to use for the index. If not provided, will use the VectorDatabase's embedder.
128
+ """
129
+
130
+ vector_type: Optional[VectorType] = VectorType.FLOAT
131
+ index_type: Optional[IndexType] = IndexType.DEFAULT
132
+ metric: Optional[Metric] = Metric.COSINE
133
+ embedder: Optional[Embedder] = None
134
+
135
+
136
+ @dataclass
137
+ class CollectionConfig:
138
+ payload_class: Type[Context]
139
+ indexes: Dict[str, IndexConfig] = field(default_factory=dict)
140
+
141
+
142
+ TBackend = TypeVar("TBackend", bound="VectorDBBackend")
143
+
144
+
145
+ class VectorDatabase(Generic[TBackend], Trackable):
146
+ ext: "VDBExtensions[TBackend]"
147
+ collections: Dict[str, CollectionConfig]
148
+ default_payload_class: Optional[Type[Context]] = None
149
+ default_embedder: Optional[Embedder] = None
150
+
151
+ def __init__(
152
+ self,
153
+ backend: TBackend,
154
+ embedder: Optional[Embedder] = None,
155
+ payload_class: Optional[Type[Context]] = None,
156
+ **kwargs,
157
+ ):
158
+ """
159
+ Initialize a vanilla vector database. This is a base class for all vector databases. If you need more functionality from a specific vector database, you should use a specific subclass.
160
+
161
+ Args:
162
+ config: The configuration for the vector database
163
+ embedder: The embedder to use for the vector database
164
+ payload_class: The default context class for collections
165
+ **kwargs: Additional keyword arguments
166
+ """
167
+
168
+ Trackable.__init__(self, **kwargs)
169
+ if isinstance(payload_class, type) and not issubclass(payload_class, Context):
170
+ raise TypeError(f"payload_class must be a subclass of Context, got {payload_class}")
171
+
172
+ self.ext = VDBExtensions(backend)
173
+ self.collections = {}
174
+ self.default_payload_class = payload_class
175
+ self.default_embedder = embedder
176
+
177
+ def drop_collection(self, collection_name: str):
178
+ self.ext.backend.drop_collection(collection_name)
179
+
180
+ # TODO: Signature looks good but some things about how the class will need to change to support this.
181
+ def load_collection(
182
+ self,
183
+ collection_name: str,
184
+ payload_class: Type[Context],
185
+ embedder: Optional[Embedder | Dict[str, Embedder]] = None,
186
+ ):
187
+ """
188
+ Load collection information into the vector database.
189
+ Args:
190
+ collection_name: The name of the collection to load
191
+ payload_class: The context class of the context objects stored in the collection
192
+ index: The index configuration for the collection
193
+ """
194
+ if not issubclass(payload_class, Context):
195
+ raise TypeError(f"payload_class must be a subclass of Context, got {payload_class}")
196
+ if not self.ext.backend.has_collection(collection_name):
197
+ raise ValueError(f"Collection {collection_name} does not exist in the vector database")
198
+
199
+ index_cfg = IndexConfig(
200
+ vector_type=None,
201
+ index_type=None,
202
+ metric=None,
203
+ embedder=embedder or self.default_embedder,
204
+ )
205
+ self.collections[collection_name] = CollectionConfig(
206
+ indexes={DEFAULT_INDEX_NAME: index_cfg},
207
+ payload_class=payload_class,
208
+ )
209
+
210
+ def create_collection(
211
+ self,
212
+ collection_name: str,
213
+ payload_class: Type[Context],
214
+ metric: Metric = Metric.COSINE,
215
+ index_type: IndexType = IndexType.DEFAULT,
216
+ vector_type: VectorType = VectorType.FLOAT,
217
+ embedder: Optional[Embedder] = None,
218
+ exists_behavior: Literal["fail", "replace"] = "replace",
219
+ ):
220
+ """
221
+ Create a collection in the vector database.
222
+
223
+ Args:
224
+ collection_name: The name of the collection to create
225
+ payload_class: The class of the context objects stored in the collection
226
+ exists_behavior: The behavior when the collection already exists
227
+ """
228
+ if not issubclass(payload_class, Context):
229
+ raise TypeError(f"payload_class must be a subclass of Context, got {payload_class}")
230
+ collection_exists = self.ext.backend.has_collection(collection_name)
231
+
232
+ if collection_exists:
233
+ if exists_behavior == "fail":
234
+ raise ValueError(
235
+ f"Collection '{collection_name}' already exists and exists_behavior is set to 'fail', if you would like ti load the collection instead use load_collection()"
236
+ )
237
+ elif exists_behavior == "replace":
238
+ self.ext.backend.drop_collection(collection_name)
239
+
240
+ index_cfg = IndexConfig(
241
+ vector_type=vector_type,
242
+ index_type=index_type,
243
+ metric=metric,
244
+ embedder=embedder or self.default_embedder,
245
+ )
246
+ self.collections[collection_name] = CollectionConfig(
247
+ indexes={DEFAULT_INDEX_NAME: index_cfg},
248
+ payload_class=payload_class,
249
+ )
250
+
251
+ self.ext.backend.create_collection(collection_name, payload_class, index_cfg)
252
+
253
+ def list_collections(self) -> List[str]:
254
+ return self.ext.backend.list_collections()
255
+
256
+ def benchmark_add_records(
257
+ self,
258
+ collection_name: str,
259
+ func: Callable,
260
+ records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
261
+ batch_size: Optional[int] = None,
262
+ embedme_scope: Literal["auto", "context", "index"] = "auto",
263
+ ):
264
+ func(self, collection_name, records, batch_size, embedme_scope)
265
+
266
+ def add_records(
267
+ self,
268
+ collection_name: str,
269
+ records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
270
+ batch_size: Optional[int] = None,
271
+ embedme_scope: Literal["auto", "context", "index"] = "auto",
272
+ tqdm_total: Optional[int] = None,
273
+ ):
274
+ """
275
+ Add items to a collection in the vector database.
276
+ Uses the Context's get_embed_context() method and the embedder to create embeddings.
277
+
278
+ Args:
279
+ collection_name: The name of the collection to add records to
280
+ records: The records to add to the collection
281
+ batch_size: Optional batch size for processing records
282
+ """
283
+ if not records:
284
+ return
285
+
286
+ # NOTE: Make embedmes compatible with the ext's hybrid search function
287
+ if embedme_scope == "auto":
288
+ if _items_have_multiple_embedmes(records):
289
+ embedme_scope = "index"
290
+ else:
291
+ embedme_scope = "context"
292
+
293
+ if embedme_scope == "index":
294
+ embedmes: Dict[str, List[str | Image.Image]] = {
295
+ k: [] for k in self.collections[collection_name].indexes.keys()
296
+ }
297
+ else:
298
+ # CAVEAT: We make embedmes a dict with None as opposed to a list so we don't have to type check it
299
+ embedmes: Dict[None, List[str | Image.Image]] = {None: []}
300
+
301
+ serialized_contexts = []
302
+ # TODO: add multi-processing/multi-threading here, just ensure that the backend is thread-safe. Maybe we add a class level parameter to check if the vendor is thread-safe. Embedding will still need to happen on a single thread
303
+ for item in tqdm(
304
+ records,
305
+ desc="Adding records to vector database",
306
+ disable=tqdm_total is None,
307
+ total=tqdm_total or 0,
308
+ ):
309
+ cntxt = _add_ebedmes_and_return_context(embedmes, item)
310
+ serialized_contexts.append(cntxt)
311
+
312
+ if batch_size is not None and len(serialized_contexts) == batch_size:
313
+ self._embed_and_add_records(collection_name, embedmes, serialized_contexts)
314
+ if embedme_scope == "index":
315
+ embedmes = {k: [] for k in embedmes.keys()}
316
+ else:
317
+ embedmes = {None: []}
318
+ serialized_contexts = []
319
+
320
+ if embedmes:
321
+ self._embed_and_add_records(collection_name, embedmes, serialized_contexts)
322
+
323
+ def has_collection(self, collection_name: str) -> bool:
324
+ """
325
+ Check if a collection exists in the vector database.
326
+
327
+ Args:
328
+ collection_name: The name of the collection to check
329
+
330
+ Returns:
331
+ True if the collection exists, False otherwise
332
+ """
333
+ return self.ext.backend.has_collection(collection_name)
334
+
335
+ def _embed_and_add_records(
336
+ self,
337
+ collection_name: str,
338
+ embedmes: Dict[str, List[str | Image.Image]] | Dict[None, List[str | Image.Image]],
339
+ contexts: List[Context],
340
+ ):
341
+ # TODO: could add functionality for multiple embedmes per context (e.g. you want to embed both an image and a text description of an image)
342
+ all_embeddings = {}
343
+ if collection_name not in self.collections:
344
+ raise ValueError(
345
+ f"Collection {collection_name} not found in VectorDatabase's collections, Please use VectorDatabase.create_collection() to create a collection first. Alternatively, you can use VectorDatabase.load_collection() to add records to an existing collection."
346
+ )
347
+ try:
348
+ # NOTE: get embeddings for each index
349
+ for index_name, index_config in self.collections[collection_name].indexes.items():
350
+ # If dict is {None: embeddings} then we use the same embeddings for all indexes. Otherwise lookup embeddinsg for each index
351
+ key = None if None in embedmes else index_name
352
+ embeddings = index_config.embedder(embedmes[key])
353
+
354
+ # NOTE: Ensure embeddings is a 2D array (DSPy returns 1D for single strings, 2D for lists)
355
+ if embeddings.ndim == 1:
356
+ embeddings = embeddings.reshape(1, -1)
357
+
358
+ all_embeddings[index_name] = embeddings
359
+ except Exception as e:
360
+ raise ValueError(f"Failed to create embeddings for index: {index_name}") from e
361
+
362
+ data_to_insert: List[immutables.Map[str, np.ndarray]] = []
363
+ # FIXME Probably should add type checking to ensure context matches schema, not sure how to do this efficiently
364
+ for i, item in enumerate(contexts):
365
+ embedding_map: dict[str, np.ndarray] = {}
366
+ for index_name, embeddings in all_embeddings.items():
367
+ embedding_map[index_name] = embeddings[i]
368
+
369
+ # Create a record with embedding and validated metadata
370
+ record = self.ext.backend.create_record(embedding_map, item)
371
+
372
+ data_to_insert.append(record)
373
+
374
+ self.ext.backend.add_records(collection_name, data_to_insert)
375
+ del data_to_insert
376
+
377
+ # TODO: maybe better way of handling telling the integration module which Context class to return
378
+ # TODO: add support for storage contexts. Where the payload is stored in a context and is mapped to the data via id
379
+ # TODO: add support for multiple searches at once (i.e. accept a list of vectors)
380
+ @track_modaic_obj
381
+ def search(
382
+ self,
383
+ collection_name: str,
384
+ query: str | Image.Image | List[str] | List[Image.Image],
385
+ k: int = 10,
386
+ filter: Optional[Condition] = None,
387
+ ) -> List[List[SearchResult]]:
388
+ """
389
+ Retrieve records from the vector database.
390
+ Returns a list of SearchResult dictionaries
391
+ SearchResult is a NamedTuple with the following keys:
392
+ - id: The id of the record
393
+ - distance: The distance of the record
394
+ - context: The context object (unhydrated if its hydratable)
395
+
396
+ Args:
397
+ collection_name: The name of the collection to search
398
+ query: The vector to search with
399
+ k: The number of results to return
400
+ filter: Optional filter to apply to the search
401
+
402
+ Returns:
403
+ results: List of SearchResult dictionaries matching the search.
404
+
405
+ Example:
406
+ ```python
407
+ results = vdb.search("collection 1", "How do I bake an apple pie?", k=10)
408
+ print(results[0][0].context)
409
+ >>> <Context: Text(text="apple pie recipe is 2 cups of flour, 1 cup of sugar, 1 cup of milk, 1 cup of eggs, 1 cup of butter")>
410
+ ```
411
+
412
+ """
413
+ if filter is not None:
414
+ filter = parse_modaic_filter(self.ext.backend.mql_translator, filter)
415
+ indexes = self.collections[collection_name].indexes
416
+ if len(indexes) > 1:
417
+ raise ValueError(
418
+ f"Collection {collection_name} has multiple indexes, please use VectorDatabase.ext.hybrid_search with an index_name"
419
+ )
420
+ query = [query] if isinstance(query, (str, Image.Image)) else query
421
+ vectors = indexes[DEFAULT_INDEX_NAME].embedder(query)
422
+ vectors = [vectors] if vectors.ndim == 1 else list(vectors)
423
+ # CAVEAT: Allowing index_name to be None for libraries that don't care. Integration module should handle this behavior on their own.
424
+ return self.ext.backend.search(
425
+ collection_name,
426
+ vectors,
427
+ self.collections[collection_name].payload_class,
428
+ k,
429
+ filter,
430
+ )
431
+
432
+ def get_records(self, collection_name: str, record_id: List[str]) -> List[Context]:
433
+ """
434
+ Get a record from the vector database.
435
+
436
+ Args:
437
+ collection_name: The name of the collection
438
+ record_id: The ID of the record to retrieve
439
+
440
+ Returns:
441
+ The serialized context record.
442
+ """
443
+ return self.ext.backend.get_records(collection_name, self.collections[collection_name].payload_class, record_id)
444
+
445
+ def hybrid_search(
446
+ self,
447
+ collection_name: str,
448
+ vectors: List[np.ndarray],
449
+ index_names: List[str],
450
+ k: int = 10,
451
+ ) -> List[Context]:
452
+ """
453
+ Hybrid search the vector database.
454
+ """
455
+ raise NotImplementedError("hybrid_search is not implemented for this vector database")
456
+
457
+ def query(self, query: str, k: int = 10, filter: Optional[dict] = None) -> List[Context]:
458
+ """
459
+ Query the vector database.
460
+
461
+ Args:
462
+ query: The query string
463
+ k: The number of results to return
464
+ filter: Optional filter to apply to the query
465
+
466
+ Returns:
467
+ List of serialized contexts matching the query.
468
+ """
469
+ raise NotImplementedError("query is not implemented for this vector database")
470
+
471
+ def set_embedder(self, embedder: Embedder):
472
+ self.default_embedder = embedder
473
+
474
+ def upsert_records(self, collection_name: str, records: Iterable[Context]):
475
+ """
476
+ Upsert a record into the vector database.
477
+ """
478
+ raise NotImplementedError("upsert_record is not implemented for this vector database")
479
+
480
+ def delete_records(self, collection_name: str, context_ids: Iterable[str]):
481
+ """
482
+ Delete a record from the vector database.
483
+ """
484
+ raise NotImplementedError("delete_record is not implemented for this vector database")
485
+
486
+
487
+ @runtime_checkable
488
+ class VectorDBBackend(Protocol):
489
+ _name: ClassVar[str]
490
+ _client: Any
491
+ mql_translator: Visitor
492
+
493
+ def __init__(self, *args, **kwargs) -> Any: ...
494
+ def create_record(self, embedding_map: Dict[str, np.ndarray], context: Context) -> Any: ...
495
+ def add_records(self, collection_name: str, records: List[Any]) -> None: ...
496
+ def drop_collection(self, collection_name: str) -> None: ...
497
+ def create_collection(
498
+ self,
499
+ collection_name: str,
500
+ payload_class: Type[Context],
501
+ index: IndexConfig = IndexConfig(), # noqa: B008
502
+ ) -> None: ...
503
+ def list_collections(self) -> List[str]: ...
504
+ def has_collection(self, collection_name: str) -> bool: ...
505
+ def search(
506
+ self,
507
+ collection_name: str,
508
+ vectors: List[np.ndarray],
509
+ payload_class: Type[Context],
510
+ k: int,
511
+ filter: Optional[Any], # Any the backend's native filtering language
512
+ ) -> List[List[SearchResult]]: ...
513
+ def get_records(
514
+ self, collection_name: str, payload_class: Type[Context], record_ids: List[str]
515
+ ) -> List[Context]: ...
516
+
517
+
518
+ COMMON_EXT = {
519
+ "reindex",
520
+ }
521
+
522
+
523
+ @runtime_checkable
524
+ class SupportsBM25(VectorDBBackend, Protocol):
525
+ def bm25_search(
526
+ self,
527
+ collection_name: str,
528
+ query: str,
529
+ k: int,
530
+ ) -> List[Context]: ...
531
+ def create_bm25_collection(
532
+ self,
533
+ collection_name: str,
534
+ payload_class: Type[Context],
535
+ exists_behavior: Literal["fail", "replace"] = "replace",
536
+ ) -> List[Context]: ...
537
+ def load_bm25_collection(
538
+ self,
539
+ collection_name: str,
540
+ payload_class: Type[Context],
541
+ ) -> List[Context]: ...
542
+
543
+
544
+ @runtime_checkable
545
+ class SupportsHybridSearch(VectorDBBackend, Protocol):
546
+ def hybrid_search(
547
+ self,
548
+ collection_name: str,
549
+ vectors: Dict[str, np.ndarray],
550
+ k: int,
551
+ ) -> List[Context]: ...
552
+ def create_hybrid_collection(
553
+ self,
554
+ collection_name: str,
555
+ payload_class: Type[Context],
556
+ indexes: Dict[str, IndexConfig],
557
+ exists_behavior: Literal["fail", "replace"] = "replace",
558
+ ) -> List[Context]: ...
559
+ def load_hybrid_collection(
560
+ self,
561
+ collection_name: str,
562
+ payload_class: Type[Context],
563
+ indexes: Dict[str, IndexConfig],
564
+ ) -> List[Context]: ...
565
+
566
+
567
+ class VDBExtensions(Generic[TBackend]):
568
+ backend: TBackend
569
+
570
+ def __init__(self, backend: TBackend):
571
+ self.backend = backend
572
+
573
+ @property
574
+ def client(self) -> Any:
575
+ return self.backend._client
576
+
577
+ # Use constrained TypeVars so intersection Protocols bind correctly
578
+ TSupportsBM25 = TypeVar("TSupportsBM25", bound=SupportsBM25)
579
+ TSupportsHybridSearch = TypeVar("TSupportsHybridSearch", bound=SupportsHybridSearch)
580
+
581
+ @overload
582
+ def hybrid_search(
583
+ self: "VDBExtensions[TSupportsHybridSearch]",
584
+ collection_name: str,
585
+ vectors: Dict[str, np.ndarray],
586
+ k: int,
587
+ ) -> List[Context]: ...
588
+
589
+ @overload
590
+ def hybrid_search(
591
+ self: "VDBExtensions[TBackend]",
592
+ collection_name: str,
593
+ vectors: Dict[str, np.ndarray],
594
+ k: int,
595
+ ) -> NoReturn: ...
596
+
597
+ def hybrid_search(
598
+ self: "VDBExtensions[TBackend]",
599
+ collection_name: str,
600
+ vectors: Dict[str, np.ndarray],
601
+ k: int,
602
+ ):
603
+ if not isinstance(self.backend, SupportsHybridSearch):
604
+ raise AttributeError(
605
+ f"""{self.backend._name} does not support the function reindex.
606
+
607
+ Available functions: {self.available()}
608
+ """
609
+ )
610
+ return self.backend.hybrid_search(collection_name, vectors, k)
611
+
612
+ @overload
613
+ def bm25_search(
614
+ self: "VDBExtensions[TSupportsBM25]",
615
+ collection_name: str,
616
+ vectors: List[np.ndarray],
617
+ index_names: List[str],
618
+ k: int,
619
+ ) -> List[Context]: ...
620
+
621
+ @overload
622
+ def bm25_search(
623
+ self: "VDBExtensions[TBackend]",
624
+ collection_name: str,
625
+ vectors: List[np.ndarray],
626
+ index_names: List[str],
627
+ k: int,
628
+ ) -> NoReturn: ...
629
+
630
+ def bm25_search(
631
+ self: "VDBExtensions[TBackend]",
632
+ collection_name: str,
633
+ vectors: List[np.ndarray],
634
+ index_names: List[str],
635
+ k: int,
636
+ ) -> List[Context]:
637
+ if not isinstance(self.backend, SupportsBM25):
638
+ raise AttributeError(
639
+ f"""{self.backend._name} does not support the function hybrid_search.
640
+
641
+ Available functions: {self.available()}
642
+ """
643
+ )
644
+ return self.backend.hybrid_search(collection_name, vectors, index_names, k)
645
+
646
+ @overload
647
+ def create_hybrid_collection(
648
+ self: "VDBExtensions[TSupportsHybridSearch]",
649
+ query: str,
650
+ k: int,
651
+ filter: Optional[dict],
652
+ ) -> List[Context]: ...
653
+
654
+ @overload
655
+ def create_hybrid_collection(
656
+ self: "VDBExtensions[TBackend]", query: str, k: int, filter: Optional[dict]
657
+ ) -> NoReturn: ...
658
+
659
+ def create_hybrid_collection(
660
+ self: "VDBExtensions[TBackend]", query: str, k: int, filter: Optional[dict]
661
+ ) -> List[Context]:
662
+ if not isinstance(self.backend, SupportsHybridSearch):
663
+ raise AttributeError(
664
+ f"""{self.backend._name} does not support the function query.
665
+
666
+ Available functions: {self.available()}
667
+ """
668
+ )
669
+ return self.backend.query(query, k, filter)
670
+
671
+ def has(self, op: str) -> bool:
672
+ fn = getattr(self, op, None)
673
+ return callable(fn)
674
+
675
+ def available(self) -> List[str]:
676
+ return [op for op in COMMON_EXT if self.has(op)]
677
+
678
+
679
+ def _add_ebedmes_and_return_context(
680
+ embedmes: Dict[str | None, List[str | Image.Image]],
681
+ item: Embeddable | Tuple[str | Image.Image, Context],
682
+ ) -> Context:
683
+ """
684
+ Adds all embedmes to the embedmes dictionary and returns the context.
685
+ """
686
+ # Fast type check for tuple
687
+ if type(item) is tuple:
688
+ embedme = item[0]
689
+ for index in embedmes.keys():
690
+ embedmes[index].append(embedme)
691
+ return item[1]
692
+ elif _has_multiple_embedmes(item):
693
+ # CAVEAT: Context objects that implement Embeddable protocol and take in an index name as a parameter also accept None as the default index.
694
+ for index in embedmes.keys():
695
+ embedmes[index].append(item.embedme(index))
696
+ return item
697
+ else:
698
+ for index in embedmes.keys():
699
+ embedmes[index].append(item.embedme())
700
+ return item
701
+
702
+
703
+ def _has_multiple_embedmes(
704
+ item: Embeddable,
705
+ ):
706
+ """
707
+ Check if the item has multiple embedmes.
708
+ """
709
+ return item.embedme.__code__.co_argcount == 2
710
+
711
+
712
+ def _items_have_multiple_embedmes(
713
+ records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
714
+ ):
715
+ """
716
+ Check if the first record has multiple embedmes.
717
+ """
718
+ p = peekable(records)
719
+ first_item = p.peek()
720
+ if isinstance(first_item, Embeddable) and _has_multiple_embedmes(first_item):
721
+ return True
722
+ return False