datus-sqlalchemy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
# Copyright 2025-present DatusAI, Inc.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0.
|
|
3
|
+
# See http://www.apache.org/licenses/LICENSE-2.0 for details.
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, override
|
|
6
|
+
|
|
7
|
+
from datus.schemas.base import TABLE_TYPE
|
|
8
|
+
from datus.schemas.node_models import ExecuteSQLResult
|
|
9
|
+
from datus.tools.db_tools.base import BaseSqlConnector
|
|
10
|
+
from datus.tools.db_tools.config import ConnectionConfig
|
|
11
|
+
from datus.utils.constants import DBType, SQLType
|
|
12
|
+
from datus.utils.exceptions import DatusException, ErrorCode
|
|
13
|
+
from datus.utils.loggings import get_logger
|
|
14
|
+
from datus.utils.sql_utils import parse_context_switch, parse_sql_type
|
|
15
|
+
from pandas import DataFrame
|
|
16
|
+
from pyarrow import Table
|
|
17
|
+
from sqlalchemy import create_engine, inspect, text
|
|
18
|
+
from sqlalchemy.engine import Inspector
|
|
19
|
+
from sqlalchemy.exc import (
|
|
20
|
+
DatabaseError,
|
|
21
|
+
DataError,
|
|
22
|
+
IntegrityError,
|
|
23
|
+
InterfaceError,
|
|
24
|
+
InternalError,
|
|
25
|
+
NoSuchTableError,
|
|
26
|
+
NotSupportedError,
|
|
27
|
+
OperationalError,
|
|
28
|
+
ProgrammingError,
|
|
29
|
+
SQLAlchemyError,
|
|
30
|
+
TimeoutError,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
logger = get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SQLAlchemyConnector(BaseSqlConnector):
|
|
37
|
+
"""
|
|
38
|
+
Base SQLAlchemy connector for database adapters.
|
|
39
|
+
Provides common SQLAlchemy functionality with Arrow support.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, connection_string: str, dialect: str = "", timeout_seconds: int = 30):
|
|
43
|
+
"""
|
|
44
|
+
Initialize SQLAlchemyConnector.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
connection_string: SQLAlchemy connection string
|
|
48
|
+
dialect: Database dialect (mysql, postgresql, etc.)
|
|
49
|
+
timeout_seconds: Connection timeout in seconds
|
|
50
|
+
"""
|
|
51
|
+
# Auto-detect dialect from connection string if not provided
|
|
52
|
+
if not dialect:
|
|
53
|
+
prefix = connection_string.split(":")[0] if isinstance(connection_string, str) else "unknown"
|
|
54
|
+
dialect = DBType.MYSQL if prefix == "mysql+pymysql" else prefix
|
|
55
|
+
|
|
56
|
+
config = ConnectionConfig(timeout_seconds=timeout_seconds)
|
|
57
|
+
super().__init__(config, dialect)
|
|
58
|
+
self.connection_string = connection_string
|
|
59
|
+
self.engine = None
|
|
60
|
+
self.connection = None
|
|
61
|
+
self._owns_engine = False
|
|
62
|
+
|
|
63
|
+
def __del__(self):
|
|
64
|
+
"""Destructor to ensure connections are properly closed."""
|
|
65
|
+
try:
|
|
66
|
+
self.close()
|
|
67
|
+
except Exception:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
# ==================== Connection Management ====================
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
def connect(self):
|
|
74
|
+
"""Establish connection to the database."""
|
|
75
|
+
if self.engine and self.connection and self._owns_engine:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
self._safe_close()
|
|
80
|
+
|
|
81
|
+
# Create engine based on dialect
|
|
82
|
+
if self.dialect not in (DBType.DUCKDB, DBType.SQLITE):
|
|
83
|
+
self.engine = create_engine(
|
|
84
|
+
self.connection_string,
|
|
85
|
+
pool_size=3,
|
|
86
|
+
max_overflow=5,
|
|
87
|
+
pool_timeout=self.timeout_seconds,
|
|
88
|
+
pool_recycle=3600,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
self.engine = create_engine(self.connection_string)
|
|
92
|
+
|
|
93
|
+
self.connection = self.engine.connect()
|
|
94
|
+
self._owns_engine = True
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
self._force_reset()
|
|
98
|
+
raise self._handle_exception(e, "", "connection") from e
|
|
99
|
+
|
|
100
|
+
if not (self.engine and self.connection):
|
|
101
|
+
self._force_reset()
|
|
102
|
+
raise DatusException(
|
|
103
|
+
ErrorCode.DB_CONNECTION_FAILED, message_args={"error_message": "Failed to establish connection"}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@override
|
|
107
|
+
def close(self):
|
|
108
|
+
"""Close the database connection."""
|
|
109
|
+
try:
|
|
110
|
+
if self.connection:
|
|
111
|
+
self.connection.close()
|
|
112
|
+
self.connection = None
|
|
113
|
+
if self.engine:
|
|
114
|
+
self.engine.dispose()
|
|
115
|
+
self.engine = None
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.warning(f"Error closing connection: {str(e)}")
|
|
118
|
+
|
|
119
|
+
def _safe_close(self):
|
|
120
|
+
"""Safely close connection, ignoring errors."""
|
|
121
|
+
try:
|
|
122
|
+
self.close()
|
|
123
|
+
except Exception:
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def _force_reset(self):
|
|
127
|
+
"""Force reset connection on error."""
|
|
128
|
+
try:
|
|
129
|
+
self._safe_rollback()
|
|
130
|
+
if self.connection:
|
|
131
|
+
try:
|
|
132
|
+
self.connection.close()
|
|
133
|
+
except Exception:
|
|
134
|
+
pass
|
|
135
|
+
self.connection = None
|
|
136
|
+
if self.engine:
|
|
137
|
+
try:
|
|
138
|
+
self.engine.dispose()
|
|
139
|
+
except Exception:
|
|
140
|
+
pass
|
|
141
|
+
self.engine = None
|
|
142
|
+
self._owns_engine = False
|
|
143
|
+
except Exception:
|
|
144
|
+
self.connection = None
|
|
145
|
+
self.engine = None
|
|
146
|
+
self._owns_engine = False
|
|
147
|
+
|
|
148
|
+
def _safe_rollback(self):
|
|
149
|
+
"""Safely rollback transaction."""
|
|
150
|
+
if self.connection:
|
|
151
|
+
try:
|
|
152
|
+
self.connection.rollback()
|
|
153
|
+
except Exception:
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
# ==================== Error Handling ====================
|
|
157
|
+
|
|
158
|
+
def _handle_exception(self, e: Exception, sql: str = "", operation: str = "SQL execution") -> DatusException:
|
|
159
|
+
"""Map SQLAlchemy exceptions to Datus exceptions."""
|
|
160
|
+
if isinstance(e, DatusException):
|
|
161
|
+
return e
|
|
162
|
+
|
|
163
|
+
# Extract error message
|
|
164
|
+
if hasattr(e, "detail") and e.detail:
|
|
165
|
+
error_message = str(e.detail) if not isinstance(e.detail, list) else "\n".join(e.detail)
|
|
166
|
+
elif hasattr(e, "orig") and e.orig is not None:
|
|
167
|
+
error_message = str(e.orig)
|
|
168
|
+
else:
|
|
169
|
+
error_message = str(e)
|
|
170
|
+
|
|
171
|
+
message_args = {"error_message": error_message, "sql": sql}
|
|
172
|
+
error_msg_lower = error_message.lower()
|
|
173
|
+
|
|
174
|
+
# Syntax errors
|
|
175
|
+
if any(kw in error_msg_lower for kw in ["syntax", "parse error", "sql error"]):
|
|
176
|
+
return DatusException(ErrorCode.DB_EXECUTION_SYNTAX_ERROR, message_args=message_args)
|
|
177
|
+
|
|
178
|
+
# Table not found
|
|
179
|
+
if isinstance(e, NoSuchTableError):
|
|
180
|
+
return DatusException(ErrorCode.DB_TABLE_NOT_EXISTS, message_args={"table_name": str(e)})
|
|
181
|
+
|
|
182
|
+
# Connection and operational errors
|
|
183
|
+
if isinstance(e, (OperationalError, InterfaceError)):
|
|
184
|
+
# Transaction rollback errors
|
|
185
|
+
if any(kw in error_msg_lower for kw in ["invalid transaction", "can't reconnect"]):
|
|
186
|
+
logger.warning("Invalid transaction state detected, resetting connection")
|
|
187
|
+
self._force_reset()
|
|
188
|
+
return DatusException(ErrorCode.DB_TRANSACTION_FAILED, message_args=message_args)
|
|
189
|
+
|
|
190
|
+
# Timeout errors
|
|
191
|
+
if any(kw in error_msg_lower for kw in ["timeout", "timed out"]):
|
|
192
|
+
return DatusException(ErrorCode.DB_CONNECTION_TIMEOUT, message_args=message_args)
|
|
193
|
+
|
|
194
|
+
# Authentication errors
|
|
195
|
+
if any(kw in error_msg_lower for kw in ["authentication", "access denied", "login failed"]):
|
|
196
|
+
return DatusException(ErrorCode.DB_AUTHENTICATION_FAILED, message_args=message_args)
|
|
197
|
+
|
|
198
|
+
# Permission errors
|
|
199
|
+
if any(kw in error_msg_lower for kw in ["permission denied", "insufficient privilege"]):
|
|
200
|
+
message_args["operation"] = operation
|
|
201
|
+
return DatusException(ErrorCode.DB_PERMISSION_DENIED, message_args=message_args)
|
|
202
|
+
|
|
203
|
+
# Connection errors
|
|
204
|
+
if any(kw in error_msg_lower for kw in ["connection refused", "connection failed", "unable to open"]):
|
|
205
|
+
return DatusException(ErrorCode.DB_CONNECTION_FAILED, message_args=message_args)
|
|
206
|
+
|
|
207
|
+
return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
|
|
208
|
+
|
|
209
|
+
# Programming errors
|
|
210
|
+
if isinstance(e, ProgrammingError):
|
|
211
|
+
if any(kw in error_msg_lower for kw in ["syntax", "parse error", "sql error"]):
|
|
212
|
+
return DatusException(ErrorCode.DB_EXECUTION_SYNTAX_ERROR, message_args=message_args)
|
|
213
|
+
return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
|
|
214
|
+
|
|
215
|
+
# Constraint violations
|
|
216
|
+
if isinstance(e, IntegrityError):
|
|
217
|
+
return DatusException(ErrorCode.DB_CONSTRAINT_VIOLATION, message_args=message_args)
|
|
218
|
+
|
|
219
|
+
# Timeout errors
|
|
220
|
+
if isinstance(e, TimeoutError):
|
|
221
|
+
return DatusException(ErrorCode.DB_EXECUTION_TIMEOUT, message_args=message_args)
|
|
222
|
+
|
|
223
|
+
# Other database errors
|
|
224
|
+
if isinstance(e, (DatabaseError, DataError, InternalError, NotSupportedError)):
|
|
225
|
+
return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
|
|
226
|
+
|
|
227
|
+
# Fallback
|
|
228
|
+
return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
|
|
229
|
+
|
|
230
|
+
# ==================== Core Execute Methods ====================
|
|
231
|
+
|
|
232
|
+
@override
|
|
233
|
+
def execute_query(
|
|
234
|
+
self, sql: str, result_format: Literal["csv", "arrow", "pandas", "list"] = "csv"
|
|
235
|
+
) -> ExecuteSQLResult:
|
|
236
|
+
"""Execute SELECT query."""
|
|
237
|
+
try:
|
|
238
|
+
self.connect()
|
|
239
|
+
result = self._execute_query(sql)
|
|
240
|
+
row_count = len(result)
|
|
241
|
+
|
|
242
|
+
# Format result based on requested format
|
|
243
|
+
if result_format == "csv":
|
|
244
|
+
df = DataFrame(result)
|
|
245
|
+
result = df.to_csv(index=False)
|
|
246
|
+
elif result_format == "arrow":
|
|
247
|
+
result = Table.from_pylist(result)
|
|
248
|
+
elif result_format == "pandas":
|
|
249
|
+
result = DataFrame(result)
|
|
250
|
+
|
|
251
|
+
return ExecuteSQLResult(
|
|
252
|
+
success=True, sql_query=sql, sql_return=result, row_count=row_count, result_format=result_format
|
|
253
|
+
)
|
|
254
|
+
except Exception as e:
|
|
255
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
256
|
+
return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql)
|
|
257
|
+
|
|
258
|
+
def _execute_query(self, sql: str) -> List[Dict[str, Any]]:
|
|
259
|
+
"""Internal query execution returning list of dicts."""
|
|
260
|
+
if parse_sql_type(sql, self.dialect) in (
|
|
261
|
+
SQLType.INSERT,
|
|
262
|
+
SQLType.UPDATE,
|
|
263
|
+
SQLType.DELETE,
|
|
264
|
+
SQLType.MERGE,
|
|
265
|
+
SQLType.CONTENT_SET,
|
|
266
|
+
SQLType.UNKNOWN,
|
|
267
|
+
):
|
|
268
|
+
raise DatusException(ErrorCode.DB_EXECUTION_ERROR, message="Only SELECT and metadata queries are supported")
|
|
269
|
+
|
|
270
|
+
self.connect()
|
|
271
|
+
try:
|
|
272
|
+
result = self.connection.execute(text(sql))
|
|
273
|
+
rows = result.fetchall()
|
|
274
|
+
return [row._asdict() for row in rows]
|
|
275
|
+
except DatusException:
|
|
276
|
+
raise
|
|
277
|
+
except Exception as e:
|
|
278
|
+
raise self._handle_exception(e, sql, "query") from e
|
|
279
|
+
|
|
280
|
+
@override
|
|
281
|
+
def execute_insert(self, sql: str) -> ExecuteSQLResult:
|
|
282
|
+
"""Execute INSERT statement."""
|
|
283
|
+
try:
|
|
284
|
+
self.connect()
|
|
285
|
+
res = self.connection.execute(text(sql))
|
|
286
|
+
self.connection.commit()
|
|
287
|
+
|
|
288
|
+
# Get inserted primary key or row count
|
|
289
|
+
inserted_pk = None
|
|
290
|
+
try:
|
|
291
|
+
if hasattr(res, "inserted_primary_key") and res.inserted_primary_key:
|
|
292
|
+
inserted_pk = res.inserted_primary_key
|
|
293
|
+
except Exception:
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
lastrowid = getattr(res, "lastrowid", None)
|
|
297
|
+
return_value = inserted_pk if inserted_pk else (lastrowid if lastrowid else res.rowcount)
|
|
298
|
+
|
|
299
|
+
return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(return_value), row_count=res.rowcount)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
self._safe_rollback()
|
|
302
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
303
|
+
return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0)
|
|
304
|
+
|
|
305
|
+
@override
|
|
306
|
+
def execute_update(self, sql: str) -> ExecuteSQLResult:
|
|
307
|
+
"""Execute UPDATE statement."""
|
|
308
|
+
try:
|
|
309
|
+
self.connect()
|
|
310
|
+
res = self.connection.execute(text(sql))
|
|
311
|
+
self.connection.commit()
|
|
312
|
+
return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount)
|
|
313
|
+
except Exception as e:
|
|
314
|
+
self._safe_rollback()
|
|
315
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
316
|
+
return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0)
|
|
317
|
+
|
|
318
|
+
@override
|
|
319
|
+
def execute_delete(self, sql: str) -> ExecuteSQLResult:
|
|
320
|
+
"""Execute DELETE statement."""
|
|
321
|
+
try:
|
|
322
|
+
self.connect()
|
|
323
|
+
res = self.connection.execute(text(sql))
|
|
324
|
+
self.connection.commit()
|
|
325
|
+
return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount)
|
|
326
|
+
except Exception as e:
|
|
327
|
+
self._safe_rollback()
|
|
328
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
329
|
+
return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0)
|
|
330
|
+
|
|
331
|
+
@override
|
|
332
|
+
def execute_ddl(self, sql: str) -> ExecuteSQLResult:
|
|
333
|
+
"""Execute DDL statement (CREATE, ALTER, DROP, etc.)."""
|
|
334
|
+
try:
|
|
335
|
+
self.connect()
|
|
336
|
+
res = self.connection.execute(text(sql))
|
|
337
|
+
self.connection.commit()
|
|
338
|
+
return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount)
|
|
339
|
+
except Exception as e:
|
|
340
|
+
self._safe_rollback()
|
|
341
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
342
|
+
return ExecuteSQLResult(success=False, sql_query=sql, error=str(ex))
|
|
343
|
+
|
|
344
|
+
def execute_pandas(self, sql: str) -> ExecuteSQLResult:
|
|
345
|
+
"""Execute query and return pandas DataFrame."""
|
|
346
|
+
try:
|
|
347
|
+
df = self._execute_pandas(sql)
|
|
348
|
+
return ExecuteSQLResult(
|
|
349
|
+
success=True, sql_query=sql, sql_return=df, row_count=len(df), result_format="pandas"
|
|
350
|
+
)
|
|
351
|
+
except Exception as e:
|
|
352
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
353
|
+
return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql)
|
|
354
|
+
|
|
355
|
+
def _execute_pandas(self, sql: str) -> DataFrame:
|
|
356
|
+
"""Internal pandas execution."""
|
|
357
|
+
return DataFrame(self._execute_query(sql))
|
|
358
|
+
|
|
359
|
+
def execute_csv(self, sql: str) -> ExecuteSQLResult:
|
|
360
|
+
"""Execute query and return CSV format."""
|
|
361
|
+
try:
|
|
362
|
+
self.connect()
|
|
363
|
+
df = self._execute_pandas(sql)
|
|
364
|
+
return ExecuteSQLResult(
|
|
365
|
+
success=True, sql_query=sql, sql_return=df.to_csv(index=False), row_count=len(df), result_format="csv"
|
|
366
|
+
)
|
|
367
|
+
except Exception as e:
|
|
368
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
369
|
+
return ExecuteSQLResult(
|
|
370
|
+
success=False, sql_query=sql, sql_return="", row_count=0, error=str(ex), result_format="csv"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def execute_arrow(self, sql: str) -> ExecuteSQLResult:
|
|
374
|
+
"""Execute query and return Arrow table."""
|
|
375
|
+
try:
|
|
376
|
+
self.connect()
|
|
377
|
+
result = self.connection.execute(text(sql))
|
|
378
|
+
if result.returns_rows:
|
|
379
|
+
df = DataFrame(result.fetchall(), columns=result.keys())
|
|
380
|
+
table = Table.from_pandas(df)
|
|
381
|
+
return ExecuteSQLResult(
|
|
382
|
+
success=True, sql_query=sql, sql_return=table, row_count=len(df), result_format="arrow"
|
|
383
|
+
)
|
|
384
|
+
return ExecuteSQLResult(
|
|
385
|
+
success=True, sql_query=sql, sql_return=result.rowcount, row_count=0, result_format="arrow"
|
|
386
|
+
)
|
|
387
|
+
except Exception as e:
|
|
388
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
389
|
+
return ExecuteSQLResult(
|
|
390
|
+
success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0, result_format="arrow"
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
@override
|
|
394
|
+
def execute_content_set(self, sql: str) -> ExecuteSQLResult:
|
|
395
|
+
"""Execute USE/SET commands."""
|
|
396
|
+
self.connect()
|
|
397
|
+
try:
|
|
398
|
+
self.connection.execute(text(sql))
|
|
399
|
+
self.connection.commit()
|
|
400
|
+
|
|
401
|
+
# Update context if applicable
|
|
402
|
+
if self.dialect != DBType.SQLITE.value:
|
|
403
|
+
context = parse_context_switch(sql=sql, dialect=self.dialect)
|
|
404
|
+
if context:
|
|
405
|
+
if catalog := context.get("catalog_name"):
|
|
406
|
+
self.catalog_name = catalog
|
|
407
|
+
if database := context.get("database_name"):
|
|
408
|
+
self.database_name = database
|
|
409
|
+
if schema := context.get("schema_name"):
|
|
410
|
+
self.schema_name = schema
|
|
411
|
+
|
|
412
|
+
return ExecuteSQLResult(success=True, sql_query=sql, sql_return="Successful", row_count=0)
|
|
413
|
+
except Exception as e:
|
|
414
|
+
self._safe_rollback()
|
|
415
|
+
ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
|
|
416
|
+
return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql)
|
|
417
|
+
|
|
418
|
+
def execute_queries(self, queries: List[str]) -> List[Any]:
|
|
419
|
+
"""Execute multiple queries."""
|
|
420
|
+
results = []
|
|
421
|
+
self.connect()
|
|
422
|
+
try:
|
|
423
|
+
for query in queries:
|
|
424
|
+
result = self.connection.execute(text(query))
|
|
425
|
+
if result.returns_rows:
|
|
426
|
+
df = DataFrame(result.fetchall(), columns=list(result.keys()))
|
|
427
|
+
results.append(df.to_dict(orient="records"))
|
|
428
|
+
else:
|
|
429
|
+
query_lower = query.strip().lower()
|
|
430
|
+
if query_lower.startswith("insert"):
|
|
431
|
+
inserted_pk = None
|
|
432
|
+
try:
|
|
433
|
+
if hasattr(result, "inserted_primary_key") and result.inserted_primary_key:
|
|
434
|
+
inserted_pk = result.inserted_primary_key
|
|
435
|
+
except Exception:
|
|
436
|
+
pass
|
|
437
|
+
lastrowid = getattr(result, "lastrowid", None)
|
|
438
|
+
results.append(inserted_pk if inserted_pk else (lastrowid if lastrowid else result.rowcount))
|
|
439
|
+
elif query_lower.startswith(("update", "delete")):
|
|
440
|
+
results.append(result.rowcount)
|
|
441
|
+
else:
|
|
442
|
+
results.append(None)
|
|
443
|
+
self.connection.commit()
|
|
444
|
+
except SQLAlchemyError as e:
|
|
445
|
+
self._safe_rollback()
|
|
446
|
+
raise self._handle_exception(e, "\n".join(queries), "batch query") from e
|
|
447
|
+
return results
|
|
448
|
+
|
|
449
|
+
def test_connection(self) -> bool:
|
|
450
|
+
"""Test database connection."""
|
|
451
|
+
self.connect()
|
|
452
|
+
try:
|
|
453
|
+
self._execute_query("SELECT 1")
|
|
454
|
+
return True
|
|
455
|
+
except Exception as e:
|
|
456
|
+
self._safe_close()
|
|
457
|
+
if isinstance(e, DatusException):
|
|
458
|
+
raise
|
|
459
|
+
raise DatusException(
|
|
460
|
+
ErrorCode.DB_CONNECTION_FAILED, message_args={"error_message": "Connection test failed"}
|
|
461
|
+
) from e
|
|
462
|
+
finally:
|
|
463
|
+
self._safe_close()
|
|
464
|
+
|
|
465
|
+
# ==================== Metadata Methods ====================
|
|
466
|
+
|
|
467
|
+
def _inspector(self) -> Inspector:
|
|
468
|
+
"""Get SQLAlchemy inspector."""
|
|
469
|
+
self.connect()
|
|
470
|
+
try:
|
|
471
|
+
return inspect(self.engine)
|
|
472
|
+
except Exception as e:
|
|
473
|
+
raise self._handle_exception(e, operation="inspector creation") from e
|
|
474
|
+
|
|
475
|
+
def get_tables(self, catalog_name: str = "", database_name: str = "", schema_name: str = "") -> List[str]:
|
|
476
|
+
"""Get list of tables."""
|
|
477
|
+
self.connect()
|
|
478
|
+
sqlalchemy_schema = self._sqlalchemy_schema(catalog_name, database_name, schema_name)
|
|
479
|
+
inspector = self._inspector()
|
|
480
|
+
return inspector.get_table_names(schema=sqlalchemy_schema)
|
|
481
|
+
|
|
482
|
+
def get_views(self, catalog_name: str = "", database_name: str = "", schema_name: str = "") -> List[str]:
|
|
483
|
+
"""Get list of views."""
|
|
484
|
+
self.connect()
|
|
485
|
+
sqlalchemy_schema = self._sqlalchemy_schema(catalog_name, database_name, schema_name)
|
|
486
|
+
inspector = self._inspector()
|
|
487
|
+
try:
|
|
488
|
+
return inspector.get_view_names(schema=sqlalchemy_schema)
|
|
489
|
+
except Exception as e:
|
|
490
|
+
raise DatusException(
|
|
491
|
+
ErrorCode.DB_FAILED, message_args={"operation": "get_views", "error_message": str(e)}
|
|
492
|
+
) from e
|
|
493
|
+
|
|
494
|
+
@override
|
|
495
|
+
def get_schemas(self, catalog_name: str = "", database_name: str = "", include_sys: bool = False) -> List[str]:
|
|
496
|
+
"""Get list of schemas."""
|
|
497
|
+
schemas = self._inspector().get_schema_names()
|
|
498
|
+
if not include_sys:
|
|
499
|
+
system_schemas = self._sys_schemas()
|
|
500
|
+
schemas = [s for s in schemas if s.lower() not in system_schemas]
|
|
501
|
+
return schemas
|
|
502
|
+
|
|
503
|
+
def get_schema(
|
|
504
|
+
self, catalog_name: str = "", database_name: str = "", schema_name: str = "", table_name: str = ""
|
|
505
|
+
) -> List[Dict[str, Any]]:
|
|
506
|
+
"""Get table schema information."""
|
|
507
|
+
sqlalchemy_schema = self._sqlalchemy_schema(
|
|
508
|
+
catalog_name or self.catalog_name, database_name or self.database_name, schema_name or self.schema_name
|
|
509
|
+
)
|
|
510
|
+
inspector = self._inspector()
|
|
511
|
+
try:
|
|
512
|
+
schemas: List[Dict[str, Any]] = []
|
|
513
|
+
pk_columns = set(
|
|
514
|
+
inspector.get_pk_constraint(table_name=table_name, schema=sqlalchemy_schema)["constrained_columns"]
|
|
515
|
+
)
|
|
516
|
+
columns = inspector.get_columns(table_name=table_name, schema=sqlalchemy_schema)
|
|
517
|
+
for i, col in enumerate(columns):
|
|
518
|
+
schemas.append(
|
|
519
|
+
{
|
|
520
|
+
"cid": i,
|
|
521
|
+
"name": col["name"],
|
|
522
|
+
"type": str(col["type"]),
|
|
523
|
+
"comment": str(col["comment"]) if "comment" in col else None,
|
|
524
|
+
"nullable": col["nullable"],
|
|
525
|
+
"pk": col["name"] in pk_columns,
|
|
526
|
+
"default_value": col["default"],
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
return schemas
|
|
530
|
+
except Exception as e:
|
|
531
|
+
raise self._handle_exception(e, sql="", operation="get schema") from e
|
|
532
|
+
|
|
533
|
+
def get_materialized_views(
|
|
534
|
+
self, catalog_name: str = "", database_name: str = "", schema_name: str = ""
|
|
535
|
+
) -> List[str]:
|
|
536
|
+
"""Get list of materialized views."""
|
|
537
|
+
inspector = self._inspector()
|
|
538
|
+
try:
|
|
539
|
+
if hasattr(inspector, "get_materialized_view_names"):
|
|
540
|
+
return inspector.get_materialized_view_names(schema=schema_name if schema_name else None)
|
|
541
|
+
return []
|
|
542
|
+
except Exception as e:
|
|
543
|
+
logger.debug(f"Materialized views not supported: {str(e)}")
|
|
544
|
+
return []
|
|
545
|
+
|
|
546
|
+
def get_sample_rows(
|
|
547
|
+
self,
|
|
548
|
+
tables: Optional[List[str]] = None,
|
|
549
|
+
top_n: int = 5,
|
|
550
|
+
catalog_name: str = "",
|
|
551
|
+
database_name: str = "",
|
|
552
|
+
schema_name: str = "",
|
|
553
|
+
table_type: TABLE_TYPE = "table",
|
|
554
|
+
) -> List[Dict[str, str]]:
|
|
555
|
+
"""Get sample data from tables."""
|
|
556
|
+
self._inspector()
|
|
557
|
+
try:
|
|
558
|
+
samples = []
|
|
559
|
+
if not tables:
|
|
560
|
+
tables = []
|
|
561
|
+
if table_type in ("table", "full"):
|
|
562
|
+
tables.extend(self.get_tables(catalog_name, database_name, schema_name))
|
|
563
|
+
if table_type in ("view", "full"):
|
|
564
|
+
tables.extend(self.get_views(catalog_name, database_name, schema_name))
|
|
565
|
+
if table_type in ("mv", "full"):
|
|
566
|
+
try:
|
|
567
|
+
tables.extend(self.get_materialized_views(catalog_name, database_name, schema_name))
|
|
568
|
+
except Exception as e:
|
|
569
|
+
logger.debug(f"Materialized views not supported: {e}")
|
|
570
|
+
|
|
571
|
+
logger.info(f"Getting sample data from {len(tables)} tables, limit {top_n}")
|
|
572
|
+
for table_name in tables:
|
|
573
|
+
full_name = self.full_name(catalog_name, database_name, schema_name, table_name)
|
|
574
|
+
query = f"SELECT * FROM {full_name} LIMIT {top_n}"
|
|
575
|
+
result = self._execute_pandas(query)
|
|
576
|
+
if not result.empty:
|
|
577
|
+
samples.append(
|
|
578
|
+
{
|
|
579
|
+
"identifier": self.identifier(catalog_name, database_name, schema_name, table_name),
|
|
580
|
+
"catalog_name": catalog_name,
|
|
581
|
+
"database_name": database_name,
|
|
582
|
+
"schema_name": schema_name,
|
|
583
|
+
"table_name": table_name,
|
|
584
|
+
"table_type": table_type,
|
|
585
|
+
"sample_rows": result.to_csv(index=False),
|
|
586
|
+
}
|
|
587
|
+
)
|
|
588
|
+
return samples
|
|
589
|
+
except DatusException:
|
|
590
|
+
raise
|
|
591
|
+
except Exception as e:
|
|
592
|
+
raise self._handle_exception(e) from e
|
|
593
|
+
|
|
594
|
+
def _sqlalchemy_schema(
|
|
595
|
+
self, catalog_name: str = "", database_name: str = "", schema_name: str = ""
|
|
596
|
+
) -> Optional[str]:
|
|
597
|
+
"""Get schema name for SQLAlchemy Inspector."""
|
|
598
|
+
return database_name or schema_name
|
|
599
|
+
|
|
600
|
+
def full_name(
|
|
601
|
+
self, catalog_name: str = "", database_name: str = "", schema_name: str = "", table_name: str = ""
|
|
602
|
+
) -> str:
|
|
603
|
+
"""Build fully-qualified table name."""
|
|
604
|
+
return self.identifier(catalog_name, database_name, schema_name, table_name)
|
|
605
|
+
|
|
606
|
+
# ==================== Streaming Methods ====================
|
|
607
|
+
|
|
608
|
+
def execute_csv_iterator(self, sql: str, max_rows: int = 100, with_header: bool = True) -> Iterator[Tuple]:
|
|
609
|
+
"""Execute query and return CSV rows in batches."""
|
|
610
|
+
self.connect()
|
|
611
|
+
try:
|
|
612
|
+
result = self.connection.execute(text(sql).execution_options(stream_results=True, max_row_buffer=max_rows))
|
|
613
|
+
if result.returns_rows:
|
|
614
|
+
if with_header:
|
|
615
|
+
yield result.keys()
|
|
616
|
+
while True:
|
|
617
|
+
batch_rows = result.fetchmany(max_rows)
|
|
618
|
+
if not batch_rows:
|
|
619
|
+
break
|
|
620
|
+
for row in batch_rows:
|
|
621
|
+
yield row
|
|
622
|
+
else:
|
|
623
|
+
if with_header:
|
|
624
|
+
yield ()
|
|
625
|
+
yield from []
|
|
626
|
+
except Exception as e:
|
|
627
|
+
raise self._handle_exception(e) from e
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: datus-sqlalchemy
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: SQLAlchemy base connector for Datus database adapters
|
|
5
|
+
Project-URL: Homepage, https://github.com/Datus-ai/datus-db-adapters
|
|
6
|
+
Project-URL: Repository, https://github.com/Datus-ai/datus-db-adapters
|
|
7
|
+
Project-URL: Issues, https://github.com/Datus-ai/datus-db-adapters/issues
|
|
8
|
+
Author-email: DatusAI <support@datus.ai>
|
|
9
|
+
License: Apache-2.0
|
|
10
|
+
Keywords: adapter,database,datus,sqlalchemy
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Requires-Python: >=3.12
|
|
17
|
+
Requires-Dist: datus-agent>0.2.1
|
|
18
|
+
Requires-Dist: pandas>=2.1.4
|
|
19
|
+
Requires-Dist: pyarrow<19.0.0,>=14.0.0
|
|
20
|
+
Requires-Dist: sqlalchemy>=2.0.23
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
|
|
23
|
+
# datus-sqlalchemy
|
|
24
|
+
|
|
25
|
+
Base SQLAlchemy connector for Datus database adapters.
|
|
26
|
+
|
|
27
|
+
## Overview
|
|
28
|
+
|
|
29
|
+
`datus-sqlalchemy` provides a common SQLAlchemy-based connector foundation for database adapters in the Datus ecosystem. It is not a standalone database adapter but serves as a shared base class for adapters like MySQL, PostgreSQL, and other SQLAlchemy-compatible databases.
|
|
30
|
+
|
|
31
|
+
## Features
|
|
32
|
+
|
|
33
|
+
- SQLAlchemy engine and connection management
|
|
34
|
+
- Unified error handling and exception mapping
|
|
35
|
+
- Support for multiple result formats (pandas, arrow, csv, list)
|
|
36
|
+
- Connection pooling and lifecycle management
|
|
37
|
+
- Streaming query execution
|
|
38
|
+
- Metadata retrieval methods
|
|
39
|
+
|
|
40
|
+
## Installation
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install datus-sqlalchemy
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
Note: This package is typically installed as a dependency of specific database adapters (e.g., `datus-mysql`).
|
|
47
|
+
|
|
48
|
+
## Usage
|
|
49
|
+
|
|
50
|
+
This package is intended to be used as a base class for database adapters:
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from datus_sqlalchemy import SQLAlchemyConnector
|
|
54
|
+
|
|
55
|
+
class MyDatabaseConnector(SQLAlchemyConnector):
|
|
56
|
+
def __init__(self, host, port, user, password, database):
|
|
57
|
+
connection_string = f"mydb://{user}:{password}@{host}:{port}/{database}"
|
|
58
|
+
super().__init__(connection_string, dialect="mydb")
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Requirements
|
|
62
|
+
|
|
63
|
+
- Python >= 3.12
|
|
64
|
+
- datus-agent >= 0.2.2
|
|
65
|
+
- sqlalchemy >= 2.0.23
|
|
66
|
+
- pyarrow >= 14.0.0, < 19.0.0
|
|
67
|
+
- pandas >= 2.1.4
|
|
68
|
+
|
|
69
|
+
## License
|
|
70
|
+
|
|
71
|
+
Apache License 2.0
|
|
72
|
+
|
|
73
|
+
## Related Packages
|
|
74
|
+
|
|
75
|
+
- `datus-mysql` - MySQL database adapter
|
|
76
|
+
- `datus-starrocks` - StarRocks database adapter
|
|
77
|
+
- `datus-snowflake` - Snowflake database adapter (uses native connector)
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
datus_sqlalchemy/__init__.py,sha256=4sKT6INfL5lOa6rynPZVG3KbhmIRjNndzuwVAAp5xMs,252
|
|
2
|
+
datus_sqlalchemy/connector.py,sha256=iFEFhVFrEbm6tBziTb_0Vr50hD8WUmg-Mh6-IwCUmDc,26223
|
|
3
|
+
datus_sqlalchemy-0.1.0.dist-info/METADATA,sha256=jicQRmtQjYDwqkz47SYBgDPJcogmM55_AiRWtbct8As,2449
|
|
4
|
+
datus_sqlalchemy-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
datus_sqlalchemy-0.1.0.dist-info/RECORD,,
|