modaic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modaic might be problematic. Click here for more details.
- modaic/__init__.py +25 -0
- modaic/agents/rag_agent.py +33 -0
- modaic/agents/registry.py +84 -0
- modaic/auto_agent.py +228 -0
- modaic/context/__init__.py +34 -0
- modaic/context/base.py +1064 -0
- modaic/context/dtype_mapping.py +25 -0
- modaic/context/table.py +585 -0
- modaic/context/text.py +94 -0
- modaic/databases/__init__.py +35 -0
- modaic/databases/graph_database.py +269 -0
- modaic/databases/sql_database.py +355 -0
- modaic/databases/vector_database/__init__.py +12 -0
- modaic/databases/vector_database/benchmarks/baseline.py +123 -0
- modaic/databases/vector_database/benchmarks/common.py +48 -0
- modaic/databases/vector_database/benchmarks/fork.py +132 -0
- modaic/databases/vector_database/benchmarks/threaded.py +119 -0
- modaic/databases/vector_database/vector_database.py +722 -0
- modaic/databases/vector_database/vendors/milvus.py +408 -0
- modaic/databases/vector_database/vendors/mongodb.py +0 -0
- modaic/databases/vector_database/vendors/pinecone.py +0 -0
- modaic/databases/vector_database/vendors/qdrant.py +1 -0
- modaic/exceptions.py +38 -0
- modaic/hub.py +305 -0
- modaic/indexing.py +127 -0
- modaic/module_utils.py +341 -0
- modaic/observability.py +275 -0
- modaic/precompiled.py +429 -0
- modaic/query_language.py +321 -0
- modaic/storage/__init__.py +3 -0
- modaic/storage/file_store.py +239 -0
- modaic/storage/pickle_store.py +25 -0
- modaic/types.py +287 -0
- modaic/utils.py +21 -0
- modaic-0.1.0.dist-info/METADATA +281 -0
- modaic-0.1.0.dist-info/RECORD +39 -0
- modaic-0.1.0.dist-info/WHEEL +5 -0
- modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
- modaic-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import immutables
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from modaic.context.base import Context, Embeddable
|
|
9
|
+
|
|
10
|
+
from .common import _add_item_embedme, _items_have_multiple_embedmes
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from modaic.databases.vector_database import VectorDatabase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def has_collection(vdb: "VectorDatabase", collection_name: str) -> bool:
|
|
17
|
+
"""
|
|
18
|
+
Check if a collection exists in the vector database.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
collection_name: The name of the collection to check
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
True if the collection exists, False otherwise
|
|
25
|
+
"""
|
|
26
|
+
return vdb.ext.backend.has_collection(collection_name)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _embed_and_add_records(
|
|
30
|
+
vdb: "VectorDatabase",
|
|
31
|
+
collection_name: str,
|
|
32
|
+
embedmes: Dict[str, List[str | Image.Image]] | Dict[None, List[str | Image.Image]],
|
|
33
|
+
contexts: List[Context],
|
|
34
|
+
):
|
|
35
|
+
# TODO: could add functionality for multiple embedmes per context (e.g. you want to embed both an image and a text description of an image)
|
|
36
|
+
all_embeddings = {}
|
|
37
|
+
if collection_name not in vdb.collections:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Collection {collection_name} not found in VectorDatabase's indexes, Please use VectorDatabase.create_collection() to create a collection first. Alternatively, you can use VectorDatabase.load_collection() to add records to an existing collection."
|
|
40
|
+
)
|
|
41
|
+
try:
|
|
42
|
+
first_index = next(iter(vdb.collections[collection_name].indexes.keys()))
|
|
43
|
+
# NOTE: get embeddings for each index
|
|
44
|
+
for index_name, index_config in vdb.collections[collection_name].indexes.items():
|
|
45
|
+
embeddings = index_config.embedder(embedmes)
|
|
46
|
+
# NOTE: Ensure embeddings is a 2D array (DSPy returns 1D for single strings, 2D for lists)
|
|
47
|
+
if embeddings.ndim == 1:
|
|
48
|
+
embeddings = embeddings.reshape(1, -1)
|
|
49
|
+
# NOTE: If index_name is None use the only index for the collection
|
|
50
|
+
all_embeddings[index_name or first_index] = embeddings
|
|
51
|
+
except Exception as e:
|
|
52
|
+
raise ValueError(f"Failed to create embeddings for index: {index_name}") from e
|
|
53
|
+
|
|
54
|
+
data_to_insert: List[immutables.Map[str, np.ndarray]] = []
|
|
55
|
+
# FIXME Probably should add type checking to ensure context matches schema, not sure how to do this efficiently
|
|
56
|
+
for i, item in enumerate(contexts):
|
|
57
|
+
embedding_map: dict[str, np.ndarray] = {}
|
|
58
|
+
for index_name, embedding in all_embeddings.items():
|
|
59
|
+
embedding_map[index_name] = embedding[i]
|
|
60
|
+
|
|
61
|
+
# Create a record with embedding and validated metadata
|
|
62
|
+
record = vdb.ext.backend.create_record(embedding_map, item)
|
|
63
|
+
data_to_insert.append(record)
|
|
64
|
+
|
|
65
|
+
vdb.ext.backend.add_records(collection_name, data_to_insert)
|
|
66
|
+
del data_to_insert
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def add_records(
|
|
70
|
+
vdb: "VectorDatabase",
|
|
71
|
+
collection_name: str,
|
|
72
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
73
|
+
batch_size: Optional[int] = None,
|
|
74
|
+
embedme_scope: Literal["auto", "context", "index"] = "auto",
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Add items to a collection in the vector database.
|
|
78
|
+
Uses the Context's get_embed_context() method and the embedder to create embeddings.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
collection_name: The name of the collection to add records to
|
|
82
|
+
records: The records to add to the collection
|
|
83
|
+
batch_size: Optional batch size for processing records
|
|
84
|
+
"""
|
|
85
|
+
if not records:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
# NOTE: Make embedmes compatible with the ext's hybrid search function
|
|
89
|
+
if embedme_scope == "auto":
|
|
90
|
+
if _items_have_multiple_embedmes(records):
|
|
91
|
+
embedme_scope = "index"
|
|
92
|
+
else:
|
|
93
|
+
embedme_scope = "context"
|
|
94
|
+
|
|
95
|
+
if embedme_scope == "index":
|
|
96
|
+
embedmes: Dict[str, List[str | Image.Image]] = {k: [] for k in vdb.collections[collection_name].indexes.keys()}
|
|
97
|
+
else:
|
|
98
|
+
# CAVEAT: We make embedmes a dict with None as opposed to a list so we don't have to type check it
|
|
99
|
+
embedmes: Dict[None, List[str | Image.Image]] = {None: []}
|
|
100
|
+
|
|
101
|
+
serialized_contexts = []
|
|
102
|
+
# TODO: add multi-processing/multi-threading here, just ensure that the backend is thread-safe. Maybe we add a class level parameter to check if the vendor is thread-safe. Embedding will still need to happen on a single thread
|
|
103
|
+
with tqdm(total=len(records), desc="Adding records to vector database", position=0) as pbar:
|
|
104
|
+
for _, item in tqdm(
|
|
105
|
+
enumerate(records),
|
|
106
|
+
desc="Adding records to vector database",
|
|
107
|
+
position=0,
|
|
108
|
+
leave=False,
|
|
109
|
+
):
|
|
110
|
+
_add_item_embedme(embedmes, item)
|
|
111
|
+
serialized_contexts.append(item)
|
|
112
|
+
|
|
113
|
+
if batch_size is not None and len(serialized_contexts) == batch_size:
|
|
114
|
+
_embed_and_add_records(vdb, collection_name, embedmes, serialized_contexts)
|
|
115
|
+
if embedme_scope == "index":
|
|
116
|
+
embedmes = {k: [] for k in embedmes.keys()}
|
|
117
|
+
else:
|
|
118
|
+
embedmes = {None: []}
|
|
119
|
+
serialized_contexts = []
|
|
120
|
+
pbar.update(batch_size)
|
|
121
|
+
|
|
122
|
+
if serialized_contexts:
|
|
123
|
+
_embed_and_add_records(vdb, collection_name, embedmes, serialized_contexts)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Dict, Iterable, List, Tuple
|
|
2
|
+
|
|
3
|
+
from more_itertools import peekable
|
|
4
|
+
from PIL import Image
|
|
5
|
+
|
|
6
|
+
from modaic.context.base import Context, Embeddable
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _has_multiple_embedmes(
|
|
10
|
+
item: Embeddable,
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Check if the item has multiple embedmes.
|
|
14
|
+
"""
|
|
15
|
+
return item.embedme.__code__.co_argcount == 2
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _items_have_multiple_embedmes(
|
|
19
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Check if the first record has multiple embedmes.
|
|
23
|
+
"""
|
|
24
|
+
p = peekable(records)
|
|
25
|
+
first_item = p.peek()
|
|
26
|
+
if isinstance(first_item, Embeddable) and _has_multiple_embedmes(first_item):
|
|
27
|
+
return True
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _add_item_embedme(
|
|
32
|
+
embedmes: Dict[str | None, List[str | Image.Image]],
|
|
33
|
+
item: Embeddable | Tuple[str | Image.Image, Context],
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Adds an item to the embedmes dictionary.
|
|
37
|
+
"""
|
|
38
|
+
# Fast type check for tuple
|
|
39
|
+
if type(item) is tuple:
|
|
40
|
+
embedme = item[0]
|
|
41
|
+
for index in embedmes.keys():
|
|
42
|
+
embedmes[index].append(embedme)
|
|
43
|
+
elif _has_multiple_embedmes(item):
|
|
44
|
+
# CAVEAT: Context objects that implement Embeddable protocol and take in an index name as a parameter also accept None as the default index.
|
|
45
|
+
for index in embedmes.keys():
|
|
46
|
+
embedmes[index].append(item.embedme(index))
|
|
47
|
+
else:
|
|
48
|
+
embedmes[None].append(item.embedme())
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import multiprocessing as mp
|
|
2
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from PIL import Image
|
|
7
|
+
|
|
8
|
+
from modaic.context.base import Context, Embeddable
|
|
9
|
+
|
|
10
|
+
from .common import _add_item_embedme, _items_have_multiple_embedmes
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from modaic.databases.vector_database.vector_database import VectorDatabase, VectorDBBackend
|
|
14
|
+
MAX_IN_FLIGHT = 8
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def add_records(
|
|
18
|
+
self: "VectorDatabase",
|
|
19
|
+
collection_name: str,
|
|
20
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
21
|
+
batch_size: Optional[int] = None,
|
|
22
|
+
embedme_scope: Literal["auto", "context", "index"] = "auto",
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Add items to a collection in the vector database.
|
|
26
|
+
Uses the Context's get_embed_context() method and the embedder to create embeddings.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
collection_name: The name of the collection to add records to
|
|
30
|
+
records: The records to add to the collection
|
|
31
|
+
batch_size: Optional batch size for processing records
|
|
32
|
+
"""
|
|
33
|
+
if not records:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
# NOTE: Make embedmes compatible with the ext's hybrid search function
|
|
37
|
+
if embedme_scope == "auto":
|
|
38
|
+
if _items_have_multiple_embedmes(records):
|
|
39
|
+
embedme_scope = "index"
|
|
40
|
+
else:
|
|
41
|
+
embedme_scope = "context"
|
|
42
|
+
|
|
43
|
+
# TODO: add multi-processing/multi-threading here, just ensure that the backend is thread-safe. Maybe we add a class level parameter to check if the vendor is thread-safe. Embedding will still need to happen on a single thread
|
|
44
|
+
def gen_embeded_records():
|
|
45
|
+
if embedme_scope == "index":
|
|
46
|
+
embedmes: Dict[str, List[str | Image.Image]] = {
|
|
47
|
+
k: [] for k in self.collections[collection_name].indexes.keys()
|
|
48
|
+
}
|
|
49
|
+
else:
|
|
50
|
+
# CAVEAT: We make embedmes a dict with None as opposed to a list so we don't have to type check it
|
|
51
|
+
embedmes: Dict[None, List[str | Image.Image]] = {None: []}
|
|
52
|
+
|
|
53
|
+
serialized_contexts = []
|
|
54
|
+
|
|
55
|
+
for item in records:
|
|
56
|
+
_add_item_embedme(embedmes, item)
|
|
57
|
+
serialized_contexts.append(item)
|
|
58
|
+
|
|
59
|
+
if batch_size is not None and len(serialized_contexts) == batch_size:
|
|
60
|
+
yield _embed_records(self, collection_name, embedmes), serialized_contexts
|
|
61
|
+
if embedme_scope == "index":
|
|
62
|
+
embedmes = {k: [] for k in embedmes.keys()}
|
|
63
|
+
else:
|
|
64
|
+
embedmes = {None: []}
|
|
65
|
+
serialized_contexts = []
|
|
66
|
+
|
|
67
|
+
if serialized_contexts:
|
|
68
|
+
yield _embed_records(self, collection_name, embedmes), serialized_contexts
|
|
69
|
+
|
|
70
|
+
mp.set_start_method("fork")
|
|
71
|
+
|
|
72
|
+
with ProcessPoolExecutor(max_workers=4) as pool:
|
|
73
|
+
pending = set()
|
|
74
|
+
for records in gen_embeded_records():
|
|
75
|
+
if len(pending) >= MAX_IN_FLIGHT:
|
|
76
|
+
done = next(as_completed(pending))
|
|
77
|
+
pending.remove(done)
|
|
78
|
+
done.result() # raise any backend error now
|
|
79
|
+
|
|
80
|
+
pending.add(pool.submit(backend_add_records, self.ext.backend, collection_name, records))
|
|
81
|
+
for fut in as_completed(pending):
|
|
82
|
+
fut.result()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _embed_records(
|
|
86
|
+
self: "VectorDatabase",
|
|
87
|
+
collection_name: str,
|
|
88
|
+
embedmes: Dict[str, List[str | Image.Image]] | Dict[None, List[str | Image.Image]],
|
|
89
|
+
):
|
|
90
|
+
# TODO: could add functionality for multiple embedmes per context (e.g. you want to embed both an image and a text description of an image)
|
|
91
|
+
all_embeddings = {}
|
|
92
|
+
if collection_name not in self.collections:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Collection {collection_name} not found in VectorDatabase's indexes, Please use VectorDatabase.create_collection() to create a collection first. Alternatively, you can use VectorDatabase.load_collection() to add records to an existing collection."
|
|
95
|
+
)
|
|
96
|
+
try:
|
|
97
|
+
first_index = next(iter(self.collections[collection_name].indexes.keys()))
|
|
98
|
+
# NOTE: get embeddings for each index
|
|
99
|
+
for index_name, index_config in self.collections[collection_name].indexes.items():
|
|
100
|
+
embeddings = index_config.embedder(embedmes)
|
|
101
|
+
# NOTE: Ensure embeddings is a 2D array (DSPy returns 1D for single strings, 2D for lists)
|
|
102
|
+
if embeddings.ndim == 1:
|
|
103
|
+
embeddings = embeddings.reshape(1, -1)
|
|
104
|
+
# NOTE: If index_name is None use the only index for the collection
|
|
105
|
+
all_embeddings[index_name or first_index] = embeddings
|
|
106
|
+
except Exception as e:
|
|
107
|
+
raise ValueError(f"Failed to create embeddings for index: {index_name}") from e
|
|
108
|
+
return all_embeddings
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _add_records(
|
|
112
|
+
self: "VectorDatabase",
|
|
113
|
+
collection_name: str,
|
|
114
|
+
embeddings: Dict[str, np.ndarray],
|
|
115
|
+
contexts: List[Context],
|
|
116
|
+
):
|
|
117
|
+
data_to_insert: List[Dict[str, np.ndarray]] = []
|
|
118
|
+
# FIXME Probably should add type checking to ensure context matches schema, not sure how to do this efficiently
|
|
119
|
+
for i, item in enumerate(contexts):
|
|
120
|
+
embedding_map: dict[str, np.ndarray] = {}
|
|
121
|
+
for index_name, embedding in embeddings.items():
|
|
122
|
+
embedding_map[index_name] = embedding[i]
|
|
123
|
+
|
|
124
|
+
# Create a record with embedding and validated metadata
|
|
125
|
+
record = self.ext.backend.create_record(embedding_map, item)
|
|
126
|
+
data_to_insert.append(record)
|
|
127
|
+
|
|
128
|
+
self.ext.backend.add_records(collection_name, data_to_insert)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def backend_add_records(backend: "VectorDBBackend", collection_name: str, records: List[Dict[str, np.ndarray]]):
|
|
132
|
+
backend.add_records(collection_name, records)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import immutables
|
|
5
|
+
import numpy as np
|
|
6
|
+
from PIL import Image
|
|
7
|
+
|
|
8
|
+
from modaic.context.base import Context, Embeddable
|
|
9
|
+
|
|
10
|
+
from .common import _add_item_embedme, _items_have_multiple_embedmes
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from modaic.databases.vector_database.vector_database import VectorDatabase
|
|
14
|
+
MAX_IN_FLIGHT = 8
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def add_records(
|
|
18
|
+
self: "VectorDatabase",
|
|
19
|
+
collection_name: str,
|
|
20
|
+
records: Iterable[Embeddable | Tuple[str | Image.Image, Context]],
|
|
21
|
+
batch_size: Optional[int] = None,
|
|
22
|
+
embedme_scope: Literal["auto", "context", "index"] = "auto",
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Add items to a collection in the vector database.
|
|
26
|
+
Uses the Context's get_embed_context() method and the embedder to create embeddings.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
collection_name: The name of the collection to add records to
|
|
30
|
+
records: The records to add to the collection
|
|
31
|
+
batch_size: Optional batch size for processing records
|
|
32
|
+
"""
|
|
33
|
+
if not records:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
# NOTE: Make embedmes compatible with the ext's hybrid search function
|
|
37
|
+
if embedme_scope == "auto":
|
|
38
|
+
if _items_have_multiple_embedmes(records):
|
|
39
|
+
embedme_scope = "index"
|
|
40
|
+
else:
|
|
41
|
+
embedme_scope = "context"
|
|
42
|
+
|
|
43
|
+
# TODO: add multi-processing/multi-threading here, just ensure that the backend is thread-safe. Maybe we add a class level parameter to check if the vendor is thread-safe. Embedding will still need to happen on a single thread
|
|
44
|
+
def gen_embeded_records():
|
|
45
|
+
if embedme_scope == "index":
|
|
46
|
+
embedmes: Dict[str, List[str | Image.Image]] = {
|
|
47
|
+
k: [] for k in self.collections[collection_name].indexes.keys()
|
|
48
|
+
}
|
|
49
|
+
else:
|
|
50
|
+
# CAVEAT: We make embedmes a dict with None as opposed to a list so we don't have to type check it
|
|
51
|
+
embedmes: Dict[None, List[str | Image.Image]] = {None: []}
|
|
52
|
+
|
|
53
|
+
serialized_contexts = []
|
|
54
|
+
|
|
55
|
+
for item in records:
|
|
56
|
+
_add_item_embedme(embedmes, item)
|
|
57
|
+
serialized_contexts.append(item)
|
|
58
|
+
|
|
59
|
+
if batch_size is not None and len(serialized_contexts) == batch_size:
|
|
60
|
+
yield _embed_and_create_records(self, collection_name, embedmes, serialized_contexts)
|
|
61
|
+
if embedme_scope == "index":
|
|
62
|
+
embedmes = {k: [] for k in embedmes.keys()}
|
|
63
|
+
else:
|
|
64
|
+
embedmes = {None: []}
|
|
65
|
+
serialized_contexts = []
|
|
66
|
+
|
|
67
|
+
if serialized_contexts:
|
|
68
|
+
yield _embed_and_create_records(self, collection_name, embedmes, serialized_contexts)
|
|
69
|
+
|
|
70
|
+
with ThreadPoolExecutor(max_workers=8) as pool:
|
|
71
|
+
pending = set()
|
|
72
|
+
for records in gen_embeded_records():
|
|
73
|
+
if len(pending) >= MAX_IN_FLIGHT:
|
|
74
|
+
done = next(as_completed(pending))
|
|
75
|
+
pending.remove(done)
|
|
76
|
+
done.result() # raise any backend error now
|
|
77
|
+
|
|
78
|
+
pending.add(pool.submit(self.ext.backend.add_records, collection_name, records))
|
|
79
|
+
for fut in as_completed(pending):
|
|
80
|
+
fut.result()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _embed_and_create_records(
|
|
84
|
+
self: "VectorDatabase",
|
|
85
|
+
collection_name: str,
|
|
86
|
+
embedmes: Dict[str, List[str | Image.Image]] | Dict[None, List[str | Image.Image]],
|
|
87
|
+
contexts: List[Context],
|
|
88
|
+
):
|
|
89
|
+
# TODO: could add functionality for multiple embedmes per context (e.g. you want to embed both an image and a text description of an image)
|
|
90
|
+
all_embeddings = {}
|
|
91
|
+
if collection_name not in self.collections:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Collection {collection_name} not found in VectorDatabase's indexes, Please use VectorDatabase.create_collection() to create a collection first. Alternatively, you can use VectorDatabase.load_collection() to add records to an existing collection."
|
|
94
|
+
)
|
|
95
|
+
try:
|
|
96
|
+
first_index = next(iter(self.collections[collection_name].indexes.keys()))
|
|
97
|
+
# NOTE: get embeddings for each index
|
|
98
|
+
for index_name, index_config in self.collections[collection_name].indexes.items():
|
|
99
|
+
embeddings = index_config.embedder(embedmes)
|
|
100
|
+
# NOTE: Ensure embeddings is a 2D array (DSPy returns 1D for single strings, 2D for lists)
|
|
101
|
+
if embeddings.ndim == 1:
|
|
102
|
+
embeddings = embeddings.reshape(1, -1)
|
|
103
|
+
# NOTE: If index_name is None use the only index for the collection
|
|
104
|
+
all_embeddings[index_name or first_index] = embeddings
|
|
105
|
+
except Exception as e:
|
|
106
|
+
raise ValueError(f"Failed to create embeddings for index: {index_name}") from e
|
|
107
|
+
|
|
108
|
+
data_to_insert: List[immutables.Map[str, np.ndarray]] = []
|
|
109
|
+
# FIXME Probably should add type checking to ensure context matches schema, not sure how to do this efficiently
|
|
110
|
+
for i, item in enumerate(contexts):
|
|
111
|
+
embedding_map: dict[str, np.ndarray] = {}
|
|
112
|
+
for index_name, embedding in all_embeddings.items():
|
|
113
|
+
embedding_map[index_name] = embedding[i]
|
|
114
|
+
|
|
115
|
+
# Create a record with embedding and validated metadata
|
|
116
|
+
record = self.ext.backend.create_record(embedding_map, item)
|
|
117
|
+
data_to_insert.append(record)
|
|
118
|
+
|
|
119
|
+
return data_to_insert
|