modaic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modaic might be problematic. Click here for more details.
- modaic/__init__.py +25 -0
- modaic/agents/rag_agent.py +33 -0
- modaic/agents/registry.py +84 -0
- modaic/auto_agent.py +228 -0
- modaic/context/__init__.py +34 -0
- modaic/context/base.py +1064 -0
- modaic/context/dtype_mapping.py +25 -0
- modaic/context/table.py +585 -0
- modaic/context/text.py +94 -0
- modaic/databases/__init__.py +35 -0
- modaic/databases/graph_database.py +269 -0
- modaic/databases/sql_database.py +355 -0
- modaic/databases/vector_database/__init__.py +12 -0
- modaic/databases/vector_database/benchmarks/baseline.py +123 -0
- modaic/databases/vector_database/benchmarks/common.py +48 -0
- modaic/databases/vector_database/benchmarks/fork.py +132 -0
- modaic/databases/vector_database/benchmarks/threaded.py +119 -0
- modaic/databases/vector_database/vector_database.py +722 -0
- modaic/databases/vector_database/vendors/milvus.py +408 -0
- modaic/databases/vector_database/vendors/mongodb.py +0 -0
- modaic/databases/vector_database/vendors/pinecone.py +0 -0
- modaic/databases/vector_database/vendors/qdrant.py +1 -0
- modaic/exceptions.py +38 -0
- modaic/hub.py +305 -0
- modaic/indexing.py +127 -0
- modaic/module_utils.py +341 -0
- modaic/observability.py +275 -0
- modaic/precompiled.py +429 -0
- modaic/query_language.py +321 -0
- modaic/storage/__init__.py +3 -0
- modaic/storage/file_store.py +239 -0
- modaic/storage/pickle_store.py +25 -0
- modaic/types.py +287 -0
- modaic/utils.py +21 -0
- modaic-0.1.0.dist-info/METADATA +281 -0
- modaic-0.1.0.dist-info/RECORD +39 -0
- modaic-0.1.0.dist-info/WHEEL +5 -0
- modaic-0.1.0.dist-info/licenses/LICENSE +31 -0
- modaic-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from .graph_database import GraphDatabase, MemgraphConfig, Neo4jConfig
|
|
2
|
+
from .sql_database import SQLDatabase, SQLiteBackend
|
|
3
|
+
from .vector_database.vector_database import (
|
|
4
|
+
CollectionConfig,
|
|
5
|
+
IndexConfig,
|
|
6
|
+
IndexType,
|
|
7
|
+
Metric,
|
|
8
|
+
SearchResult,
|
|
9
|
+
SupportsHybridSearch,
|
|
10
|
+
VDBExtensions,
|
|
11
|
+
VectorDatabase,
|
|
12
|
+
VectorDBBackend,
|
|
13
|
+
VectorType,
|
|
14
|
+
)
|
|
15
|
+
from .vector_database.vendors.milvus import MilvusBackend
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"CollectionConfig",
|
|
19
|
+
"SQLDatabase",
|
|
20
|
+
"SQLiteBackend",
|
|
21
|
+
"VectorDatabase",
|
|
22
|
+
"MilvusBackend",
|
|
23
|
+
"SearchResult",
|
|
24
|
+
"VectorDBBackend",
|
|
25
|
+
"IndexConfig",
|
|
26
|
+
"IndexType",
|
|
27
|
+
"Metric",
|
|
28
|
+
"SupportsHybridSearch",
|
|
29
|
+
"VDBExtensions",
|
|
30
|
+
"VectorDBBackend",
|
|
31
|
+
"VectorType",
|
|
32
|
+
"GraphDatabase",
|
|
33
|
+
"MemgraphConfig",
|
|
34
|
+
"Neo4jConfig",
|
|
35
|
+
]
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
3
|
+
from typing import (
|
|
4
|
+
TYPE_CHECKING,
|
|
5
|
+
Any,
|
|
6
|
+
ClassVar,
|
|
7
|
+
Dict,
|
|
8
|
+
Iterator,
|
|
9
|
+
List,
|
|
10
|
+
Optional,
|
|
11
|
+
Protocol,
|
|
12
|
+
Type,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
|
|
17
|
+
from ..context.base import Context, Relation
|
|
18
|
+
from ..observability import Trackable, track_modaic_obj
|
|
19
|
+
|
|
20
|
+
load_dotenv()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import gqlalchemy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GraphDBConfig(Protocol):
|
|
28
|
+
_client_class: ClassVar[Type["gqlalchemy.DatabaseClient"]]
|
|
29
|
+
# CAVEAT: checking for this attribute is currently the most reliable way to ascertain that something is a dataclass
|
|
30
|
+
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GraphDatabase(Trackable):
|
|
34
|
+
"""
|
|
35
|
+
A database that stores context objects and relationships between them in a graph database.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, config: GraphDBConfig, **kwargs):
|
|
39
|
+
Trackable.__init__(self, **kwargs)
|
|
40
|
+
self.config = config
|
|
41
|
+
if getattr(self.config, "_client_class", None) is None:
|
|
42
|
+
raise ImportError("gqlalchemy is not installed. Install 'gqlalchemy' to use GraphDatabase backends.")
|
|
43
|
+
self._client = self.config._client_class(**asdict(self.config))
|
|
44
|
+
|
|
45
|
+
@track_modaic_obj
|
|
46
|
+
def execute_and_fetch(self, query: str) -> List[Dict[str, Any]]:
|
|
47
|
+
return self._client.execute_and_fetch(query)
|
|
48
|
+
|
|
49
|
+
def execute(
|
|
50
|
+
self,
|
|
51
|
+
query: str,
|
|
52
|
+
parameters: Dict[str, Any] = None,
|
|
53
|
+
connection: Optional["gqlalchemy.Connection"] = None,
|
|
54
|
+
) -> None:
|
|
55
|
+
self._client.execute(query, parameters or {}, connection)
|
|
56
|
+
|
|
57
|
+
def create_index(self, index: "gqlalchemy.Index") -> None:
|
|
58
|
+
self._client.create_index(index)
|
|
59
|
+
|
|
60
|
+
def drop_index(self, index: "gqlalchemy.Index") -> None:
|
|
61
|
+
self._client.drop_index(index)
|
|
62
|
+
|
|
63
|
+
def get_indexes(self) -> List["gqlalchemy.Index"]:
|
|
64
|
+
return self._client.get_indexes()
|
|
65
|
+
|
|
66
|
+
def ensure_indexes(self, indexes: List["gqlalchemy.Index"]) -> None:
|
|
67
|
+
self._client.ensure_indexes(indexes)
|
|
68
|
+
|
|
69
|
+
def drop_indexes(self) -> None:
|
|
70
|
+
self._client.drop_indexes()
|
|
71
|
+
|
|
72
|
+
def create_constraint(self, constraint: "gqlalchemy.Constraint") -> None:
|
|
73
|
+
self._client.create_constraint(constraint)
|
|
74
|
+
|
|
75
|
+
def drop_constraint(self, constraint: "gqlalchemy.Constraint") -> None:
|
|
76
|
+
self._client.drop_constraint(constraint)
|
|
77
|
+
|
|
78
|
+
def get_constraints(self) -> List["gqlalchemy.Constraint"]:
|
|
79
|
+
return self._client.get_constraints()
|
|
80
|
+
|
|
81
|
+
def get_exists_constraints(self) -> List["gqlalchemy.Constraint"]:
|
|
82
|
+
return self._client.get_exists_constraints()
|
|
83
|
+
|
|
84
|
+
def get_unique_constraints(self) -> List["gqlalchemy.Constraint"]:
|
|
85
|
+
return self._client.get_unique_constraints()
|
|
86
|
+
|
|
87
|
+
def ensure_constraints(self, constraints: List["gqlalchemy.Constraint"]) -> None:
|
|
88
|
+
self._client.ensure_constraints(constraints)
|
|
89
|
+
|
|
90
|
+
def drop_database(self) -> None:
|
|
91
|
+
self._client.drop_database()
|
|
92
|
+
|
|
93
|
+
def new_connection(self) -> "gqlalchemy.Connection":
|
|
94
|
+
return self._client.new_connection()
|
|
95
|
+
|
|
96
|
+
def get_variable_assume_one(self, query_result: Iterator[Dict[str, Any]], variable_name: str) -> Any:
|
|
97
|
+
return self._client.get_variable_assume_one(query_result, variable_name)
|
|
98
|
+
|
|
99
|
+
def create_node(self, node: Context) -> Optional[Context]:
|
|
100
|
+
"""
|
|
101
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
102
|
+
"""
|
|
103
|
+
node = node.to_gqlalchemy()
|
|
104
|
+
created_node = self._client.create_node(node)
|
|
105
|
+
if created_node is not None:
|
|
106
|
+
return Context.from_gqlalchemy(created_node)
|
|
107
|
+
|
|
108
|
+
def save_node(self, node: Context) -> "gqlalchemy.Node":
|
|
109
|
+
"""
|
|
110
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
111
|
+
"""
|
|
112
|
+
node = node.to_gqlalchemy(self)
|
|
113
|
+
result = self._client.save_node(node)
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
def save_nodes(self, nodes: List[Context]) -> None:
|
|
117
|
+
"""
|
|
118
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
119
|
+
"""
|
|
120
|
+
nodes = [node.to_gqlalchemy(self) for node in nodes]
|
|
121
|
+
self._client.save_nodes(nodes)
|
|
122
|
+
|
|
123
|
+
def save_node_with_id(self, node: Context) -> Optional["gqlalchemy.Node"]:
|
|
124
|
+
"""
|
|
125
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
126
|
+
"""
|
|
127
|
+
node = node.to_gqlalchemy(self)
|
|
128
|
+
result = self._client.save_node_with_id(node)
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
def load_node(self, node: Context) -> Optional["gqlalchemy.Node"]:
|
|
132
|
+
"""
|
|
133
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
134
|
+
"""
|
|
135
|
+
node = node.to_gqlalchemy(self)
|
|
136
|
+
result = self._client.load_node(node)
|
|
137
|
+
return result
|
|
138
|
+
|
|
139
|
+
def load_node_with_all_properties(self, node: Context) -> Optional["gqlalchemy.Node"]:
|
|
140
|
+
"""
|
|
141
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
142
|
+
"""
|
|
143
|
+
node = node.to_gqlalchemy(self)
|
|
144
|
+
result = self._client.load_node_with_all_properties(node)
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
def load_node_with_id(self, node: Context) -> Optional["gqlalchemy.Node"]:
|
|
148
|
+
"""
|
|
149
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
150
|
+
"""
|
|
151
|
+
node = node.to_gqlalchemy(self)
|
|
152
|
+
result = self._client.load_node_with_id(node)
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
def load_relationship(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
|
|
156
|
+
"""
|
|
157
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
158
|
+
"""
|
|
159
|
+
relationship = relationship.to_gqlalchemy(self)
|
|
160
|
+
result = self._client.load_relationship(relationship)
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
def load_relationship_with_id(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
|
|
164
|
+
"""
|
|
165
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
166
|
+
"""
|
|
167
|
+
relationship = relationship.to_gqlalchemy(self)
|
|
168
|
+
result = self._client.load_relationship_with_id(relationship)
|
|
169
|
+
return result
|
|
170
|
+
|
|
171
|
+
def load_relationship_with_start_node_id_and_end_node_id(
|
|
172
|
+
self, relationship: Relation
|
|
173
|
+
) -> Optional["gqlalchemy.Relationship"]:
|
|
174
|
+
"""
|
|
175
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
176
|
+
"""
|
|
177
|
+
relationship = relationship.to_gqlalchemy(self)
|
|
178
|
+
result = self._client.load_relationship_with_start_node_id_and_end_node_id(relationship)
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
def save_relationship(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
|
|
182
|
+
"""
|
|
183
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
184
|
+
"""
|
|
185
|
+
relationship = relationship.to_gqlalchemy(self)
|
|
186
|
+
|
|
187
|
+
result = self._client.save_relationship(relationship)
|
|
188
|
+
return result
|
|
189
|
+
|
|
190
|
+
def save_relationships(self, relationships: List[Relation]) -> None:
|
|
191
|
+
"""
|
|
192
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
193
|
+
"""
|
|
194
|
+
relationships = [relationship.to_gqlalchemy(self) for relationship in relationships]
|
|
195
|
+
self._client.save_relationships(relationships)
|
|
196
|
+
|
|
197
|
+
def save_relationship_with_id(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
|
|
198
|
+
"""
|
|
199
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
200
|
+
"""
|
|
201
|
+
relationship = relationship.to_gqlalchemy(self)
|
|
202
|
+
result = self._client.save_relationship_with_id(relationship)
|
|
203
|
+
return result
|
|
204
|
+
|
|
205
|
+
def create_relationship(self, relationship: Relation) -> Optional["gqlalchemy.Relationship"]:
|
|
206
|
+
"""
|
|
207
|
+
<Danger>This method is not thread safe. We are actively working on a solution to make it thread safe.</Danger>
|
|
208
|
+
"""
|
|
209
|
+
relationship = relationship.to_gqlalchemy(self)
|
|
210
|
+
result = self._client.create_relationship(relationship)
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
NEO4J_HOST = os.getenv("NEO4J_HOST", "localhost")
|
|
215
|
+
NEO4J_PORT = int(os.getenv("NEO4J_PORT", "7687"))
|
|
216
|
+
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
|
|
217
|
+
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "test")
|
|
218
|
+
NEO4J_ENCRYPTED = os.getenv("NEO4J_ENCRYPT", "false").lower() == "true"
|
|
219
|
+
NEO4J_CLIENT_NAME = os.getenv("NEO4J_CLIENT_NAME", "neo4j")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def load_neo4j() -> Optional[Type[Any]]:
|
|
223
|
+
try:
|
|
224
|
+
import gqlalchemy
|
|
225
|
+
except ImportError:
|
|
226
|
+
return None
|
|
227
|
+
return gqlalchemy.Neo4j
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def load_memgraph() -> Optional[Type[Any]]:
|
|
231
|
+
try:
|
|
232
|
+
import gqlalchemy
|
|
233
|
+
except ImportError:
|
|
234
|
+
return None
|
|
235
|
+
return gqlalchemy.Memgraph
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclass
|
|
239
|
+
class Neo4jConfig:
|
|
240
|
+
host: str = NEO4J_HOST
|
|
241
|
+
port: int = NEO4J_PORT
|
|
242
|
+
username: str = NEO4J_USERNAME
|
|
243
|
+
password: str = NEO4J_PASSWORD
|
|
244
|
+
encrypted: bool = (NEO4J_ENCRYPTED,)
|
|
245
|
+
client_name: str = (NEO4J_CLIENT_NAME,)
|
|
246
|
+
|
|
247
|
+
_client_class: ClassVar[Optional[Type[Any]]] = load_neo4j()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
MG_HOST = os.getenv("MG_HOST", "127.0.0.1")
|
|
251
|
+
MG_PORT = int(os.getenv("MG_PORT", "7687"))
|
|
252
|
+
MG_USERNAME = os.getenv("MG_USERNAME", "")
|
|
253
|
+
MG_PASSWORD = os.getenv("MG_PASSWORD", "")
|
|
254
|
+
MG_ENCRYPTED = os.getenv("MG_ENCRYPT", "false").lower() == "true"
|
|
255
|
+
MG_CLIENT_NAME = os.getenv("MG_CLIENT_NAME", "GQLAlchemy")
|
|
256
|
+
MG_LAZY = os.getenv("MG_LAZY", "false").lower() == "true"
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@dataclass
|
|
260
|
+
class MemgraphConfig:
|
|
261
|
+
host: str = MG_HOST
|
|
262
|
+
port: int = MG_PORT
|
|
263
|
+
username: str = MG_USERNAME
|
|
264
|
+
password: str = MG_PASSWORD
|
|
265
|
+
encrypted: bool = MG_ENCRYPTED
|
|
266
|
+
client_name: str = MG_CLIENT_NAME
|
|
267
|
+
lazy: bool = MG_LAZY
|
|
268
|
+
|
|
269
|
+
_client_class: ClassVar[Optional[Type[Any]]] = load_memgraph()
|
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple
|
|
5
|
+
from urllib.parse import urlencode
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from sqlalchemy import Column, CursorResult, MetaData, String, Text, create_engine, inspect, text
|
|
9
|
+
from sqlalchemy import Table as SQLTable
|
|
10
|
+
from sqlalchemy.dialects import sqlite
|
|
11
|
+
from sqlalchemy.orm import sessionmaker
|
|
12
|
+
from sqlalchemy.sql.compiler import IdentifierPreparer
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
from ..context.table import BaseTable, Table, TableFile
|
|
16
|
+
from ..storage import FileStore
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SQLDatabaseBackend:
|
|
21
|
+
"""
|
|
22
|
+
Base class for SQL database backends.
|
|
23
|
+
Each subclass must implement the `url` property.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def url(self) -> str:
|
|
28
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class SQLServerBackend(SQLDatabaseBackend):
|
|
33
|
+
"""
|
|
34
|
+
Backend configuration for a SQL served over a port or remote connection. (MySQL, PostgreSQL, etc.)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
user: The username to connect to the database.
|
|
38
|
+
password: The password to connect to the database.
|
|
39
|
+
host: The host of the database.
|
|
40
|
+
database: The name of the database.
|
|
41
|
+
port: The port of the database.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
user: str
|
|
45
|
+
password: str
|
|
46
|
+
host: str
|
|
47
|
+
database: str
|
|
48
|
+
port: Optional[str] = None
|
|
49
|
+
dialect: str = "mysql"
|
|
50
|
+
driver: Optional[str] = None
|
|
51
|
+
query_params: Optional[dict] = None
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def url(self) -> str:
|
|
55
|
+
port = f":{self.port}" if self.port else ""
|
|
56
|
+
driver = f"+{self.driver}" if self.driver else ""
|
|
57
|
+
query = f"?{urlencode(self.query_params)}" if self.query_params else ""
|
|
58
|
+
return f"{self.dialect}{driver}://{self.user}:{self.password}@{self.host}{port}/{self.database}{query}"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class SQLiteBackend(SQLDatabaseBackend):
|
|
63
|
+
"""
|
|
64
|
+
Backend configuration for a SQLite database.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
db_path: Path to the SQLite database file.
|
|
68
|
+
in_memory: Whether to create an in-memory SQLite database.
|
|
69
|
+
query_params: Query parameters to pass to the database.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
db_path: Optional[str] = None
|
|
73
|
+
in_memory: bool = False
|
|
74
|
+
query_params: Optional[dict] = None
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def url(self) -> str:
|
|
78
|
+
base = "sqlite:///:memory:" if self.in_memory else f"sqlite:///{self.db_path}"
|
|
79
|
+
query = f"?{urlencode(self.query_params)}" if self.query_params else ""
|
|
80
|
+
return f"{base}{query}"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class SQLDatabase:
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
backend: SQLDatabaseBackend | str,
|
|
87
|
+
engine_kwargs: dict = None, # TODO: This may not be a smart idea, may want to enforce specific kwargs
|
|
88
|
+
session_kwargs: dict = None, # TODO: This may not be a smart idea, may want to enforce specific kwargs
|
|
89
|
+
):
|
|
90
|
+
self.url = backend.url if isinstance(backend, SQLDatabaseBackend) else backend
|
|
91
|
+
self.engine = create_engine(self.url, **(engine_kwargs or {}))
|
|
92
|
+
self.metadata = MetaData()
|
|
93
|
+
self.session = sessionmaker(bind=self.engine, **(session_kwargs or {}))
|
|
94
|
+
self.inspector = inspect(self.engine)
|
|
95
|
+
self.preparer = IdentifierPreparer(sqlite.dialect())
|
|
96
|
+
|
|
97
|
+
# Create metadata table to store table metadata
|
|
98
|
+
self.metadata_table = SQLTable(
|
|
99
|
+
"metadata",
|
|
100
|
+
self.metadata,
|
|
101
|
+
Column("table_name", String(255), primary_key=True),
|
|
102
|
+
Column("metadata_json", Text),
|
|
103
|
+
)
|
|
104
|
+
self.metadata.create_all(self.engine)
|
|
105
|
+
self.connection = None
|
|
106
|
+
self._in_transaction = False
|
|
107
|
+
|
|
108
|
+
def add_table(
|
|
109
|
+
self,
|
|
110
|
+
table: BaseTable,
|
|
111
|
+
if_exists: Literal["fail", "replace", "append"] = "replace",
|
|
112
|
+
schema: str = None,
|
|
113
|
+
):
|
|
114
|
+
# TODO: support batch inserting for large dataframes
|
|
115
|
+
with self.connect() as connection:
|
|
116
|
+
# Use the connection for to_sql to respect transaction context
|
|
117
|
+
table._df.to_sql(table.name, connection, if_exists=if_exists, index=False)
|
|
118
|
+
|
|
119
|
+
# Remove existing metadata for this table if it exists
|
|
120
|
+
connection.execute(self.metadata_table.delete().where(self.metadata_table.c.table_name == table.name))
|
|
121
|
+
|
|
122
|
+
# Insert new metadata
|
|
123
|
+
connection.execute(
|
|
124
|
+
self.metadata_table.insert().values(
|
|
125
|
+
table_name=table.name,
|
|
126
|
+
metadata_json=json.dumps(table.metadata),
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
if self._should_commit():
|
|
130
|
+
connection.commit()
|
|
131
|
+
|
|
132
|
+
def add_tables(
|
|
133
|
+
self,
|
|
134
|
+
tables: Iterable[BaseTable],
|
|
135
|
+
if_exists: Literal["fail", "replace", "append"] = "replace",
|
|
136
|
+
schema: str = None,
|
|
137
|
+
):
|
|
138
|
+
for table in tables:
|
|
139
|
+
self.add_table(table, if_exists, schema)
|
|
140
|
+
|
|
141
|
+
def drop_table(self, name: str, must_exist: bool = False):
|
|
142
|
+
"""
|
|
143
|
+
Drop a table from the database and remove its metadata.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
name: The name of the table to drop
|
|
147
|
+
"""
|
|
148
|
+
if_exists = "IF EXISTS" if not must_exist else ""
|
|
149
|
+
safe_name = self.preparer.quote(name)
|
|
150
|
+
with self.connect() as connection:
|
|
151
|
+
command = text(f"DROP TABLE {if_exists} {safe_name}")
|
|
152
|
+
connection.execute(command)
|
|
153
|
+
# Also remove metadata for this table
|
|
154
|
+
connection.execute(self.metadata_table.delete().where(self.metadata_table.c.table_name == name))
|
|
155
|
+
if self._should_commit():
|
|
156
|
+
connection.commit()
|
|
157
|
+
|
|
158
|
+
def drop_tables(self, names: Iterable[str], must_exist: bool = False):
|
|
159
|
+
for name in names:
|
|
160
|
+
self.drop_table(name, must_exist)
|
|
161
|
+
|
|
162
|
+
def list_tables(self) -> List[str]:
|
|
163
|
+
"""
|
|
164
|
+
List all tables currently in the database.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
List of table names in the database.
|
|
168
|
+
"""
|
|
169
|
+
# Refresh the inspector to ensure we get current table list
|
|
170
|
+
self.inspector = inspect(self.engine)
|
|
171
|
+
return self.inspector.get_table_names()
|
|
172
|
+
|
|
173
|
+
def get_table(self, name: str) -> BaseTable:
|
|
174
|
+
df = pd.read_sql_table(name, self.engine)
|
|
175
|
+
|
|
176
|
+
return Table(df=df, name=name, metadata=self.get_table_metadata(name))
|
|
177
|
+
|
|
178
|
+
def get_table_schema(self, name: str) -> List[Column]:
|
|
179
|
+
"""
|
|
180
|
+
Return column schema for a given table.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
name: The name of the table to get schema for
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Column schema information for the table.
|
|
187
|
+
"""
|
|
188
|
+
return self.inspector.get_columns(name)
|
|
189
|
+
|
|
190
|
+
def get_table_metadata(self, name: str) -> dict:
|
|
191
|
+
"""
|
|
192
|
+
Get metadata for a specific table.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
name: The name of the table to get metadata for
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Dictionary containing the table's metadata, or empty dict if not found.
|
|
199
|
+
"""
|
|
200
|
+
with self.connect() as connection:
|
|
201
|
+
result = connection.execute(
|
|
202
|
+
self.metadata_table.select().where(self.metadata_table.c.table_name == name)
|
|
203
|
+
).fetchone()
|
|
204
|
+
|
|
205
|
+
if result:
|
|
206
|
+
return json.loads(result.metadata_json)
|
|
207
|
+
return {}
|
|
208
|
+
|
|
209
|
+
def query(self, query: str) -> CursorResult:
|
|
210
|
+
with self.connect() as connection:
|
|
211
|
+
result = connection.execute(text(query))
|
|
212
|
+
return result
|
|
213
|
+
|
|
214
|
+
def fetchall(self, query: str) -> List[Tuple]:
|
|
215
|
+
result = self.query(query)
|
|
216
|
+
return result.fetchall()
|
|
217
|
+
|
|
218
|
+
def fetchone(self, query: str) -> Tuple:
|
|
219
|
+
result = self.query(query)
|
|
220
|
+
return result.fetchone()
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def from_file_store(
|
|
224
|
+
cls,
|
|
225
|
+
file_store: FileStore,
|
|
226
|
+
backend: SQLDatabaseBackend,
|
|
227
|
+
folder: Optional[str] = None,
|
|
228
|
+
table_created_hook: Optional[Callable[[TableFile], Any]] = None,
|
|
229
|
+
) -> "SQLDatabase":
|
|
230
|
+
# TODO: support batch inserting and parallel processing
|
|
231
|
+
"""
|
|
232
|
+
Initializes a new SQLDatabase from a file store.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
file_store: File store containing files to load
|
|
236
|
+
backend: SQL database backend
|
|
237
|
+
folder: Folder in the file store to load
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
New SQLDatabase instance loaded with data from the file store.
|
|
241
|
+
"""
|
|
242
|
+
# TODO: make sure the loaded sql database is empty if not raise error and tell user to use __init__ for an already existing database
|
|
243
|
+
instance = cls(backend)
|
|
244
|
+
instance.add_file_store(file_store, folder, table_created_hook)
|
|
245
|
+
return instance
|
|
246
|
+
|
|
247
|
+
def add_file_store(
|
|
248
|
+
self,
|
|
249
|
+
file_store: FileStore,
|
|
250
|
+
folder: Optional[str] = None,
|
|
251
|
+
table_created_hook: Optional[Callable[[TableFile], Any]] = None,
|
|
252
|
+
):
|
|
253
|
+
with self.begin():
|
|
254
|
+
for key, _ in tqdm(file_store.items(folder), desc="Uploading files to SQL database"):
|
|
255
|
+
table = TableFile.from_file_store(key, file_store)
|
|
256
|
+
self.add_table(table, if_exists="fail")
|
|
257
|
+
if table_created_hook:
|
|
258
|
+
table_created_hook(table)
|
|
259
|
+
|
|
260
|
+
@contextmanager
|
|
261
|
+
def connect(self):
|
|
262
|
+
"""
|
|
263
|
+
Context manager for database connections.
|
|
264
|
+
Reuses existing connection if available, otherwise creates a temporary one.
|
|
265
|
+
"""
|
|
266
|
+
connection_existed = self.connection is not None
|
|
267
|
+
if not connection_existed:
|
|
268
|
+
self.connection = self.engine.connect()
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
yield self.connection
|
|
272
|
+
finally:
|
|
273
|
+
# Only close if we created the connection for this operation
|
|
274
|
+
if not connection_existed:
|
|
275
|
+
self.close()
|
|
276
|
+
|
|
277
|
+
def open_persistent_connection(self):
|
|
278
|
+
"""
|
|
279
|
+
Opens a persistent connection that will be reused across operations.
|
|
280
|
+
Call close() to close the persistent connection.
|
|
281
|
+
"""
|
|
282
|
+
if self.connection is None:
|
|
283
|
+
self.connection = self.engine.connect()
|
|
284
|
+
|
|
285
|
+
def close(self):
|
|
286
|
+
"""
|
|
287
|
+
Closes the current connection if one exists.
|
|
288
|
+
"""
|
|
289
|
+
if self.connection:
|
|
290
|
+
self.connection.close()
|
|
291
|
+
self.connection = None
|
|
292
|
+
|
|
293
|
+
def _should_commit(self) -> bool:
|
|
294
|
+
"""
|
|
295
|
+
Returns True if operations should commit immediately.
|
|
296
|
+
Returns False if we're within an explicit transaction context.
|
|
297
|
+
"""
|
|
298
|
+
return not self._in_transaction
|
|
299
|
+
|
|
300
|
+
@contextmanager
|
|
301
|
+
def begin(self):
|
|
302
|
+
"""
|
|
303
|
+
Context manager for database transactions using existing connection.
|
|
304
|
+
Requires an active connection. Commits on success, rolls back on exception.
|
|
305
|
+
|
|
306
|
+
Raises:
|
|
307
|
+
RuntimeError: If no active connection exists
|
|
308
|
+
"""
|
|
309
|
+
if self.connection is None:
|
|
310
|
+
raise RuntimeError("No active connection. Use connect_and_begin() or open a connection first.")
|
|
311
|
+
|
|
312
|
+
transaction = self.connection.begin()
|
|
313
|
+
old_in_transaction = self._in_transaction
|
|
314
|
+
self._in_transaction = True
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
yield self.connection
|
|
318
|
+
transaction.commit()
|
|
319
|
+
except Exception:
|
|
320
|
+
transaction.rollback()
|
|
321
|
+
raise
|
|
322
|
+
finally:
|
|
323
|
+
self._in_transaction = old_in_transaction
|
|
324
|
+
|
|
325
|
+
@contextmanager
|
|
326
|
+
def connect_and_begin(self):
|
|
327
|
+
"""
|
|
328
|
+
Context manager that establishes a connection and starts a transaction.
|
|
329
|
+
Reuses existing connection if available, otherwise creates a temporary one.
|
|
330
|
+
Commits on success, rolls back on exception.
|
|
331
|
+
"""
|
|
332
|
+
connection_existed = self.connection is not None
|
|
333
|
+
if not connection_existed:
|
|
334
|
+
self.connection = self.engine.connect()
|
|
335
|
+
|
|
336
|
+
transaction = self.connection.begin()
|
|
337
|
+
old_in_transaction = self._in_transaction
|
|
338
|
+
self._in_transaction = True
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
yield self.connection
|
|
342
|
+
transaction.commit()
|
|
343
|
+
except Exception:
|
|
344
|
+
transaction.rollback()
|
|
345
|
+
raise
|
|
346
|
+
finally:
|
|
347
|
+
self._in_transaction = old_in_transaction
|
|
348
|
+
# Only close if we created the connection for this operation
|
|
349
|
+
if not connection_existed:
|
|
350
|
+
self.close()
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class MultiTenantSQLDatabase:
|
|
354
|
+
def __init__(self):
|
|
355
|
+
raise NotImplementedError("Not implemented")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .vector_database import IndexConfig, IndexType, Metric, SupportsHybridSearch, VectorDatabase, VectorType
|
|
2
|
+
from .vendors.milvus import MilvusBackend
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"VectorDatabase",
|
|
6
|
+
"SupportsHybridSearch",
|
|
7
|
+
"MilvusBackend",
|
|
8
|
+
"IndexConfig",
|
|
9
|
+
"IndexType",
|
|
10
|
+
"VectorType",
|
|
11
|
+
"Metric",
|
|
12
|
+
]
|