modaic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modaic might be problematic. Click here for more details.
- modaic/__init__.py +25 -0
- modaic/agents/rag_agent.py +33 -0
- modaic/agents/registry.py +84 -0
- modaic/auto_agent.py +228 -0
- modaic/context/__init__.py +34 -0
- modaic/context/base.py +1064 -0
- modaic/context/dtype_mapping.py +25 -0
- modaic/context/table.py +585 -0
- modaic/context/text.py +94 -0
- modaic/databases/__init__.py +35 -0
- modaic/databases/graph_database.py +269 -0
- modaic/databases/sql_database.py +355 -0
- modaic/databases/vector_database/__init__.py +12 -0
- modaic/databases/vector_database/benchmarks/baseline.py +123 -0
- modaic/databases/vector_database/benchmarks/common.py +48 -0
- modaic/databases/vector_database/benchmarks/fork.py +132 -0
- modaic/databases/vector_database/benchmarks/threaded.py +119 -0
- modaic/databases/vector_database/vector_database.py +722 -0
- modaic/databases/vector_database/vendors/milvus.py +408 -0
- modaic/databases/vector_database/vendors/mongodb.py +0 -0
- modaic/databases/vector_database/vendors/pinecone.py +0 -0
- modaic/databases/vector_database/vendors/qdrant.py +1 -0
- modaic/exceptions.py +38 -0
- modaic/hub.py +305 -0
- modaic/indexing.py +127 -0
- modaic/module_utils.py +341 -0
- modaic/observability.py +275 -0
- modaic/precompiled.py +429 -0
- modaic/query_language.py +321 -0
- modaic/storage/__init__.py +3 -0
- modaic/storage/file_store.py +239 -0
- modaic/storage/pickle_store.py +25 -0
- modaic/types.py +287 -0
- modaic/utils.py +21 -0
- modaic-0.1.0.dist-info/METADATA +281 -0
- modaic-0.1.0.dist-info/RECORD +39 -0
- modaic-0.1.0.dist-info/WHEEL +5 -0
- modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
- modaic-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,722 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
Callable,
|
|
5
|
+
ClassVar,
|
|
6
|
+
Dict,
|
|
7
|
+
Generic,
|
|
8
|
+
Iterable,
|
|
9
|
+
List,
|
|
10
|
+
Literal,
|
|
11
|
+
NamedTuple,
|
|
12
|
+
NoReturn,
|
|
13
|
+
Optional,
|
|
14
|
+
Protocol,
|
|
15
|
+
Tuple,
|
|
16
|
+
Type,
|
|
17
|
+
TypeVar,
|
|
18
|
+
overload,
|
|
19
|
+
runtime_checkable,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
import immutables
|
|
23
|
+
import numpy as np
|
|
24
|
+
from aenum import AutoNumberEnum
|
|
25
|
+
from langchain_core.structured_query import Visitor
|
|
26
|
+
from more_itertools import peekable
|
|
27
|
+
from PIL import Image
|
|
28
|
+
from tqdm.auto import tqdm
|
|
29
|
+
|
|
30
|
+
from ... import Embedder
|
|
31
|
+
from ...context.base import Context, Embeddable
|
|
32
|
+
from ...observability import Trackable, track_modaic_obj
|
|
33
|
+
from ...query_language import Condition, parse_modaic_filter
|
|
34
|
+
|
|
35
|
+
DEFAULT_INDEX_NAME = "default"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SearchResult(NamedTuple):
|
|
39
|
+
id: str
|
|
40
|
+
score: float
|
|
41
|
+
context: Context
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# TODO: Add casting logic
|
|
45
|
+
class VectorType(AutoNumberEnum):
|
|
46
|
+
_init_ = "supported_libraries"
|
|
47
|
+
# name | supported_libraries
|
|
48
|
+
FLOAT = ["milvus", "qdrant", "mongo", "pinecone"] # float32
|
|
49
|
+
FLOAT16 = ["milvus", "qdrant"]
|
|
50
|
+
BFLOAT16 = ["milvus"]
|
|
51
|
+
INT8 = ["milvus", "mongo"]
|
|
52
|
+
UINT8 = ["qdrant"]
|
|
53
|
+
BINARY = ["milvus", "mongo"]
|
|
54
|
+
MULTI = ["qdrant"]
|
|
55
|
+
FLOAT_SPARSE = ["milvus", "qdrant", "pinecone"]
|
|
56
|
+
FLOAT16_SPARSE = ["qdrant"]
|
|
57
|
+
INT8_SPARSE = ["qdrant"]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class IndexType(AutoNumberEnum):
|
|
61
|
+
"""
|
|
62
|
+
The ANN or ENN algorithm to use for an index. IndexType.DEFAULT is IndexType.HNSW for most vector databases (milvus, qdrant, mongo).
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
_init_ = "supported_libraries"
|
|
66
|
+
# name | supported_libraries
|
|
67
|
+
DEFAULT = ["milvus", "qdrant", "mongo", "pinecone"]
|
|
68
|
+
HNSW = ["milvus", "qdrant", "mongo"]
|
|
69
|
+
FLAT = ["milvus", "redis"]
|
|
70
|
+
IVF_FLAT = ["milvus"]
|
|
71
|
+
IVF_SQ8 = ["milvus"]
|
|
72
|
+
IVF_PQ = ["milvus"]
|
|
73
|
+
IVF_RABITQ = ["milvus"]
|
|
74
|
+
GPU_IVF_FLAT = ["milvus"]
|
|
75
|
+
GPU_IVF_PQ = ["milvus"]
|
|
76
|
+
DISKANN = ["milvus"]
|
|
77
|
+
BIN_FLAT = ["milvus"]
|
|
78
|
+
BIN_IVF_FLAT = ["milvus"]
|
|
79
|
+
MINHASH_LSH = ["milvus"]
|
|
80
|
+
SPARSE_INVERTED_INDEX = ["milvus"]
|
|
81
|
+
INVERTED = ["milvus"]
|
|
82
|
+
BITMAP = ["milvus"]
|
|
83
|
+
TRIE = ["milvus"]
|
|
84
|
+
STL_SORT = ["milvus"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Metric(AutoNumberEnum):
|
|
88
|
+
_init_ = "supported_libraries" # mapping of the library that supports the metric and the name the library uses to refer to it
|
|
89
|
+
EUCLIDEAN = {
|
|
90
|
+
"milvus": "L2",
|
|
91
|
+
"qdrant": "Euclid",
|
|
92
|
+
"mongo": "euclidean",
|
|
93
|
+
"pinecone": "euclidean",
|
|
94
|
+
}
|
|
95
|
+
DOT_PRODUCT = {
|
|
96
|
+
"milvus": "IP",
|
|
97
|
+
"qdrant": "Dot",
|
|
98
|
+
"mongo": "dotProduct",
|
|
99
|
+
"pinecone": "dotproduct",
|
|
100
|
+
}
|
|
101
|
+
COSINE = {
|
|
102
|
+
"milvus": "COSINE",
|
|
103
|
+
"qdrant": "Cosine",
|
|
104
|
+
"mongo": "cosine",
|
|
105
|
+
"pinecone": "cosine",
|
|
106
|
+
}
|
|
107
|
+
MANHATTAN = {
|
|
108
|
+
"qdrant": "Manhattan",
|
|
109
|
+
"mongo": "manhattan",
|
|
110
|
+
}
|
|
111
|
+
HAMMING = {"milvus": "HAMMING"}
|
|
112
|
+
JACCARD = {"milvus": "JACCARD"}
|
|
113
|
+
MHJACCARD = {"milvus": "MHJACCARD"}
|
|
114
|
+
BM25 = {"milvus": "BM25"}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# TODO Make this support non-vector indexes like full-text search maybe?
|
|
118
|
+
@dataclass
|
|
119
|
+
class IndexConfig:
|
|
120
|
+
"""
|
|
121
|
+
Configuration for a VDB index.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
vector_type: The type of vector used by the index.
|
|
125
|
+
index_type: The type of index to use. see IndexType for available options.
|
|
126
|
+
metric: The metric to use for the index. see Metric for available options.
|
|
127
|
+
embedder: The embedder to use for the index. If not provided, will use the VectorDatabase's embedder.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
vector_type: Optional[VectorType] = VectorType.FLOAT
|
|
131
|
+
index_type: Optional[IndexType] = IndexType.DEFAULT
|
|
132
|
+
metric: Optional[Metric] = Metric.COSINE
|
|
133
|
+
embedder: Optional[Embedder] = None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class CollectionConfig:
|
|
138
|
+
payload_class: Type[Context]
|
|
139
|
+
indexes: Dict[str, IndexConfig] = field(default_factory=dict)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
TBackend = TypeVar("TBackend", bound="VectorDBBackend")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class VectorDatabase(Generic[TBackend], Trackable):
|
|
146
|
+
ext: "VDBExtensions[TBackend]"
|
|
147
|
+
collections: Dict[str, CollectionConfig]
|
|
148
|
+
default_payload_class: Optional[Type[Context]] = None
|
|
149
|
+
default_embedder: Optional[Embedder] = None
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
backend: TBackend,
|
|
154
|
+
embedder: Optional[Embedder] = None,
|
|
155
|
+
payload_class: Optional[Type[Context]] = None,
|
|
156
|
+
**kwargs,
|
|
157
|
+
):
|
|
158
|
+
"""
|
|
159
|
+
Initialize a vanilla vector database. This is a base class for all vector databases. If you need more functionality from a specific vector database, you should use a specific subclass.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
config: The configuration for the vector database
|
|
163
|
+
embedder: The embedder to use for the vector database
|
|
164
|
+
payload_class: The default context class for collections
|
|
165
|
+
**kwargs: Additional keyword arguments
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
Trackable.__init__(self, **kwargs)
|
|
169
|
+
if isinstance(payload_class, type) and not issubclass(payload_class, Context):
|
|
170
|
+
raise TypeError(f"payload_class must be a subclass of Context, got {payload_class}")
|
|
171
|
+
|
|
172
|
+
self.ext = VDBExtensions(backend)
|
|
173
|
+
self.collections = {}
|
|
174
|
+
self.default_payload_class = payload_class
|
|
175
|
+
self.default_embedder = embedder
|
|
176
|
+
|
|
177
|
+
def drop_collection(self, collection_name: str):
|
|
178
|
+
self.ext.backend.drop_collection(collection_name)
|
|
179
|
+
|
|
180
|
+
# TODO: Signature looks good but some things about how the class will need to change to support this.
|
|
181
|
+
def load_collection(
|
|
182
|
+
self,
|
|
183
|
+
collection_name: str,
|
|
184
|
+
payload_class: Type[Context],
|
|
185
|
+
embedder: Optional[Embedder | Dict[str, Embedder]] = None,
|
|
186
|
+
):
|
|
187
|
+
"""
|
|
188
|
+
Load collection information into the vector database.
|
|
189
|
+
Args:
|
|
190
|
+
collection_name: The name of the collection to load
|
|
191
|
+
payload_class: The context class of the context objects stored in the collection
|
|
192
|
+
index: The index configuration for the collection
|
|
193
|
+
"""
|
|
194
|
+
if not issubclass(payload_class, Context):
|
|
195
|
+
raise TypeError(f"payload_class must be a subclass of Context, got {payload_class}")
|
|
196
|
+
if not self.ext.backend.has_collection(collection_name):
|
|
197
|
+
raise ValueError(f"Collection {collection_name} does not exist in the vector database")
|
|
198
|
+
|
|
199
|
+
index_cfg = IndexConfig(
|
|
200
|
+
vector_type=None,
|
|
201
|
+
index_type=None,
|
|
202
|
+
metric=None,
|
|
203
|
+
embedder=embedder or self.default_embedder,
|
|
204
|
+
)
|
|
205
|
+
self.collections[collection_name] = CollectionConfig(
|
|
206
|
+
indexes={DEFAULT_INDEX_NAME: index_cfg},
|
|
207
|
+
payload_class=payload_class,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def create_collection(
|
|
211
|
+
self,
|
|
212
|
+
collection_name: str,
|
|
213
|
+
payload_class: Type[Context],
|
|
214
|
+
metric: Metric = Metric.COSINE,
|
|
215
|
+
index_type: IndexType = IndexType.DEFAULT,
|
|
216
|
+
vector_type: VectorType = VectorType.FLOAT,
|
|
217
|
+
embedder: Optional[Embedder] = None,
|
|
218
|
+
exists_behavior: Literal["fail", "replace"] = "replace",
|
|
219
|
+
):
|
|
220
|
+
"""
|
|
221
|
+
Create a collection in the vector database.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
collection_name: The name of the collection to create
|
|
225
|
+
payload_class: The class of the context objects stored in the collection
|
|
226
|
+
exists_behavior: The behavior when the collection already exists
|
|
227
|
+
"""
|
|
228
|
+
if not issubclass(payload_class, Context):
|
|
229
|
+
raise TypeError(f"payload_class must be a subclass of Context, got {payload_class}")
|
|
230
|
+
collection_exists = self.ext.backend.has_collection(collection_name)
|
|
231
|
+
|
|
232
|
+
if collection_exists:
|
|
233
|
+
if exists_behavior == "fail":
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Collection '{collection_name}' already exists and exists_behavior is set to 'fail', if you would like ti load the collection instead use load_collection()"
|
|
236
|
+
)
|
|
237
|
+
elif exists_behavior == "replace":
|
|
238
|
+
self.ext.backend.drop_collection(collection_name)
|
|
239
|
+
|
|
240
|
+
index_cfg = IndexConfig(
|
|
241
|
+
vector_type=vector_type,
|
|
242
|
+
index_type=index_type,
|
|
243
|
+
metric=metric,
|
|
244
|
+
embedder=embedder or self.default_embedder,
|
|
245
|
+
)
|
|
246
|
+
self.collections[collection_name] = CollectionConfig(
|
|
247
|
+
indexes={DEFAULT_INDEX_NAME: index_cfg},
|
|
248
|
+
payload_class=payload_class,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self.ext.backend.create_collection(collection_name, payload_class, index_cfg)
|
|
252
|
+
|
|
253
|
+
def list_collections(self) -> List[str]:
|
|
254
|
+
return self.ext.backend.list_collections()
|
|
255
|
+
|
|
256
|
+
def benchmark_add_records(
|
|
257
|
+
self,
|
|
258
|
+
collection_name: str,
|
|
259
|
+
func: Callable,
|
|
260
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
261
|
+
batch_size: Optional[int] = None,
|
|
262
|
+
embedme_scope: Literal["auto", "context", "index"] = "auto",
|
|
263
|
+
):
|
|
264
|
+
func(self, collection_name, records, batch_size, embedme_scope)
|
|
265
|
+
|
|
266
|
+
def add_records(
|
|
267
|
+
self,
|
|
268
|
+
collection_name: str,
|
|
269
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
270
|
+
batch_size: Optional[int] = None,
|
|
271
|
+
embedme_scope: Literal["auto", "context", "index"] = "auto",
|
|
272
|
+
tqdm_total: Optional[int] = None,
|
|
273
|
+
):
|
|
274
|
+
"""
|
|
275
|
+
Add items to a collection in the vector database.
|
|
276
|
+
Uses the Context's get_embed_context() method and the embedder to create embeddings.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
collection_name: The name of the collection to add records to
|
|
280
|
+
records: The records to add to the collection
|
|
281
|
+
batch_size: Optional batch size for processing records
|
|
282
|
+
"""
|
|
283
|
+
if not records:
|
|
284
|
+
return
|
|
285
|
+
|
|
286
|
+
# NOTE: Make embedmes compatible with the ext's hybrid search function
|
|
287
|
+
if embedme_scope == "auto":
|
|
288
|
+
if _items_have_multiple_embedmes(records):
|
|
289
|
+
embedme_scope = "index"
|
|
290
|
+
else:
|
|
291
|
+
embedme_scope = "context"
|
|
292
|
+
|
|
293
|
+
if embedme_scope == "index":
|
|
294
|
+
embedmes: Dict[str, List[str | Image.Image]] = {
|
|
295
|
+
k: [] for k in self.collections[collection_name].indexes.keys()
|
|
296
|
+
}
|
|
297
|
+
else:
|
|
298
|
+
# CAVEAT: We make embedmes a dict with None as opposed to a list so we don't have to type check it
|
|
299
|
+
embedmes: Dict[None, List[str | Image.Image]] = {None: []}
|
|
300
|
+
|
|
301
|
+
serialized_contexts = []
|
|
302
|
+
# TODO: add multi-processing/multi-threading here, just ensure that the backend is thread-safe. Maybe we add a class level parameter to check if the vendor is thread-safe. Embedding will still need to happen on a single thread
|
|
303
|
+
for item in tqdm(
|
|
304
|
+
records,
|
|
305
|
+
desc="Adding records to vector database",
|
|
306
|
+
disable=tqdm_total is None,
|
|
307
|
+
total=tqdm_total or 0,
|
|
308
|
+
):
|
|
309
|
+
cntxt = _add_ebedmes_and_return_context(embedmes, item)
|
|
310
|
+
serialized_contexts.append(cntxt)
|
|
311
|
+
|
|
312
|
+
if batch_size is not None and len(serialized_contexts) == batch_size:
|
|
313
|
+
self._embed_and_add_records(collection_name, embedmes, serialized_contexts)
|
|
314
|
+
if embedme_scope == "index":
|
|
315
|
+
embedmes = {k: [] for k in embedmes.keys()}
|
|
316
|
+
else:
|
|
317
|
+
embedmes = {None: []}
|
|
318
|
+
serialized_contexts = []
|
|
319
|
+
|
|
320
|
+
if embedmes:
|
|
321
|
+
self._embed_and_add_records(collection_name, embedmes, serialized_contexts)
|
|
322
|
+
|
|
323
|
+
def has_collection(self, collection_name: str) -> bool:
|
|
324
|
+
"""
|
|
325
|
+
Check if a collection exists in the vector database.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
collection_name: The name of the collection to check
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
True if the collection exists, False otherwise
|
|
332
|
+
"""
|
|
333
|
+
return self.ext.backend.has_collection(collection_name)
|
|
334
|
+
|
|
335
|
+
def _embed_and_add_records(
|
|
336
|
+
self,
|
|
337
|
+
collection_name: str,
|
|
338
|
+
embedmes: Dict[str, List[str | Image.Image]] | Dict[None, List[str | Image.Image]],
|
|
339
|
+
contexts: List[Context],
|
|
340
|
+
):
|
|
341
|
+
# TODO: could add functionality for multiple embedmes per context (e.g. you want to embed both an image and a text description of an image)
|
|
342
|
+
all_embeddings = {}
|
|
343
|
+
if collection_name not in self.collections:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Collection {collection_name} not found in VectorDatabase's collections, Please use VectorDatabase.create_collection() to create a collection first. Alternatively, you can use VectorDatabase.load_collection() to add records to an existing collection."
|
|
346
|
+
)
|
|
347
|
+
try:
|
|
348
|
+
# NOTE: get embeddings for each index
|
|
349
|
+
for index_name, index_config in self.collections[collection_name].indexes.items():
|
|
350
|
+
# If dict is {None: embeddings} then we use the same embeddings for all indexes. Otherwise lookup embeddinsg for each index
|
|
351
|
+
key = None if None in embedmes else index_name
|
|
352
|
+
embeddings = index_config.embedder(embedmes[key])
|
|
353
|
+
|
|
354
|
+
# NOTE: Ensure embeddings is a 2D array (DSPy returns 1D for single strings, 2D for lists)
|
|
355
|
+
if embeddings.ndim == 1:
|
|
356
|
+
embeddings = embeddings.reshape(1, -1)
|
|
357
|
+
|
|
358
|
+
all_embeddings[index_name] = embeddings
|
|
359
|
+
except Exception as e:
|
|
360
|
+
raise ValueError(f"Failed to create embeddings for index: {index_name}") from e
|
|
361
|
+
|
|
362
|
+
data_to_insert: List[immutables.Map[str, np.ndarray]] = []
|
|
363
|
+
# FIXME Probably should add type checking to ensure context matches schema, not sure how to do this efficiently
|
|
364
|
+
for i, item in enumerate(contexts):
|
|
365
|
+
embedding_map: dict[str, np.ndarray] = {}
|
|
366
|
+
for index_name, embeddings in all_embeddings.items():
|
|
367
|
+
embedding_map[index_name] = embeddings[i]
|
|
368
|
+
|
|
369
|
+
# Create a record with embedding and validated metadata
|
|
370
|
+
record = self.ext.backend.create_record(embedding_map, item)
|
|
371
|
+
|
|
372
|
+
data_to_insert.append(record)
|
|
373
|
+
|
|
374
|
+
self.ext.backend.add_records(collection_name, data_to_insert)
|
|
375
|
+
del data_to_insert
|
|
376
|
+
|
|
377
|
+
# TODO: maybe better way of handling telling the integration module which Context class to return
|
|
378
|
+
# TODO: add support for storage contexts. Where the payload is stored in a context and is mapped to the data via id
|
|
379
|
+
# TODO: add support for multiple searches at once (i.e. accept a list of vectors)
|
|
380
|
+
@track_modaic_obj
|
|
381
|
+
def search(
|
|
382
|
+
self,
|
|
383
|
+
collection_name: str,
|
|
384
|
+
query: str | Image.Image | List[str] | List[Image.Image],
|
|
385
|
+
k: int = 10,
|
|
386
|
+
filter: Optional[Condition] = None,
|
|
387
|
+
) -> List[List[SearchResult]]:
|
|
388
|
+
"""
|
|
389
|
+
Retrieve records from the vector database.
|
|
390
|
+
Returns a list of SearchResult dictionaries
|
|
391
|
+
SearchResult is a NamedTuple with the following keys:
|
|
392
|
+
- id: The id of the record
|
|
393
|
+
- distance: The distance of the record
|
|
394
|
+
- context: The context object (unhydrated if its hydratable)
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
collection_name: The name of the collection to search
|
|
398
|
+
query: The vector to search with
|
|
399
|
+
k: The number of results to return
|
|
400
|
+
filter: Optional filter to apply to the search
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
results: List of SearchResult dictionaries matching the search.
|
|
404
|
+
|
|
405
|
+
Example:
|
|
406
|
+
```python
|
|
407
|
+
results = vdb.search("collection 1", "How do I bake an apple pie?", k=10)
|
|
408
|
+
print(results[0][0].context)
|
|
409
|
+
>>> <Context: Text(text="apple pie recipe is 2 cups of flour, 1 cup of sugar, 1 cup of milk, 1 cup of eggs, 1 cup of butter")>
|
|
410
|
+
```
|
|
411
|
+
|
|
412
|
+
"""
|
|
413
|
+
if filter is not None:
|
|
414
|
+
filter = parse_modaic_filter(self.ext.backend.mql_translator, filter)
|
|
415
|
+
indexes = self.collections[collection_name].indexes
|
|
416
|
+
if len(indexes) > 1:
|
|
417
|
+
raise ValueError(
|
|
418
|
+
f"Collection {collection_name} has multiple indexes, please use VectorDatabase.ext.hybrid_search with an index_name"
|
|
419
|
+
)
|
|
420
|
+
query = [query] if isinstance(query, (str, Image.Image)) else query
|
|
421
|
+
vectors = indexes[DEFAULT_INDEX_NAME].embedder(query)
|
|
422
|
+
vectors = [vectors] if vectors.ndim == 1 else list(vectors)
|
|
423
|
+
# CAVEAT: Allowing index_name to be None for libraries that don't care. Integration module should handle this behavior on their own.
|
|
424
|
+
return self.ext.backend.search(
|
|
425
|
+
collection_name,
|
|
426
|
+
vectors,
|
|
427
|
+
self.collections[collection_name].payload_class,
|
|
428
|
+
k,
|
|
429
|
+
filter,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def get_records(self, collection_name: str, record_id: List[str]) -> List[Context]:
|
|
433
|
+
"""
|
|
434
|
+
Get a record from the vector database.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
collection_name: The name of the collection
|
|
438
|
+
record_id: The ID of the record to retrieve
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
The serialized context record.
|
|
442
|
+
"""
|
|
443
|
+
return self.ext.backend.get_records(collection_name, self.collections[collection_name].payload_class, record_id)
|
|
444
|
+
|
|
445
|
+
def hybrid_search(
|
|
446
|
+
self,
|
|
447
|
+
collection_name: str,
|
|
448
|
+
vectors: List[np.ndarray],
|
|
449
|
+
index_names: List[str],
|
|
450
|
+
k: int = 10,
|
|
451
|
+
) -> List[Context]:
|
|
452
|
+
"""
|
|
453
|
+
Hybrid search the vector database.
|
|
454
|
+
"""
|
|
455
|
+
raise NotImplementedError("hybrid_search is not implemented for this vector database")
|
|
456
|
+
|
|
457
|
+
def query(self, query: str, k: int = 10, filter: Optional[dict] = None) -> List[Context]:
|
|
458
|
+
"""
|
|
459
|
+
Query the vector database.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
query: The query string
|
|
463
|
+
k: The number of results to return
|
|
464
|
+
filter: Optional filter to apply to the query
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
List of serialized contexts matching the query.
|
|
468
|
+
"""
|
|
469
|
+
raise NotImplementedError("query is not implemented for this vector database")
|
|
470
|
+
|
|
471
|
+
def set_embedder(self, embedder: Embedder):
|
|
472
|
+
self.default_embedder = embedder
|
|
473
|
+
|
|
474
|
+
def upsert_records(self, collection_name: str, records: Iterable[Context]):
|
|
475
|
+
"""
|
|
476
|
+
Upsert a record into the vector database.
|
|
477
|
+
"""
|
|
478
|
+
raise NotImplementedError("upsert_record is not implemented for this vector database")
|
|
479
|
+
|
|
480
|
+
def delete_records(self, collection_name: str, context_ids: Iterable[str]):
|
|
481
|
+
"""
|
|
482
|
+
Delete a record from the vector database.
|
|
483
|
+
"""
|
|
484
|
+
raise NotImplementedError("delete_record is not implemented for this vector database")
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
@runtime_checkable
|
|
488
|
+
class VectorDBBackend(Protocol):
|
|
489
|
+
_name: ClassVar[str]
|
|
490
|
+
_client: Any
|
|
491
|
+
mql_translator: Visitor
|
|
492
|
+
|
|
493
|
+
def __init__(self, *args, **kwargs) -> Any: ...
|
|
494
|
+
def create_record(self, embedding_map: Dict[str, np.ndarray], context: Context) -> Any: ...
|
|
495
|
+
def add_records(self, collection_name: str, records: List[Any]) -> None: ...
|
|
496
|
+
def drop_collection(self, collection_name: str) -> None: ...
|
|
497
|
+
def create_collection(
|
|
498
|
+
self,
|
|
499
|
+
collection_name: str,
|
|
500
|
+
payload_class: Type[Context],
|
|
501
|
+
index: IndexConfig = IndexConfig(), # noqa: B008
|
|
502
|
+
) -> None: ...
|
|
503
|
+
def list_collections(self) -> List[str]: ...
|
|
504
|
+
def has_collection(self, collection_name: str) -> bool: ...
|
|
505
|
+
def search(
|
|
506
|
+
self,
|
|
507
|
+
collection_name: str,
|
|
508
|
+
vectors: List[np.ndarray],
|
|
509
|
+
payload_class: Type[Context],
|
|
510
|
+
k: int,
|
|
511
|
+
filter: Optional[Any], # Any the backend's native filtering language
|
|
512
|
+
) -> List[List[SearchResult]]: ...
|
|
513
|
+
def get_records(
|
|
514
|
+
self, collection_name: str, payload_class: Type[Context], record_ids: List[str]
|
|
515
|
+
) -> List[Context]: ...
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
COMMON_EXT = {
|
|
519
|
+
"reindex",
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
@runtime_checkable
|
|
524
|
+
class SupportsBM25(VectorDBBackend, Protocol):
|
|
525
|
+
def bm25_search(
|
|
526
|
+
self,
|
|
527
|
+
collection_name: str,
|
|
528
|
+
query: str,
|
|
529
|
+
k: int,
|
|
530
|
+
) -> List[Context]: ...
|
|
531
|
+
def create_bm25_collection(
|
|
532
|
+
self,
|
|
533
|
+
collection_name: str,
|
|
534
|
+
payload_class: Type[Context],
|
|
535
|
+
exists_behavior: Literal["fail", "replace"] = "replace",
|
|
536
|
+
) -> List[Context]: ...
|
|
537
|
+
def load_bm25_collection(
|
|
538
|
+
self,
|
|
539
|
+
collection_name: str,
|
|
540
|
+
payload_class: Type[Context],
|
|
541
|
+
) -> List[Context]: ...
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
@runtime_checkable
|
|
545
|
+
class SupportsHybridSearch(VectorDBBackend, Protocol):
|
|
546
|
+
def hybrid_search(
|
|
547
|
+
self,
|
|
548
|
+
collection_name: str,
|
|
549
|
+
vectors: Dict[str, np.ndarray],
|
|
550
|
+
k: int,
|
|
551
|
+
) -> List[Context]: ...
|
|
552
|
+
def create_hybrid_collection(
|
|
553
|
+
self,
|
|
554
|
+
collection_name: str,
|
|
555
|
+
payload_class: Type[Context],
|
|
556
|
+
indexes: Dict[str, IndexConfig],
|
|
557
|
+
exists_behavior: Literal["fail", "replace"] = "replace",
|
|
558
|
+
) -> List[Context]: ...
|
|
559
|
+
def load_hybrid_collection(
|
|
560
|
+
self,
|
|
561
|
+
collection_name: str,
|
|
562
|
+
payload_class: Type[Context],
|
|
563
|
+
indexes: Dict[str, IndexConfig],
|
|
564
|
+
) -> List[Context]: ...
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class VDBExtensions(Generic[TBackend]):
|
|
568
|
+
backend: TBackend
|
|
569
|
+
|
|
570
|
+
def __init__(self, backend: TBackend):
|
|
571
|
+
self.backend = backend
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def client(self) -> Any:
|
|
575
|
+
return self.backend._client
|
|
576
|
+
|
|
577
|
+
# Use constrained TypeVars so intersection Protocols bind correctly
|
|
578
|
+
TSupportsBM25 = TypeVar("TSupportsBM25", bound=SupportsBM25)
|
|
579
|
+
TSupportsHybridSearch = TypeVar("TSupportsHybridSearch", bound=SupportsHybridSearch)
|
|
580
|
+
|
|
581
|
+
@overload
|
|
582
|
+
def hybrid_search(
|
|
583
|
+
self: "VDBExtensions[TSupportsHybridSearch]",
|
|
584
|
+
collection_name: str,
|
|
585
|
+
vectors: Dict[str, np.ndarray],
|
|
586
|
+
k: int,
|
|
587
|
+
) -> List[Context]: ...
|
|
588
|
+
|
|
589
|
+
@overload
|
|
590
|
+
def hybrid_search(
|
|
591
|
+
self: "VDBExtensions[TBackend]",
|
|
592
|
+
collection_name: str,
|
|
593
|
+
vectors: Dict[str, np.ndarray],
|
|
594
|
+
k: int,
|
|
595
|
+
) -> NoReturn: ...
|
|
596
|
+
|
|
597
|
+
def hybrid_search(
|
|
598
|
+
self: "VDBExtensions[TBackend]",
|
|
599
|
+
collection_name: str,
|
|
600
|
+
vectors: Dict[str, np.ndarray],
|
|
601
|
+
k: int,
|
|
602
|
+
):
|
|
603
|
+
if not isinstance(self.backend, SupportsHybridSearch):
|
|
604
|
+
raise AttributeError(
|
|
605
|
+
f"""{self.backend._name} does not support the function reindex.
|
|
606
|
+
|
|
607
|
+
Available functions: {self.available()}
|
|
608
|
+
"""
|
|
609
|
+
)
|
|
610
|
+
return self.backend.hybrid_search(collection_name, vectors, k)
|
|
611
|
+
|
|
612
|
+
@overload
|
|
613
|
+
def bm25_search(
|
|
614
|
+
self: "VDBExtensions[TSupportsBM25]",
|
|
615
|
+
collection_name: str,
|
|
616
|
+
vectors: List[np.ndarray],
|
|
617
|
+
index_names: List[str],
|
|
618
|
+
k: int,
|
|
619
|
+
) -> List[Context]: ...
|
|
620
|
+
|
|
621
|
+
@overload
|
|
622
|
+
def bm25_search(
|
|
623
|
+
self: "VDBExtensions[TBackend]",
|
|
624
|
+
collection_name: str,
|
|
625
|
+
vectors: List[np.ndarray],
|
|
626
|
+
index_names: List[str],
|
|
627
|
+
k: int,
|
|
628
|
+
) -> NoReturn: ...
|
|
629
|
+
|
|
630
|
+
def bm25_search(
|
|
631
|
+
self: "VDBExtensions[TBackend]",
|
|
632
|
+
collection_name: str,
|
|
633
|
+
vectors: List[np.ndarray],
|
|
634
|
+
index_names: List[str],
|
|
635
|
+
k: int,
|
|
636
|
+
) -> List[Context]:
|
|
637
|
+
if not isinstance(self.backend, SupportsBM25):
|
|
638
|
+
raise AttributeError(
|
|
639
|
+
f"""{self.backend._name} does not support the function hybrid_search.
|
|
640
|
+
|
|
641
|
+
Available functions: {self.available()}
|
|
642
|
+
"""
|
|
643
|
+
)
|
|
644
|
+
return self.backend.hybrid_search(collection_name, vectors, index_names, k)
|
|
645
|
+
|
|
646
|
+
@overload
|
|
647
|
+
def create_hybrid_collection(
|
|
648
|
+
self: "VDBExtensions[TSupportsHybridSearch]",
|
|
649
|
+
query: str,
|
|
650
|
+
k: int,
|
|
651
|
+
filter: Optional[dict],
|
|
652
|
+
) -> List[Context]: ...
|
|
653
|
+
|
|
654
|
+
@overload
|
|
655
|
+
def create_hybrid_collection(
|
|
656
|
+
self: "VDBExtensions[TBackend]", query: str, k: int, filter: Optional[dict]
|
|
657
|
+
) -> NoReturn: ...
|
|
658
|
+
|
|
659
|
+
def create_hybrid_collection(
|
|
660
|
+
self: "VDBExtensions[TBackend]", query: str, k: int, filter: Optional[dict]
|
|
661
|
+
) -> List[Context]:
|
|
662
|
+
if not isinstance(self.backend, SupportsHybridSearch):
|
|
663
|
+
raise AttributeError(
|
|
664
|
+
f"""{self.backend._name} does not support the function query.
|
|
665
|
+
|
|
666
|
+
Available functions: {self.available()}
|
|
667
|
+
"""
|
|
668
|
+
)
|
|
669
|
+
return self.backend.query(query, k, filter)
|
|
670
|
+
|
|
671
|
+
def has(self, op: str) -> bool:
|
|
672
|
+
fn = getattr(self, op, None)
|
|
673
|
+
return callable(fn)
|
|
674
|
+
|
|
675
|
+
def available(self) -> List[str]:
|
|
676
|
+
return [op for op in COMMON_EXT if self.has(op)]
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def _add_ebedmes_and_return_context(
|
|
680
|
+
embedmes: Dict[str | None, List[str | Image.Image]],
|
|
681
|
+
item: Embeddable | Tuple[str | Image.Image, Context],
|
|
682
|
+
) -> Context:
|
|
683
|
+
"""
|
|
684
|
+
Adds all embedmes to the embedmes dictionary and returns the context.
|
|
685
|
+
"""
|
|
686
|
+
# Fast type check for tuple
|
|
687
|
+
if type(item) is tuple:
|
|
688
|
+
embedme = item[0]
|
|
689
|
+
for index in embedmes.keys():
|
|
690
|
+
embedmes[index].append(embedme)
|
|
691
|
+
return item[1]
|
|
692
|
+
elif _has_multiple_embedmes(item):
|
|
693
|
+
# CAVEAT: Context objects that implement Embeddable protocol and take in an index name as a parameter also accept None as the default index.
|
|
694
|
+
for index in embedmes.keys():
|
|
695
|
+
embedmes[index].append(item.embedme(index))
|
|
696
|
+
return item
|
|
697
|
+
else:
|
|
698
|
+
for index in embedmes.keys():
|
|
699
|
+
embedmes[index].append(item.embedme())
|
|
700
|
+
return item
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def _has_multiple_embedmes(
|
|
704
|
+
item: Embeddable,
|
|
705
|
+
):
|
|
706
|
+
"""
|
|
707
|
+
Check if the item has multiple embedmes.
|
|
708
|
+
"""
|
|
709
|
+
return item.embedme.__code__.co_argcount == 2
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _items_have_multiple_embedmes(
|
|
713
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
714
|
+
):
|
|
715
|
+
"""
|
|
716
|
+
Check if the first record has multiple embedmes.
|
|
717
|
+
"""
|
|
718
|
+
p = peekable(records)
|
|
719
|
+
first_item = p.peek()
|
|
720
|
+
if isinstance(first_item, Embeddable) and _has_multiple_embedmes(first_item):
|
|
721
|
+
return True
|
|
722
|
+
return False
|