modaic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modaic might be problematic. Click here for more details.

Files changed (39) hide show
  1. modaic/__init__.py +25 -0
  2. modaic/agents/rag_agent.py +33 -0
  3. modaic/agents/registry.py +84 -0
  4. modaic/auto_agent.py +228 -0
  5. modaic/context/__init__.py +34 -0
  6. modaic/context/base.py +1064 -0
  7. modaic/context/dtype_mapping.py +25 -0
  8. modaic/context/table.py +585 -0
  9. modaic/context/text.py +94 -0
  10. modaic/databases/__init__.py +35 -0
  11. modaic/databases/graph_database.py +269 -0
  12. modaic/databases/sql_database.py +355 -0
  13. modaic/databases/vector_database/__init__.py +12 -0
  14. modaic/databases/vector_database/benchmarks/baseline.py +123 -0
  15. modaic/databases/vector_database/benchmarks/common.py +48 -0
  16. modaic/databases/vector_database/benchmarks/fork.py +132 -0
  17. modaic/databases/vector_database/benchmarks/threaded.py +119 -0
  18. modaic/databases/vector_database/vector_database.py +722 -0
  19. modaic/databases/vector_database/vendors/milvus.py +408 -0
  20. modaic/databases/vector_database/vendors/mongodb.py +0 -0
  21. modaic/databases/vector_database/vendors/pinecone.py +0 -0
  22. modaic/databases/vector_database/vendors/qdrant.py +1 -0
  23. modaic/exceptions.py +38 -0
  24. modaic/hub.py +305 -0
  25. modaic/indexing.py +127 -0
  26. modaic/module_utils.py +341 -0
  27. modaic/observability.py +275 -0
  28. modaic/precompiled.py +429 -0
  29. modaic/query_language.py +321 -0
  30. modaic/storage/__init__.py +3 -0
  31. modaic/storage/file_store.py +239 -0
  32. modaic/storage/pickle_store.py +25 -0
  33. modaic/types.py +287 -0
  34. modaic/utils.py +21 -0
  35. modaic-0.1.0.dist-info/METADATA +281 -0
  36. modaic-0.1.0.dist-info/RECORD +39 -0
  37. modaic-0.1.0.dist-info/WHEEL +5 -0
  38. modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
  39. modaic-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,35 @@
1
+ from .graph_database import GraphDatabase, MemgraphConfig, Neo4jConfig
2
+ from .sql_database import SQLDatabase, SQLiteBackend
3
+ from .vector_database.vector_database import (
4
+ CollectionConfig,
5
+ IndexConfig,
6
+ IndexType,
7
+ Metric,
8
+ SearchResult,
9
+ SupportsHybridSearch,
10
+ VDBExtensions,
11
+ VectorDatabase,
12
+ VectorDBBackend,
13
+ VectorType,
14
+ )
15
+ from .vector_database.vendors.milvus import MilvusBackend
16
+
17
+ __all__ = [
18
+ "CollectionConfig",
19
+ "SQLDatabase",
20
+ "SQLiteBackend",
21
+ "VectorDatabase",
22
+ "MilvusBackend",
23
+ "SearchResult",
24
+ "VectorDBBackend",
25
+ "IndexConfig",
26
+ "IndexType",
27
+ "Metric",
28
+ "SupportsHybridSearch",
29
+ "VDBExtensions",
30
+ "VectorDBBackend",
31
+ "VectorType",
32
+ "GraphDatabase",
33
+ "MemgraphConfig",
34
+ "Neo4jConfig",
35
+ ]
@@ -0,0 +1,269 @@
1
+ import os
2
+ from dataclasses import asdict, dataclass
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ ClassVar,
7
+ Dict,
8
+ Iterator,
9
+ List,
10
+ Optional,
11
+ Protocol,
12
+ Type,
13
+ )
14
+
15
+ from dotenv import load_dotenv
16
+
17
+ from ..context.base import Context, Relation
18
+ from ..observability import Trackable, track_modaic_obj
19
+
20
+ load_dotenv()
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ import gqlalchemy
25
+
26
+
27
+ class GraphDBConfig(Protocol):
28
+ _client_class: ClassVar[Type["gqlalchemy.DatabaseClient"]]
29
+ # CAVEAT: checking for this attribute is currently the most reliable way to ascertain that something is a dataclass
30
+ __dataclass_fields__: ClassVar[Dict[str, Any]]
31
+
32
+
33
+ class GraphDatabase(Trackable):
34
+ """
35
+ A database that stores context objects and relationships between them in a graph database.
36
+ """
37
+
38
+ def __init__(self, config: GraphDBConfig, **kwargs):
39
+ Trackable.__init__(self, **kwargs)
40
+ self.config = config
41
+ if getattr(self.config, "_client_class", None) is None:
42
+ raise ImportError("gqlalchemy is not installed. Install 'gqlalchemy' to use GraphDatabase backends.")
43
+ self._client = self.config._client_class(**asdict(self.config))
44
+
45
+ @track_modaic_obj
46
+ def execute_and_fetch(self, query: str) -> List[Dict[str, Any]]:
47
+ return self._client.execute_and_fetch(query)
48
+
49
+ def execute(
50
+ self,
51
+ query: str,
52
+ parameters: Dict[str, Any] = None,
53
+ connection: Optional["gqlalchemy.Connection"] = None,
54
+ ) -> None:
55
+ self._client.execute(query, parameters or {}, connection)
56
+
57
+ def create_index(self, index: "gqlalchemy.Index") -> None:
58
+ self._client.create_index(index)
59
+
60
+ def drop_index(self, index: "gqlalchemy.Index") -> None:
61
+ self._client.drop_index(index)
62
+
63
+ def get_indexes(self) -> List["gqlalchemy.Index"]:
64
+ return self._client.get_indexes()
65
+
66
+ def ensure_indexes(self, indexes: List["gqlalchemy.Index"]) -> None:
67
+ self._client.ensure_indexes(indexes)
68
+
69
+ def drop_indexes(self) -> None:
70
+ self._client.drop_indexes()
71
+
72
+ def create_constraint(self, constraint: "gqlalchemy.Constraint") -> None:
73
+ self._client.create_constraint(constraint)
74
+
75
+ def drop_constraint(self, constraint: "gqlalchemy.Constraint") -> None:
76
+ self._client.drop_constraint(constraint)
77
+
78
+ def get_constraints(self) -> List["gqlalchemy.Constraint"]:
79
+ return self._client.get_constraints()
80
+
81
+ def get_exists_constraints(self) -> List["gqlalchemy.Constraint"]:
82
+ return self._client.get_exists_constraints()
83
+
84
+ def get_unique_constraints(self) -> List["gqlalchemy.Constraint"]:
85
+ return self._client.get_unique_constraints()
86
+
87
+ def ensure_constraints(self, constraints: List["gqlalchemy.Constraint"]) -> None:
88
+ self._client.ensure_constraints(constraints)
89
+
90
+ def drop_database(self) -> None:
91
+ self._client.drop_database()
92
+
93
+ def new_connection(self) -> "gqlalchemy.Connection":
94
+ return self._client.new_connection()
95
+
96
+ def get_variable_assume_one(self, query_result: Iterator[Dict[str, Any]], variable_name: str) -> Any:
97
+ return self._client.get_variable_assume_one(query_result, variable_name)
98
+
99
+ def create_node(self, node: Context) -> Optional[Context]:
100
+ """
101
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
102
+ """
103
+ node = node.to_gqlalchemy()
104
+ created_node = self._client.create_node(node)
105
+ if created_node is not None:
106
+ return Context.from_gqlalchemy(created_node)
107
+
108
+ def save_node(self, node: Context) -> "gqlalchemy.Node":
109
+ """
110
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
111
+ """
112
+ node = node.to_gqlalchemy(self)
113
+ result = self._client.save_node(node)
114
+ return result
115
+
116
+ def save_nodes(self, nodes: List[Context]) -> None:
117
+ """
118
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
119
+ """
120
+ nodes = [node.to_gqlalchemy(self) for node in nodes]
121
+ self._client.save_nodes(nodes)
122
+
123
+ def save_node_with_id(self, node: Context) -> Optional["gqlalchemy.Node"]:
124
+ """
125
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
126
+ """
127
+ node = node.to_gqlalchemy(self)
128
+ result = self._client.save_node_with_id(node)
129
+ return result
130
+
131
+ def load_node(self, node: Context) -> Optional["gqlalchemy.Node"]:
132
+ """
133
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
134
+ """
135
+ node = node.to_gqlalchemy(self)
136
+ result = self._client.load_node(node)
137
+ return result
138
+
139
+ def load_node_with_all_properties(self, node: Context) -> Optional["gqlalchemy.Node"]:
140
+ """
141
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
142
+ """
143
+ node = node.to_gqlalchemy(self)
144
+ result = self._client.load_node_with_all_properties(node)
145
+ return result
146
+
147
+ def load_node_with_id(self, node: Context) -> Optional["gqlalchemy.Node"]:
148
+ """
149
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
150
+ """
151
+ node = node.to_gqlalchemy(self)
152
+ result = self._client.load_node_with_id(node)
153
+ return result
154
+
155
+ def load_relationship(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
156
+ """
157
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
158
+ """
159
+ relationship = relationship.to_gqlalchemy(self)
160
+ result = self._client.load_relationship(relationship)
161
+ return result
162
+
163
+ def load_relationship_with_id(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
164
+ """
165
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
166
+ """
167
+ relationship = relationship.to_gqlalchemy(self)
168
+ result = self._client.load_relationship_with_id(relationship)
169
+ return result
170
+
171
+ def load_relationship_with_start_node_id_and_end_node_id(
172
+ self, relationship: Relation
173
+ ) -> Optional["gqlalchemy.Relationship"]:
174
+ """
175
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
176
+ """
177
+ relationship = relationship.to_gqlalchemy(self)
178
+ result = self._client.load_relationship_with_start_node_id_and_end_node_id(relationship)
179
+ return result
180
+
181
+ def save_relationship(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
182
+ """
183
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
184
+ """
185
+ relationship = relationship.to_gqlalchemy(self)
186
+
187
+ result = self._client.save_relationship(relationship)
188
+ return result
189
+
190
+ def save_relationships(self, relationships: List[Relation]) -> None:
191
+ """
192
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
193
+ """
194
+ relationships = [relationship.to_gqlalchemy(self) for relationship in relationships]
195
+ self._client.save_relationships(relationships)
196
+
197
+ def save_relationship_with_id(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
198
+ """
199
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
200
+ """
201
+ relationship = relationship.to_gqlalchemy(self)
202
+ result = self._client.save_relationship_with_id(relationship)
203
+ return result
204
+
205
+ def create_relationship(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
206
+ """
207
+ <Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
208
+ """
209
+ relationship = relationship.to_gqlalchemy(self)
210
+ result = self._client.create_relationship(relationship)
211
+ return result
212
+
213
+
214
+ NEO4J_HOST = os.getenv("NEO4J_HOST", "localhost")
215
+ NEO4J_PORT = int(os.getenv("NEO4J_PORT", "7687"))
216
+ NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
217
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "test")
218
+ NEO4J_ENCRYPTED = os.getenv("NEO4J_ENCRYPT", "false").lower() == "true"
219
+ NEO4J_CLIENT_NAME = os.getenv("NEO4J_CLIENT_NAME", "neo4j")
220
+
221
+
222
+ def load_neo4j() -> Optional[Type[Any]]:
223
+ try:
224
+ import gqlalchemy
225
+ except ImportError:
226
+ return None
227
+ return gqlalchemy.Neo4j
228
+
229
+
230
+ def load_memgraph() -> Optional[Type[Any]]:
231
+ try:
232
+ import gqlalchemy
233
+ except ImportError:
234
+ return None
235
+ return gqlalchemy.Memgraph
236
+
237
+
238
+ @dataclass
239
+ class Neo4jConfig:
240
+ host: str = NEO4J_HOST
241
+ port: int = NEO4J_PORT
242
+ username: str = NEO4J_USERNAME
243
+ password: str = NEO4J_PASSWORD
244
+ encrypted: bool = (NEO4J_ENCRYPTED,)
245
+ client_name: str = (NEO4J_CLIENT_NAME,)
246
+
247
+ _client_class: ClassVar[Optional[Type[Any]]] = load_neo4j()
248
+
249
+
250
+ MG_HOST = os.getenv("MG_HOST", "127.0.0.1")
251
+ MG_PORT = int(os.getenv("MG_PORT", "7687"))
252
+ MG_USERNAME = os.getenv("MG_USERNAME", "")
253
+ MG_PASSWORD = os.getenv("MG_PASSWORD", "")
254
+ MG_ENCRYPTED = os.getenv("MG_ENCRYPT", "false").lower() == "true"
255
+ MG_CLIENT_NAME = os.getenv("MG_CLIENT_NAME", "GQLAlchemy")
256
+ MG_LAZY = os.getenv("MG_LAZY", "false").lower() == "true"
257
+
258
+
259
+ @dataclass
260
+ class MemgraphConfig:
261
+ host: str = MG_HOST
262
+ port: int = MG_PORT
263
+ username: str = MG_USERNAME
264
+ password: str = MG_PASSWORD
265
+ encrypted: bool = MG_ENCRYPTED
266
+ client_name: str = MG_CLIENT_NAME
267
+ lazy: bool = MG_LAZY
268
+
269
+ _client_class: ClassVar[Optional[Type[Any]]] = load_memgraph()
@@ -0,0 +1,355 @@
1
+ import json
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple
5
+ from urllib.parse import urlencode
6
+
7
+ import pandas as pd
8
+ from sqlalchemy import Column, CursorResult, MetaData, String, Text, create_engine, inspect, text
9
+ from sqlalchemy import Table as SQLTable
10
+ from sqlalchemy.dialects import sqlite
11
+ from sqlalchemy.orm import sessionmaker
12
+ from sqlalchemy.sql.compiler import IdentifierPreparer
13
+ from tqdm import tqdm
14
+
15
+ from ..context.table import BaseTable, Table, TableFile
16
+ from ..storage import FileStore
17
+
18
+
19
+ @dataclass
20
+ class SQLDatabaseBackend:
21
+ """
22
+ Base class for SQL database backends.
23
+ Each subclass must implement the `url` property.
24
+ """
25
+
26
+ @property
27
+ def url(self) -> str:
28
+ raise NotImplementedError("Subclasses must implement this method")
29
+
30
+
31
+ @dataclass
32
+ class SQLServerBackend(SQLDatabaseBackend):
33
+ """
34
+ Backend configuration for a SQL served over a port or remote connection. (MySQL, PostgreSQL, etc.)
35
+
36
+ Args:
37
+ user: The username to connect to the database.
38
+ password: The password to connect to the database.
39
+ host: The host of the database.
40
+ database: The name of the database.
41
+ port: The port of the database.
42
+ """
43
+
44
+ user: str
45
+ password: str
46
+ host: str
47
+ database: str
48
+ port: Optional[str] = None
49
+ dialect: str = "mysql"
50
+ driver: Optional[str] = None
51
+ query_params: Optional[dict] = None
52
+
53
+ @property
54
+ def url(self) -> str:
55
+ port = f":{self.port}" if self.port else ""
56
+ driver = f"+{self.driver}" if self.driver else ""
57
+ query = f"?{urlencode(self.query_params)}" if self.query_params else ""
58
+ return f"{self.dialect}{driver}://{self.user}:{self.password}@{self.host}{port}/{self.database}{query}"
59
+
60
+
61
+ @dataclass
62
+ class SQLiteBackend(SQLDatabaseBackend):
63
+ """
64
+ Backend configuration for a SQLite database.
65
+
66
+ Args:
67
+ db_path: Path to the SQLite database file.
68
+ in_memory: Whether to create an in-memory SQLite database.
69
+ query_params: Query parameters to pass to the database.
70
+ """
71
+
72
+ db_path: Optional[str] = None
73
+ in_memory: bool = False
74
+ query_params: Optional[dict] = None
75
+
76
+ @property
77
+ def url(self) -> str:
78
+ base = "sqlite:///:memory:" if self.in_memory else f"sqlite:///{self.db_path}"
79
+ query = f"?{urlencode(self.query_params)}" if self.query_params else ""
80
+ return f"{base}{query}"
81
+
82
+
83
+ class SQLDatabase:
84
+ def __init__(
85
+ self,
86
+ backend: SQLDatabaseBackend | str,
87
+ engine_kwargs: dict = None, # TODO: This may not be a smart idea, may want to enforce specific kwargs
88
+ session_kwargs: dict = None, # TODO: This may not be a smart idea, may want to enforce specific kwargs
89
+ ):
90
+ self.url = backend.url if isinstance(backend, SQLDatabaseBackend) else backend
91
+ self.engine = create_engine(self.url, **(engine_kwargs or {}))
92
+ self.metadata = MetaData()
93
+ self.session = sessionmaker(bind=self.engine, **(session_kwargs or {}))
94
+ self.inspector = inspect(self.engine)
95
+ self.preparer = IdentifierPreparer(sqlite.dialect())
96
+
97
+ # Create metadata table to store table metadata
98
+ self.metadata_table = SQLTable(
99
+ "metadata",
100
+ self.metadata,
101
+ Column("table_name", String(255), primary_key=True),
102
+ Column("metadata_json", Text),
103
+ )
104
+ self.metadata.create_all(self.engine)
105
+ self.connection = None
106
+ self._in_transaction = False
107
+
108
+ def add_table(
109
+ self,
110
+ table: BaseTable,
111
+ if_exists: Literal["fail", "replace", "append"] = "replace",
112
+ schema: str = None,
113
+ ):
114
+ # TODO: support batch inserting for large dataframes
115
+ with self.connect() as connection:
116
+ # Use the connection for to_sql to respect transaction context
117
+ table._df.to_sql(table.name, connection, if_exists=if_exists, index=False)
118
+
119
+ # Remove existing metadata for this table if it exists
120
+ connection.execute(self.metadata_table.delete().where(self.metadata_table.c.table_name == table.name))
121
+
122
+ # Insert new metadata
123
+ connection.execute(
124
+ self.metadata_table.insert().values(
125
+ table_name=table.name,
126
+ metadata_json=json.dumps(table.metadata),
127
+ )
128
+ )
129
+ if self._should_commit():
130
+ connection.commit()
131
+
132
+ def add_tables(
133
+ self,
134
+ tables: Iterable[BaseTable],
135
+ if_exists: Literal["fail", "replace", "append"] = "replace",
136
+ schema: str = None,
137
+ ):
138
+ for table in tables:
139
+ self.add_table(table, if_exists, schema)
140
+
141
+ def drop_table(self, name: str, must_exist: bool = False):
142
+ """
143
+ Drop a table from the database and remove its metadata.
144
+
145
+ Args:
146
+ name: The name of the table to drop
147
+ """
148
+ if_exists = "IF EXISTS" if not must_exist else ""
149
+ safe_name = self.preparer.quote(name)
150
+ with self.connect() as connection:
151
+ command = text(f"DROP TABLE {if_exists} {safe_name}")
152
+ connection.execute(command)
153
+ # Also remove metadata for this table
154
+ connection.execute(self.metadata_table.delete().where(self.metadata_table.c.table_name == name))
155
+ if self._should_commit():
156
+ connection.commit()
157
+
158
+ def drop_tables(self, names: Iterable[str], must_exist: bool = False):
159
+ for name in names:
160
+ self.drop_table(name, must_exist)
161
+
162
+ def list_tables(self) -> List[str]:
163
+ """
164
+ List all tables currently in the database.
165
+
166
+ Returns:
167
+ List of table names in the database.
168
+ """
169
+ # Refresh the inspector to ensure we get current table list
170
+ self.inspector = inspect(self.engine)
171
+ return self.inspector.get_table_names()
172
+
173
+ def get_table(self, name: str) -> BaseTable:
174
+ df = pd.read_sql_table(name, self.engine)
175
+
176
+ return Table(df=df, name=name, metadata=self.get_table_metadata(name))
177
+
178
+ def get_table_schema(self, name: str) -> List[Column]:
179
+ """
180
+ Return column schema for a given table.
181
+
182
+ Args:
183
+ name: The name of the table to get schema for
184
+
185
+ Returns:
186
+ Column schema information for the table.
187
+ """
188
+ return self.inspector.get_columns(name)
189
+
190
+ def get_table_metadata(self, name: str) -> dict:
191
+ """
192
+ Get metadata for a specific table.
193
+
194
+ Args:
195
+ name: The name of the table to get metadata for
196
+
197
+ Returns:
198
+ Dictionary containing the table's metadata, or empty dict if not found.
199
+ """
200
+ with self.connect() as connection:
201
+ result = connection.execute(
202
+ self.metadata_table.select().where(self.metadata_table.c.table_name == name)
203
+ ).fetchone()
204
+
205
+ if result:
206
+ return json.loads(result.metadata_json)
207
+ return {}
208
+
209
+ def query(self, query: str) -> CursorResult:
210
+ with self.connect() as connection:
211
+ result = connection.execute(text(query))
212
+ return result
213
+
214
+ def fetchall(self, query: str) -> List[Tuple]:
215
+ result = self.query(query)
216
+ return result.fetchall()
217
+
218
+ def fetchone(self, query: str) -> Tuple:
219
+ result = self.query(query)
220
+ return result.fetchone()
221
+
222
+ @classmethod
223
+ def from_file_store(
224
+ cls,
225
+ file_store: FileStore,
226
+ backend: SQLDatabaseBackend,
227
+ folder: Optional[str] = None,
228
+ table_created_hook: Optional[Callable[[TableFile], Any]] = None,
229
+ ) -> "SQLDatabase":
230
+ # TODO: support batch inserting and parallel processing
231
+ """
232
+ Initializes a new SQLDatabase from a file store.
233
+
234
+ Args:
235
+ file_store: File store containing files to load
236
+ backend: SQL database backend
237
+ folder: Folder in the file store to load
238
+
239
+ Returns:
240
+ New SQLDatabase instance loaded with data from the file store.
241
+ """
242
+ # TODO: make sure the loaded sql database is empty if not raise error and tell user to use __init__ for an already existing database
243
+ instance = cls(backend)
244
+ instance.add_file_store(file_store, folder, table_created_hook)
245
+ return instance
246
+
247
+ def add_file_store(
248
+ self,
249
+ file_store: FileStore,
250
+ folder: Optional[str] = None,
251
+ table_created_hook: Optional[Callable[[TableFile], Any]] = None,
252
+ ):
253
+ with self.begin():
254
+ for key, _ in tqdm(file_store.items(folder), desc="Uploading files to SQL database"):
255
+ table = TableFile.from_file_store(key, file_store)
256
+ self.add_table(table, if_exists="fail")
257
+ if table_created_hook:
258
+ table_created_hook(table)
259
+
260
+ @contextmanager
261
+ def connect(self):
262
+ """
263
+ Context manager for database connections.
264
+ Reuses existing connection if available, otherwise creates a temporary one.
265
+ """
266
+ connection_existed = self.connection is not None
267
+ if not connection_existed:
268
+ self.connection = self.engine.connect()
269
+
270
+ try:
271
+ yield self.connection
272
+ finally:
273
+ # Only close if we created the connection for this operation
274
+ if not connection_existed:
275
+ self.close()
276
+
277
+ def open_persistent_connection(self):
278
+ """
279
+ Opens a persistent connection that will be reused across operations.
280
+ Call close() to close the persistent connection.
281
+ """
282
+ if self.connection is None:
283
+ self.connection = self.engine.connect()
284
+
285
+ def close(self):
286
+ """
287
+ Closes the current connection if one exists.
288
+ """
289
+ if self.connection:
290
+ self.connection.close()
291
+ self.connection = None
292
+
293
+ def _should_commit(self) -> bool:
294
+ """
295
+ Returns True if operations should commit immediately.
296
+ Returns False if we're within an explicit transaction context.
297
+ """
298
+ return not self._in_transaction
299
+
300
+ @contextmanager
301
+ def begin(self):
302
+ """
303
+ Context manager for database transactions using existing connection.
304
+ Requires an active connection. Commits on success, rolls back on exception.
305
+
306
+ Raises:
307
+ RuntimeError: If no active connection exists
308
+ """
309
+ if self.connection is None:
310
+ raise RuntimeError("No active connection. Use connect_and_begin() or open a connection first.")
311
+
312
+ transaction = self.connection.begin()
313
+ old_in_transaction = self._in_transaction
314
+ self._in_transaction = True
315
+
316
+ try:
317
+ yield self.connection
318
+ transaction.commit()
319
+ except Exception:
320
+ transaction.rollback()
321
+ raise
322
+ finally:
323
+ self._in_transaction = old_in_transaction
324
+
325
+ @contextmanager
326
+ def connect_and_begin(self):
327
+ """
328
+ Context manager that establishes a connection and starts a transaction.
329
+ Reuses existing connection if available, otherwise creates a temporary one.
330
+ Commits on success, rolls back on exception.
331
+ """
332
+ connection_existed = self.connection is not None
333
+ if not connection_existed:
334
+ self.connection = self.engine.connect()
335
+
336
+ transaction = self.connection.begin()
337
+ old_in_transaction = self._in_transaction
338
+ self._in_transaction = True
339
+
340
+ try:
341
+ yield self.connection
342
+ transaction.commit()
343
+ except Exception:
344
+ transaction.rollback()
345
+ raise
346
+ finally:
347
+ self._in_transaction = old_in_transaction
348
+ # Only close if we created the connection for this operation
349
+ if not connection_existed:
350
+ self.close()
351
+
352
+
353
+ class MultiTenantSQLDatabase:
354
+ def __init__(self):
355
+ raise NotImplementedError("Not implemented")
@@ -0,0 +1,12 @@
1
+ from .vector_database import IndexConfig, IndexType, Metric, SupportsHybridSearch, VectorDatabase, VectorType
2
+ from .vendors.milvus import MilvusBackend
3
+
4
+ __all__ = [
5
+ "VectorDatabase",
6
+ "SupportsHybridSearch",
7
+ "MilvusBackend",
8
+ "IndexConfig",
9
+ "IndexType",
10
+ "VectorType",
11
+ "Metric",
12
+ ]