modaic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modaic might be problematic. Click here for more details.

Files changed (39) hide show
  1. modaic/__init__.py +25 -0
  2. modaic/agents/rag_agent.py +33 -0
  3. modaic/agents/registry.py +84 -0
  4. modaic/auto_agent.py +228 -0
  5. modaic/context/__init__.py +34 -0
  6. modaic/context/base.py +1064 -0
  7. modaic/context/dtype_mapping.py +25 -0
  8. modaic/context/table.py +585 -0
  9. modaic/context/text.py +94 -0
  10. modaic/databases/__init__.py +35 -0
  11. modaic/databases/graph_database.py +269 -0
  12. modaic/databases/sql_database.py +355 -0
  13. modaic/databases/vector_database/__init__.py +12 -0
  14. modaic/databases/vector_database/benchmarks/baseline.py +123 -0
  15. modaic/databases/vector_database/benchmarks/common.py +48 -0
  16. modaic/databases/vector_database/benchmarks/fork.py +132 -0
  17. modaic/databases/vector_database/benchmarks/threaded.py +119 -0
  18. modaic/databases/vector_database/vector_database.py +722 -0
  19. modaic/databases/vector_database/vendors/milvus.py +408 -0
  20. modaic/databases/vector_database/vendors/mongodb.py +0 -0
  21. modaic/databases/vector_database/vendors/pinecone.py +0 -0
  22. modaic/databases/vector_database/vendors/qdrant.py +1 -0
  23. modaic/exceptions.py +38 -0
  24. modaic/hub.py +305 -0
  25. modaic/indexing.py +127 -0
  26. modaic/module_utils.py +341 -0
  27. modaic/observability.py +275 -0
  28. modaic/precompiled.py +429 -0
  29. modaic/query_language.py +321 -0
  30. modaic/storage/__init__.py +3 -0
  31. modaic/storage/file_store.py +239 -0
  32. modaic/storage/pickle_store.py +25 -0
  33. modaic/types.py +287 -0
  34. modaic/utils.py +21 -0
  35. modaic-0.1.0.dist-info/METADATA +281 -0
  36. modaic-0.1.0.dist-info/RECORD +39 -0
  37. modaic-0.1.0.dist-info/WHEEL +5 -0
  38. modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
  39. modaic-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,408 @@
1
+ from collections.abc import Mapping
2
+ from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, Union
3
+
4
+ import numpy as np
5
+ from langchain_community.query_constructors.milvus import MilvusTranslator as MilvusTranslator_
6
+ from langchain_core.structured_query import Comparator, Comparison, Visitor
7
+ from pymilvus import DataType, MilvusClient
8
+ from pymilvus.orm.collection import CollectionSchema
9
+
10
+ from ....context.base import Context
11
+ from ....exceptions import BackendCompatibilityError
12
+ from ....types import InnerField, Schema, SchemaField, float_format, int_format
13
+ from ..vector_database import DEFAULT_INDEX_NAME, IndexConfig, IndexType, SearchResult, VectorType
14
+
15
+ milvus_to_modaic_vector = {
16
+ VectorType.FLOAT: DataType.FLOAT_VECTOR,
17
+ VectorType.FLOAT16: DataType.FLOAT16_VECTOR,
18
+ VectorType.BFLOAT16: DataType.BFLOAT16_VECTOR,
19
+ VectorType.BINARY: DataType.BINARY_VECTOR,
20
+ VectorType.FLOAT_SPARSE: DataType.SPARSE_FLOAT_VECTOR,
21
+ # VectorType.INT8: DataType.INT8_VECTOR,
22
+ }
23
+
24
+ modaic_to_milvus_index = {
25
+ IndexType.DEFAULT: "AUTOINDEX",
26
+ IndexType.HNSW: "HNSW",
27
+ IndexType.FLAT: "FLAT",
28
+ IndexType.IVF_FLAT: "IVF_FLAT",
29
+ IndexType.IVF_SQ8: "IVF_SQ8",
30
+ IndexType.IVF_PQ: "IVF_PQ",
31
+ IndexType.IVF_RABITQ: "IVF_RABITQ",
32
+ IndexType.GPU_IVF_FLAT: "GPU_IVF_FLAT",
33
+ IndexType.GPU_IVF_PQ: "GPU_IVF_PQ",
34
+ IndexType.DISKANN: "DISKANN",
35
+ IndexType.BIN_FLAT: "BIN_FLAT",
36
+ IndexType.BIN_IVF_FLAT: "BIN_IVF_FLAT",
37
+ IndexType.MINHASH_LSH: "MINHASH_LSH",
38
+ IndexType.SPARSE_INVERTED_INDEX: "SPARSE_INVERTED_INDEX",
39
+ IndexType.INVERTED: "INVERTED",
40
+ IndexType.BITMAP: "BITMAP",
41
+ IndexType.TRIE: "TRIE",
42
+ IndexType.STL_SORT: "STL_SORT",
43
+ }
44
+
45
+ # Name for field that tracks which fields are null for a record (only used for milvus lite)
46
+ NULL_FIELD_NAME = "null_fields"
47
+
48
+
49
+ class MilvusTranslator(MilvusTranslator_):
50
+ """
51
+ Patch of langchain_community's MilvusTranslator to support lists of strings values.
52
+ """
53
+
54
+ def visit_comparison(self, comparison: Comparison) -> str:
55
+ comparator = self._format_func(comparison.comparator)
56
+ processed_value = process_value(comparison.value, comparison.comparator)
57
+ attribute = comparison.attribute
58
+
59
+ return "( " + attribute + " " + comparator + " " + processed_value + " )"
60
+
61
+
62
+ class MilvusBackend:
63
+ _name: ClassVar[Literal["milvus"]] = "milvus"
64
+ mql_translator: Visitor = MilvusTranslator()
65
+
66
+ def __init__(
67
+ self,
68
+ uri: str = "http://localhost:19530",
69
+ user: str = "",
70
+ password: str = "",
71
+ db_name: str = "",
72
+ token: str = "",
73
+ timeout: Optional[float] = None,
74
+ **kwargs,
75
+ ):
76
+ """
77
+ Initialize a Milvus vector database.
78
+ """
79
+
80
+ if uri.startswith(("http://", "https://", "tcp://")):
81
+ self.milvus_lite = False
82
+ elif uri.endswith(".db"):
83
+ self.milvus_lite = True
84
+ else:
85
+ raise ValueError(
86
+ f"Invalid URI: {uri}, must start with http://, https://, or tcp:// for milvus server or end with .db for milvus lite"
87
+ )
88
+ self._client = MilvusClient(
89
+ uri=uri,
90
+ user=user,
91
+ password=password,
92
+ db_name=db_name,
93
+ token=token,
94
+ timeout=timeout,
95
+ **kwargs,
96
+ )
97
+
98
+ def create_record(self, embedding_map: Dict[str, np.ndarray], context: Context) -> Any:
99
+ """
100
+ Convert a Context to a record for Milvus.
101
+ """
102
+ # CAVEAT: users can optionally hide fields from model_dump(). Use include_hidden=True to get all fields.
103
+ record = context.model_dump(include_hidden=True)
104
+ # NOTE: Track null values if using milvus lite since null values are not supported in milvus lite
105
+ if self.milvus_lite:
106
+ schema = context.schema().as_dict()
107
+ null_fields = []
108
+ for field_name, field_value in record.items():
109
+ if field_value is None:
110
+ null_fields.append(field_name)
111
+ if schema[field_name].type == "string":
112
+ record[field_name] = ""
113
+ elif schema[field_name].type == "array":
114
+ record[field_name] = []
115
+ elif schema[field_name].type == "object":
116
+ record[field_name] = {}
117
+ elif schema[field_name].type == "number" or schema[field_name].type == "integer":
118
+ record[field_name] = 0
119
+ elif schema[field_name].type == "boolean":
120
+ record[field_name] = False
121
+
122
+ record[NULL_FIELD_NAME] = null_fields
123
+
124
+ for index_name, embedding in embedding_map.items():
125
+ record[index_name] = embedding.tolist()
126
+ return record
127
+
128
+ def add_records(self, collection_name: str, records: List[Any]):
129
+ """
130
+ Add records to a Milvus collection.
131
+ """
132
+ self._client.insert(collection_name, records)
133
+
134
+ def list_collections(self) -> List[str]:
135
+ return self._client.list_collections()
136
+
137
+ def drop_collection(self, collection_name: str):
138
+ """
139
+ Drop a Milvus collection.
140
+ """
141
+ self._client.drop_collection(collection_name)
142
+
143
+ def create_collection(
144
+ self,
145
+ collection_name: str,
146
+ payload_class: Type[Context],
147
+ index: IndexConfig = IndexConfig(), # noqa: B008
148
+ ):
149
+ """
150
+ Create a Milvus collection.
151
+ """
152
+ if not issubclass(payload_class, Context):
153
+ raise TypeError(f"Payload class {payload_class} is must be a subclass of Context")
154
+
155
+ schema = _modaic_to_milvus_schema(self._client, payload_class.schema(), self.milvus_lite)
156
+ modaic_to_milvus_vector = {
157
+ VectorType.FLOAT: DataType.FLOAT_VECTOR,
158
+ VectorType.FLOAT16: DataType.FLOAT16_VECTOR,
159
+ VectorType.BFLOAT16: DataType.BFLOAT16_VECTOR,
160
+ VectorType.BINARY: DataType.BINARY_VECTOR,
161
+ VectorType.FLOAT_SPARSE: DataType.SPARSE_FLOAT_VECTOR,
162
+ # VectorType.INT8: DataType.INT8_VECTOR,
163
+ }
164
+
165
+ try:
166
+ vector_type = modaic_to_milvus_vector[index.vector_type]
167
+ except KeyError:
168
+ raise ValueError(f"Milvus does not support vector type: {index.vector_type}") from None
169
+ kwargs = {
170
+ "field_name": DEFAULT_INDEX_NAME,
171
+ "datatype": vector_type,
172
+ }
173
+ # NOTE: sparse vectors don't have a dim in milvus
174
+ if index.vector_type != VectorType.FLOAT_SPARSE:
175
+ kwargs["dim"] = index.embedder.embedding_dim
176
+ schema.add_field(**kwargs)
177
+
178
+ index_params = self._client.prepare_index_params()
179
+ index_type = modaic_to_milvus_index[index.index_type]
180
+ try:
181
+ metric_type = index.metric.supported_libraries["milvus"]
182
+ except KeyError:
183
+ raise ValueError(f"Milvus does not support metric type: {index.metric}") from None
184
+ index_params.add_index(
185
+ field_name=DEFAULT_INDEX_NAME,
186
+ index_name=f"{DEFAULT_INDEX_NAME}_index",
187
+ index_type=index_type,
188
+ metric_type=metric_type,
189
+ )
190
+
191
+ self._client.create_collection(collection_name, schema=schema, index_params=index_params)
192
+
193
+ def has_collection(self, collection_name: str) -> bool:
194
+ """
195
+ Check if a collection exists in Milvus.
196
+
197
+ Args:
198
+ client: The Milvus client instance
199
+ collection_name: The name of the collection to check
200
+
201
+ Returns:
202
+ bool: True if the collection exists, False otherwise
203
+ """
204
+ return self._client.has_collection(collection_name)
205
+
206
+ def search(
207
+ self,
208
+ collection_name: str,
209
+ vectors: List[np.ndarray],
210
+ payload_class: Type[Context],
211
+ k: int = 10,
212
+ filter: Optional[str] = None,
213
+ ) -> List[List[SearchResult]]:
214
+ """
215
+ Retrieve records from the vector database.
216
+ """
217
+ if not issubclass(payload_class, Context):
218
+ raise TypeError(f"Payload class {payload_class} is must be a subclass of Context")
219
+
220
+ output_fields = [field_name for field_name in payload_class.model_fields]
221
+ if self.milvus_lite:
222
+ output_fields.append(NULL_FIELD_NAME)
223
+ listified_vectors = [vector.tolist() for vector in vectors]
224
+
225
+ searches = self._client.search(
226
+ collection_name=collection_name,
227
+ data=listified_vectors,
228
+ limit=k,
229
+ filter=filter,
230
+ anns_field=DEFAULT_INDEX_NAME, # Use the same field name as in create_collection
231
+ output_fields=output_fields,
232
+ )
233
+
234
+ all_results = []
235
+ for search in searches:
236
+ context_list = []
237
+ for result in search:
238
+ match result:
239
+ case {"id": id, "distance": distance, "entity": entity}:
240
+ context_list.append(
241
+ SearchResult(
242
+ id=id, score=distance, context=payload_class.model_validate(self._process_null(entity))
243
+ )
244
+ )
245
+ case _:
246
+ raise ValueError(f"Failed to parse search results to {payload_class.__name__}: {result}")
247
+ all_results.append(context_list)
248
+
249
+ return all_results
250
+
251
+ def get_records(self, collection_name: str, payload_class: Type[Context], record_ids: List[str]) -> List[Context]:
252
+ output_fields = [field_name for field_name in payload_class.model_fields]
253
+ if self.milvus_lite:
254
+ output_fields.append(NULL_FIELD_NAME)
255
+ records = self._client.get(collection_name=collection_name, ids=record_ids, output_fields=output_fields)
256
+ return [payload_class.model_validate(self._process_null(record)) for record in records]
257
+
258
+ @staticmethod
259
+ def from_local(file_path: str) -> "MilvusBackend":
260
+ return MilvusBackend(uri=file_path)
261
+
262
+ def _process_null(self, record: dict) -> dict:
263
+ if self.milvus_lite and NULL_FIELD_NAME in record:
264
+ for field_name in record[NULL_FIELD_NAME]:
265
+ record[field_name] = None
266
+ del record[NULL_FIELD_NAME]
267
+ return record
268
+
269
+
270
+ def _modaic_to_milvus_schema(client: MilvusClient, modaic_schema: Schema, milvus_lite: bool) -> CollectionSchema:
271
+ """
272
+ Convert a Pydantic BaseModel schema to a Milvus collection schema.
273
+
274
+ Args:
275
+ client: The Milvus client instance
276
+ modaic_schema: The Modaic schema to convert
277
+ milvus_lite: Whether the schema is for a milvus lite database
278
+
279
+ Returns:
280
+ Any: The Milvus schema object
281
+ """
282
+ # Maps types that can contain the 'format' keyword to the default milvus data type
283
+ formatted_types: Mapping[Literal["integer", "number"], DataType] = {
284
+ "integer": DataType.INT64,
285
+ "number": DataType.DOUBLE,
286
+ }
287
+ # Maps types that do not contain the 'format' keyword to the milvus data type
288
+ non_formatted_types: Mapping[Literal["string", "boolean"], DataType] = {
289
+ "string": DataType.VARCHAR,
290
+ "boolean": DataType.BOOL,
291
+ }
292
+ # Maps values for the 'format' keyword to the milvus data type
293
+ format_to_milvus: Mapping[int_format | float_format, DataType] = {
294
+ "int8": DataType.INT8,
295
+ "int16": DataType.INT16,
296
+ "int32": DataType.INT32,
297
+ "int64": DataType.INT64,
298
+ "float": DataType.FLOAT,
299
+ "double": DataType.DOUBLE,
300
+ "bool": DataType.BOOL,
301
+ }
302
+
303
+ MAX_STR_LENGTH = 65_535 # noqa: N806
304
+ MAX_ARRAY_CAPACITY = 4096 # noqa: N806
305
+
306
+ def get_milvus_type(sf: SchemaField | InnerField) -> DataType:
307
+ type_ = sf.type
308
+ format_ = sf.format
309
+ if type_ in formatted_types and format_ in format_to_milvus:
310
+ milvus_data_type = format_to_milvus[format_]
311
+ elif type_ in formatted_types:
312
+ milvus_data_type = formatted_types[type_]
313
+ elif type_ in non_formatted_types:
314
+ milvus_data_type = non_formatted_types[type_]
315
+ else:
316
+ raise ValueError(f"Milvus does not support field type: {type_}")
317
+ return milvus_data_type
318
+
319
+ def is_nullable(sf: SchemaField | InnerField) -> bool:
320
+ if milvus_lite:
321
+ return False
322
+ return sf.optional
323
+
324
+ milvus_schema = client.create_schema(auto_id=False, enable_dynamic_field=True)
325
+ for field_name, schema_field in modaic_schema.as_dict().items():
326
+ if schema_field.type == "array":
327
+ if schema_field.inner_type.type == "string":
328
+ milvus_schema.add_field(
329
+ field_name=field_name,
330
+ datatype=DataType.ARRAY,
331
+ nullable=is_nullable(schema_field),
332
+ element_type=DataType.VARCHAR,
333
+ max_capacity=schema_field.size or MAX_ARRAY_CAPACITY,
334
+ max_length=schema_field.inner_type.size or MAX_STR_LENGTH,
335
+ )
336
+ else:
337
+ milvus_schema.add_field(
338
+ field_name=field_name,
339
+ datatype=DataType.ARRAY,
340
+ nullable=is_nullable(schema_field),
341
+ element_type=get_milvus_type(schema_field.inner_type),
342
+ max_capacity=schema_field.size or MAX_ARRAY_CAPACITY,
343
+ )
344
+ elif schema_field.type == "string":
345
+ milvus_schema.add_field(
346
+ field_name=field_name,
347
+ datatype=DataType.VARCHAR,
348
+ max_length=schema_field.size or MAX_STR_LENGTH,
349
+ nullable=is_nullable(schema_field),
350
+ is_primary=schema_field.is_id,
351
+ )
352
+ elif schema_field.type == "object":
353
+ milvus_schema.add_field(
354
+ field_name=field_name,
355
+ datatype=DataType.JSON,
356
+ nullable=is_nullable(schema_field),
357
+ )
358
+ else:
359
+ milvus_data_type = get_milvus_type(schema_field)
360
+ milvus_schema.add_field(
361
+ field_name=field_name,
362
+ datatype=milvus_data_type,
363
+ nullable=is_nullable(schema_field),
364
+ )
365
+
366
+ if milvus_lite:
367
+ if NULL_FIELD_NAME in milvus_schema.fields:
368
+ raise BackendCompatibilityError(
369
+ f"Milvus lite vector databases reserve the field '{NULL_FIELD_NAME}' for tracking null values"
370
+ )
371
+ else:
372
+ milvus_schema.add_field(
373
+ field_name=NULL_FIELD_NAME,
374
+ datatype=DataType.ARRAY,
375
+ element_type=DataType.VARCHAR,
376
+ max_capacity=len(modaic_schema.as_dict()),
377
+ max_length=255,
378
+ )
379
+ return milvus_schema
380
+
381
+
382
+ def process_value(value: Union[int, float, str], comparator: Comparator) -> str:
383
+ """Convert a value to a string and add double quotes if it is a string.
384
+
385
+ It required for comparators involving strings.
386
+
387
+ Args:
388
+ value: The value to convert.
389
+ comparator: The comparator.
390
+
391
+ Returns:
392
+ The converted value as a string.
393
+ """
394
+ #
395
+ if isinstance(value, str):
396
+ if comparator is Comparator.LIKE:
397
+ # If the comparator is LIKE, add a percent sign after it for prefix matching
398
+ # and add double quotes
399
+ return f'"{value}%"'
400
+ else:
401
+ # If the value is already a string, add double quotes
402
+ return f'"{value}"'
403
+ elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], str):
404
+ inside = ", ".join(f'"{v}"' for v in value)
405
+ return f"[{inside}]"
406
+ else:
407
+ # If the value is not a string, convert it to a string without double quotes
408
+ return str(value)
File without changes
File without changes
modaic/exceptions.py ADDED
@@ -0,0 +1,38 @@
1
+ class ModaicError(Exception):
2
+ pass
3
+
4
+
5
+ class ModaicHubError(ModaicError):
6
+ """Base class for all hub-related errors."""
7
+
8
+ pass
9
+
10
+
11
+ class RepositoryExistsError(ModaicHubError):
12
+ """Raised when repository already exists"""
13
+
14
+ pass
15
+
16
+
17
+ class AuthenticationError(ModaicHubError):
18
+ """Raised when authentication fails"""
19
+
20
+ pass
21
+
22
+
23
+ class RepositoryNotFoundError(ModaicHubError):
24
+ """Raised when repository does not exist"""
25
+
26
+ pass
27
+
28
+
29
+ class SchemaError(ModaicError):
30
+ """Raised when a schema is invalid"""
31
+
32
+ pass
33
+
34
+
35
+ class BackendCompatibilityError(ModaicError):
36
+ """Raised when a feature is not supported by a backend"""
37
+
38
+ pass