datus-sqlalchemy 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datus-sqlalchemy
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: SQLAlchemy base connector for Datus database adapters
5
5
  Project-URL: Homepage, https://github.com/Datus-ai/datus-db-adapters
6
6
  Project-URL: Repository, https://github.com/Datus-ai/datus-db-adapters
@@ -14,7 +14,7 @@ Classifier: License :: OSI Approved :: Apache Software License
14
14
  Classifier: Programming Language :: Python :: 3
15
15
  Classifier: Programming Language :: Python :: 3.12
16
16
  Requires-Python: >=3.12
17
- Requires-Dist: datus-agent>0.2.5
17
+ Requires-Dist: datus-db-core>=0.1.0
18
18
  Requires-Dist: pandas>=2.1.4
19
19
  Requires-Dist: pyarrow<19.0.0,>=14.0.0
20
20
  Requires-Dist: sqlalchemy>=2.0.23
@@ -4,14 +4,18 @@
4
4
 
5
5
  from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, override
6
6
 
7
- from datus.schemas.base import TABLE_TYPE
8
- from datus.schemas.node_models import ExecuteSQLResult
9
- from datus.tools.db_tools.base import BaseSqlConnector
10
- from datus.tools.db_tools.config import ConnectionConfig
11
- from datus.utils.constants import DBType, SQLType
12
- from datus.utils.exceptions import DatusException, ErrorCode
13
- from datus.utils.loggings import get_logger
14
- from datus.utils.sql_utils import parse_context_switch, parse_sql_type
7
+ from datus_db_core import (
8
+ TABLE_TYPE,
9
+ BaseSqlConnector,
10
+ ConnectionConfig,
11
+ DatusDbException,
12
+ ErrorCode,
13
+ ExecuteSQLResult,
14
+ SQLType,
15
+ get_logger,
16
+ parse_context_switch,
17
+ parse_sql_type,
18
+ )
15
19
  from pandas import DataFrame
16
20
  from pyarrow import Table
17
21
  from sqlalchemy import create_engine, inspect, text
@@ -71,15 +75,15 @@ class SQLAlchemyConnector(BaseSqlConnector):
71
75
 
72
76
  @override
73
77
  def connect(self):
74
- """Initialize the connection pool (engine only, no persistent connection)."""
75
- if self.engine and self._owns_engine:
78
+ """Establish connection to the database."""
79
+ if self.engine and self.connection and self._owns_engine:
76
80
  return
77
81
 
78
82
  try:
79
83
  self._safe_close()
80
84
 
81
85
  # Create engine with connection pool
82
- if self.dialect not in (DBType.DUCKDB, DBType.SQLITE):
86
+ if self.dialect not in ("duckdb", "sqlite"):
83
87
  self.engine = create_engine(
84
88
  self.connection_string,
85
89
  pool_size=10, # Increased for parallel execution
@@ -91,22 +95,26 @@ class SQLAlchemyConnector(BaseSqlConnector):
91
95
  else:
92
96
  self.engine = create_engine(self.connection_string)
93
97
 
98
+ self.connection = self.engine.connect()
94
99
  self._owns_engine = True
95
100
 
96
101
  except Exception as e:
97
102
  self._force_reset()
98
103
  raise self._handle_exception(e, "", "connection") from e
99
104
 
100
- if not self.engine:
105
+ if not (self.engine and self.connection):
101
106
  self._force_reset()
102
- raise DatusException(
107
+ raise DatusDbException(
103
108
  ErrorCode.DB_CONNECTION_FAILED, message_args={"error_message": "Failed to establish connection"}
104
109
  )
105
110
 
106
111
  @override
107
112
  def close(self):
108
- """Dispose the connection pool."""
113
+ """Close the database connection."""
109
114
  try:
115
+ if self.connection:
116
+ self.connection.close()
117
+ self.connection = None
110
118
  if self.engine:
111
119
  self.engine.dispose()
112
120
  self.engine = None
@@ -124,6 +132,13 @@ class SQLAlchemyConnector(BaseSqlConnector):
124
132
  def _force_reset(self):
125
133
  """Force reset engine on error."""
126
134
  try:
135
+ self._safe_rollback()
136
+ if self.connection:
137
+ try:
138
+ self.connection.close()
139
+ except Exception:
140
+ pass
141
+ self.connection = None
127
142
  if self.engine:
128
143
  try:
129
144
  self.engine.dispose()
@@ -137,9 +152,9 @@ class SQLAlchemyConnector(BaseSqlConnector):
137
152
 
138
153
  # ==================== Error Handling ====================
139
154
 
140
- def _handle_exception(self, e: Exception, sql: str = "", operation: str = "SQL execution") -> DatusException:
155
+ def _handle_exception(self, e: Exception, sql: str = "", operation: str = "SQL execution") -> DatusDbException:
141
156
  """Map SQLAlchemy exceptions to Datus exceptions."""
142
- if isinstance(e, DatusException):
157
+ if isinstance(e, DatusDbException):
143
158
  return e
144
159
 
145
160
  # Extract error message
@@ -155,11 +170,11 @@ class SQLAlchemyConnector(BaseSqlConnector):
155
170
 
156
171
  # Syntax errors
157
172
  if any(kw in error_msg_lower for kw in ["syntax", "parse error", "sql error"]):
158
- return DatusException(ErrorCode.DB_EXECUTION_SYNTAX_ERROR, message_args=message_args)
173
+ return DatusDbException(ErrorCode.DB_EXECUTION_SYNTAX_ERROR, message_args=message_args)
159
174
 
160
175
  # Table not found
161
176
  if isinstance(e, NoSuchTableError):
162
- return DatusException(ErrorCode.DB_TABLE_NOT_EXISTS, message_args={"table_name": str(e)})
177
+ return DatusDbException(ErrorCode.DB_TABLE_NOT_EXISTS, message_args={"table_name": str(e)})
163
178
 
164
179
  # Connection and operational errors
165
180
  if isinstance(e, (OperationalError, InterfaceError)):
@@ -167,47 +182,47 @@ class SQLAlchemyConnector(BaseSqlConnector):
167
182
  if any(kw in error_msg_lower for kw in ["invalid transaction", "can't reconnect"]):
168
183
  logger.warning("Invalid transaction state detected, resetting connection")
169
184
  self._force_reset()
170
- return DatusException(ErrorCode.DB_TRANSACTION_FAILED, message_args=message_args)
185
+ return DatusDbException(ErrorCode.DB_TRANSACTION_FAILED, message_args=message_args)
171
186
 
172
187
  # Timeout errors
173
188
  if any(kw in error_msg_lower for kw in ["timeout", "timed out"]):
174
- return DatusException(ErrorCode.DB_CONNECTION_TIMEOUT, message_args=message_args)
189
+ return DatusDbException(ErrorCode.DB_CONNECTION_TIMEOUT, message_args=message_args)
175
190
 
176
191
  # Authentication errors
177
192
  if any(kw in error_msg_lower for kw in ["authentication", "access denied", "login failed"]):
178
- return DatusException(ErrorCode.DB_AUTHENTICATION_FAILED, message_args=message_args)
193
+ return DatusDbException(ErrorCode.DB_AUTHENTICATION_FAILED, message_args=message_args)
179
194
 
180
195
  # Permission errors
181
196
  if any(kw in error_msg_lower for kw in ["permission denied", "insufficient privilege"]):
182
197
  message_args["operation"] = operation
183
- return DatusException(ErrorCode.DB_PERMISSION_DENIED, message_args=message_args)
198
+ return DatusDbException(ErrorCode.DB_PERMISSION_DENIED, message_args=message_args)
184
199
 
185
200
  # Connection errors
186
201
  if any(kw in error_msg_lower for kw in ["connection refused", "connection failed", "unable to open"]):
187
- return DatusException(ErrorCode.DB_CONNECTION_FAILED, message_args=message_args)
202
+ return DatusDbException(ErrorCode.DB_CONNECTION_FAILED, message_args=message_args)
188
203
 
189
- return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
204
+ return DatusDbException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
190
205
 
191
206
  # Programming errors
192
207
  if isinstance(e, ProgrammingError):
193
208
  if any(kw in error_msg_lower for kw in ["syntax", "parse error", "sql error"]):
194
- return DatusException(ErrorCode.DB_EXECUTION_SYNTAX_ERROR, message_args=message_args)
195
- return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
209
+ return DatusDbException(ErrorCode.DB_EXECUTION_SYNTAX_ERROR, message_args=message_args)
210
+ return DatusDbException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
196
211
 
197
212
  # Constraint violations
198
213
  if isinstance(e, IntegrityError):
199
- return DatusException(ErrorCode.DB_CONSTRAINT_VIOLATION, message_args=message_args)
214
+ return DatusDbException(ErrorCode.DB_CONSTRAINT_VIOLATION, message_args=message_args)
200
215
 
201
216
  # Timeout errors
202
217
  if isinstance(e, TimeoutError):
203
- return DatusException(ErrorCode.DB_EXECUTION_TIMEOUT, message_args=message_args)
218
+ return DatusDbException(ErrorCode.DB_EXECUTION_TIMEOUT, message_args=message_args)
204
219
 
205
220
  # Other database errors
206
221
  if isinstance(e, (DatabaseError, DataError, InternalError, NotSupportedError)):
207
- return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
222
+ return DatusDbException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
208
223
 
209
224
  # Fallback
210
- return DatusException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
225
+ return DatusDbException(ErrorCode.DB_EXECUTION_ERROR, message_args=message_args)
211
226
 
212
227
  # ==================== Core Execute Methods ====================
213
228
 
@@ -234,7 +249,7 @@ class SQLAlchemyConnector(BaseSqlConnector):
234
249
  success=True, sql_query=sql, sql_return=result, row_count=row_count, result_format=result_format
235
250
  )
236
251
  except Exception as e:
237
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
252
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
238
253
  return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql)
239
254
 
240
255
  def _execute_query(self, sql: str) -> List[Dict[str, Any]]:
@@ -247,16 +262,16 @@ class SQLAlchemyConnector(BaseSqlConnector):
247
262
  SQLType.CONTENT_SET,
248
263
  SQLType.UNKNOWN,
249
264
  ):
250
- raise DatusException(ErrorCode.DB_EXECUTION_ERROR, message="Only SELECT and metadata queries are supported")
265
+ raise DatusDbException(
266
+ ErrorCode.DB_EXECUTION_ERROR, message="Only SELECT and metadata queries are supported"
267
+ )
251
268
 
252
269
  self.connect()
253
270
  try:
254
- # Get connection from pool for this query
255
- with self.engine.connect() as conn:
256
- result = conn.execute(text(sql))
257
- rows = result.fetchall()
258
- return [row._asdict() for row in rows]
259
- except DatusException:
271
+ result = self.connection.execute(text(sql))
272
+ rows = result.fetchall()
273
+ return [row._asdict() for row in rows]
274
+ except DatusDbException:
260
275
  self._safe_rollback()
261
276
  raise
262
277
  except Exception as e:
@@ -268,26 +283,24 @@ class SQLAlchemyConnector(BaseSqlConnector):
268
283
  """Execute INSERT statement."""
269
284
  try:
270
285
  self.connect()
271
- with self.engine.connect() as conn:
272
- res = conn.execute(text(sql))
273
- conn.commit()
286
+ res = self.connection.execute(text(sql))
287
+ self.connection.commit()
274
288
 
275
- # Get inserted primary key or row count
276
- inserted_pk = None
277
- try:
278
- if hasattr(res, "inserted_primary_key") and res.inserted_primary_key:
279
- inserted_pk = res.inserted_primary_key
280
- except Exception:
281
- pass
289
+ # Get inserted primary key or row count
290
+ inserted_pk = None
291
+ try:
292
+ if hasattr(res, "inserted_primary_key") and res.inserted_primary_key:
293
+ inserted_pk = res.inserted_primary_key
294
+ except Exception:
295
+ pass
282
296
 
283
- lastrowid = getattr(res, "lastrowid", None)
284
- return_value = inserted_pk if inserted_pk else (lastrowid if lastrowid else res.rowcount)
297
+ lastrowid = getattr(res, "lastrowid", None)
298
+ return_value = inserted_pk if inserted_pk else (lastrowid if lastrowid else res.rowcount)
285
299
 
286
- return ExecuteSQLResult(
287
- success=True, sql_query=sql, sql_return=str(return_value), row_count=res.rowcount
288
- )
300
+ return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(return_value), row_count=res.rowcount)
289
301
  except Exception as e:
290
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
302
+ self._safe_rollback()
303
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
291
304
  return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0)
292
305
 
293
306
  @override
@@ -295,14 +308,12 @@ class SQLAlchemyConnector(BaseSqlConnector):
295
308
  """Execute UPDATE statement."""
296
309
  try:
297
310
  self.connect()
298
- with self.engine.connect() as conn:
299
- res = conn.execute(text(sql))
300
- conn.commit()
301
- return ExecuteSQLResult(
302
- success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount
303
- )
311
+ res = self.connection.execute(text(sql))
312
+ self.connection.commit()
313
+ return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount)
304
314
  except Exception as e:
305
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
315
+ self._safe_rollback()
316
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
306
317
  return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0)
307
318
 
308
319
  @override
@@ -310,14 +321,12 @@ class SQLAlchemyConnector(BaseSqlConnector):
310
321
  """Execute DELETE statement."""
311
322
  try:
312
323
  self.connect()
313
- with self.engine.connect() as conn:
314
- res = conn.execute(text(sql))
315
- conn.commit()
316
- return ExecuteSQLResult(
317
- success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount
318
- )
324
+ res = self.connection.execute(text(sql))
325
+ self.connection.commit()
326
+ return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount)
319
327
  except Exception as e:
320
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
328
+ self._safe_rollback()
329
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
321
330
  return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0)
322
331
 
323
332
  @override
@@ -325,14 +334,12 @@ class SQLAlchemyConnector(BaseSqlConnector):
325
334
  """Execute DDL statement (CREATE, ALTER, DROP, etc.)."""
326
335
  try:
327
336
  self.connect()
328
- with self.engine.connect() as conn:
329
- res = conn.execute(text(sql))
330
- conn.commit()
331
- return ExecuteSQLResult(
332
- success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount
333
- )
337
+ res = self.connection.execute(text(sql))
338
+ self.connection.commit()
339
+ return ExecuteSQLResult(success=True, sql_query=sql, sql_return=str(res.rowcount), row_count=res.rowcount)
334
340
  except Exception as e:
335
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
341
+ self._safe_rollback()
342
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
336
343
  return ExecuteSQLResult(success=False, sql_query=sql, error=str(ex))
337
344
 
338
345
  def execute_pandas(self, sql: str) -> ExecuteSQLResult:
@@ -343,7 +350,7 @@ class SQLAlchemyConnector(BaseSqlConnector):
343
350
  success=True, sql_query=sql, sql_return=df, row_count=len(df), result_format="pandas"
344
351
  )
345
352
  except Exception as e:
346
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
353
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
347
354
  return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql)
348
355
 
349
356
  def _execute_pandas(self, sql: str) -> DataFrame:
@@ -359,7 +366,7 @@ class SQLAlchemyConnector(BaseSqlConnector):
359
366
  success=True, sql_query=sql, sql_return=df.to_csv(index=False), row_count=len(df), result_format="csv"
360
367
  )
361
368
  except Exception as e:
362
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
369
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
363
370
  return ExecuteSQLResult(
364
371
  success=False, sql_query=sql, sql_return="", row_count=0, error=str(ex), result_format="csv"
365
372
  )
@@ -379,7 +386,7 @@ class SQLAlchemyConnector(BaseSqlConnector):
379
386
  success=True, sql_query=sql, sql_return=result.rowcount, row_count=0, result_format="arrow"
380
387
  )
381
388
  except Exception as e:
382
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
389
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
383
390
  return ExecuteSQLResult(
384
391
  success=False, error=str(ex), sql_query=sql, sql_return="", row_count=0, result_format="arrow"
385
392
  )
@@ -389,9 +396,8 @@ class SQLAlchemyConnector(BaseSqlConnector):
389
396
  """Execute USE/SET commands."""
390
397
  self.connect()
391
398
  try:
392
- with self.engine.connect() as conn:
393
- conn.execute(text(sql))
394
- conn.commit()
399
+ self.connection.execute(text(sql))
400
+ self.connection.commit()
395
401
 
396
402
  # Update context if applicable
397
403
  if self.dialect != "sqlite":
@@ -406,7 +412,8 @@ class SQLAlchemyConnector(BaseSqlConnector):
406
412
 
407
413
  return ExecuteSQLResult(success=True, sql_query=sql, sql_return="Successful", row_count=0)
408
414
  except Exception as e:
409
- ex = e if isinstance(e, DatusException) else self._handle_exception(e, sql)
415
+ self._safe_rollback()
416
+ ex = e if isinstance(e, DatusDbException) else self._handle_exception(e, sql)
410
417
  return ExecuteSQLResult(success=False, error=str(ex), sql_query=sql)
411
418
 
412
419
  def execute_queries(self, queries: List[str]) -> List[Any]:
@@ -414,31 +421,29 @@ class SQLAlchemyConnector(BaseSqlConnector):
414
421
  results = []
415
422
  self.connect()
416
423
  try:
417
- with self.engine.connect() as conn:
418
- for query in queries:
419
- result = conn.execute(text(query))
420
- if result.returns_rows:
421
- df = DataFrame(result.fetchall(), columns=list(result.keys()))
422
- results.append(df.to_dict(orient="records"))
424
+ for query in queries:
425
+ result = self.connection.execute(text(query))
426
+ if result.returns_rows:
427
+ df = DataFrame(result.fetchall(), columns=list(result.keys()))
428
+ results.append(df.to_dict(orient="records"))
429
+ else:
430
+ query_lower = query.strip().lower()
431
+ if query_lower.startswith("insert"):
432
+ inserted_pk = None
433
+ try:
434
+ if hasattr(result, "inserted_primary_key") and result.inserted_primary_key:
435
+ inserted_pk = result.inserted_primary_key
436
+ except Exception:
437
+ pass
438
+ lastrowid = getattr(result, "lastrowid", None)
439
+ results.append(inserted_pk if inserted_pk else (lastrowid if lastrowid else result.rowcount))
440
+ elif query_lower.startswith(("update", "delete")):
441
+ results.append(result.rowcount)
423
442
  else:
424
- query_lower = query.strip().lower()
425
- if query_lower.startswith("insert"):
426
- inserted_pk = None
427
- try:
428
- if hasattr(result, "inserted_primary_key") and result.inserted_primary_key:
429
- inserted_pk = result.inserted_primary_key
430
- except Exception:
431
- pass
432
- lastrowid = getattr(result, "lastrowid", None)
433
- results.append(
434
- inserted_pk if inserted_pk else (lastrowid if lastrowid else result.rowcount)
435
- )
436
- elif query_lower.startswith(("update", "delete")):
437
- results.append(result.rowcount)
438
- else:
439
- results.append(None)
440
- conn.commit()
443
+ results.append(None)
444
+ self.connection.commit()
441
445
  except SQLAlchemyError as e:
446
+ self._safe_rollback()
442
447
  raise self._handle_exception(e, "\n".join(queries), "batch query") from e
443
448
  return results
444
449
 
@@ -449,9 +454,9 @@ class SQLAlchemyConnector(BaseSqlConnector):
449
454
  self._execute_query("SELECT 1")
450
455
  return True
451
456
  except Exception as e:
452
- if isinstance(e, DatusException):
457
+ if isinstance(e, DatusDbException):
453
458
  raise
454
- raise DatusException(
459
+ raise DatusDbException(
455
460
  ErrorCode.DB_CONNECTION_FAILED, message_args={"error_message": "Connection test failed"}
456
461
  ) from e
457
462
 
@@ -480,11 +485,10 @@ class SQLAlchemyConnector(BaseSqlConnector):
480
485
  try:
481
486
  return inspector.get_view_names(schema=sqlalchemy_schema)
482
487
  except Exception as e:
483
- raise DatusException(
488
+ raise DatusDbException(
484
489
  ErrorCode.DB_FAILED, message_args={"operation": "get_views", "error_message": str(e)}
485
490
  ) from e
486
491
 
487
- @override
488
492
  def get_schemas(self, catalog_name: str = "", database_name: str = "", include_sys: bool = False) -> List[str]:
489
493
  """Get list of schemas."""
490
494
  schemas = self._inspector().get_schema_names()
@@ -579,7 +583,7 @@ class SQLAlchemyConnector(BaseSqlConnector):
579
583
  }
580
584
  )
581
585
  return samples
582
- except DatusException:
586
+ except DatusDbException:
583
587
  raise
584
588
  except Exception as e:
585
589
  raise self._handle_exception(e) from e
@@ -602,20 +606,19 @@ class SQLAlchemyConnector(BaseSqlConnector):
602
606
  """Execute query and return CSV rows in batches."""
603
607
  self.connect()
604
608
  try:
605
- with self.engine.connect() as conn:
606
- result = conn.execute(text(sql).execution_options(stream_results=True, max_row_buffer=max_rows))
607
- if result.returns_rows:
608
- if with_header:
609
- yield result.keys()
610
- while True:
611
- batch_rows = result.fetchmany(max_rows)
612
- if not batch_rows:
613
- break
614
- for row in batch_rows:
615
- yield row
616
- else:
617
- if with_header:
618
- yield ()
619
- yield from []
609
+ result = self.connection.execute(text(sql).execution_options(stream_results=True, max_row_buffer=max_rows))
610
+ if result.returns_rows:
611
+ if with_header:
612
+ yield result.keys()
613
+ while True:
614
+ batch_rows = result.fetchmany(max_rows)
615
+ if not batch_rows:
616
+ break
617
+ for row in batch_rows:
618
+ yield row
619
+ else:
620
+ if with_header:
621
+ yield ()
622
+ yield from []
620
623
  except Exception as e:
621
624
  raise self._handle_exception(e) from e
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "datus-sqlalchemy"
3
- version = "0.1.3"
3
+ version = "0.1.4"
4
4
  description = "SQLAlchemy base connector for Datus database adapters"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
@@ -18,7 +18,7 @@ classifiers = [
18
18
  ]
19
19
 
20
20
  dependencies = [
21
- "datus-agent>0.2.5",
21
+ "datus-db-core>=0.1.0",
22
22
  "sqlalchemy>=2.0.23",
23
23
  "pyarrow>=14.0.0,<19.0.0",
24
24
  "pandas>=2.1.4",
@@ -29,6 +29,9 @@ Homepage = "https://github.com/Datus-ai/datus-db-adapters"
29
29
  Repository = "https://github.com/Datus-ai/datus-db-adapters"
30
30
  Issues = "https://github.com/Datus-ai/datus-db-adapters/issues"
31
31
 
32
+ [tool.uv.sources]
33
+ datus-db-core = { workspace = true }
34
+
32
35
  [build-system]
33
36
  requires = ["hatchling"]
34
37
  build-backend = "hatchling.build"
@@ -0,0 +1,33 @@
1
+ # Copyright 2025-present DatusAI, Inc.
2
+ # Licensed under the Apache License, Version 2.0.
3
+ # See http://www.apache.org/licenses/LICENSE-2.0 for details.
4
+
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ from datus_sqlalchemy import SQLAlchemyConnector
8
+
9
+
10
+ class DummySQLAlchemyConnector(SQLAlchemyConnector):
11
+ def get_databases(self, catalog_name: str = "", include_sys: bool = False):
12
+ return []
13
+
14
+
15
+ def test_execute_content_set_and_query_share_persistent_connection():
16
+ """USE/SET statements must affect later queries executed by the connector."""
17
+ connector = DummySQLAlchemyConnector("sqlite://", dialect="mysql")
18
+ persistent_conn = MagicMock()
19
+ query_result = MagicMock()
20
+ query_result.fetchall.return_value = [MagicMock(_asdict=lambda: {"id": 1})]
21
+
22
+ engine = MagicMock()
23
+ engine.connect.return_value = persistent_conn
24
+ persistent_conn.execute.side_effect = [MagicMock(), query_result]
25
+
26
+ with patch("datus_sqlalchemy.connector.create_engine", return_value=engine):
27
+ set_result = connector.execute_content_set("USE analytics")
28
+ query_rows = connector._execute_query("SELECT id FROM users")
29
+
30
+ assert set_result.success is True
31
+ assert query_rows == [{"id": 1}]
32
+ assert engine.connect.call_count == 1
33
+ assert persistent_conn.execute.call_count == 2