sqlspec 0.13.1__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +39 -1
- sqlspec/__main__.py +12 -0
- sqlspec/adapters/adbc/config.py +16 -40
- sqlspec/adapters/adbc/driver.py +43 -16
- sqlspec/adapters/adbc/transformers.py +108 -0
- sqlspec/adapters/aiosqlite/config.py +2 -20
- sqlspec/adapters/aiosqlite/driver.py +36 -18
- sqlspec/adapters/asyncmy/config.py +2 -33
- sqlspec/adapters/asyncmy/driver.py +23 -16
- sqlspec/adapters/asyncpg/config.py +5 -39
- sqlspec/adapters/asyncpg/driver.py +41 -18
- sqlspec/adapters/bigquery/config.py +2 -43
- sqlspec/adapters/bigquery/driver.py +26 -14
- sqlspec/adapters/duckdb/config.py +2 -49
- sqlspec/adapters/duckdb/driver.py +35 -16
- sqlspec/adapters/oracledb/config.py +4 -83
- sqlspec/adapters/oracledb/driver.py +54 -27
- sqlspec/adapters/psqlpy/config.py +2 -55
- sqlspec/adapters/psqlpy/driver.py +28 -8
- sqlspec/adapters/psycopg/config.py +4 -73
- sqlspec/adapters/psycopg/driver.py +69 -24
- sqlspec/adapters/sqlite/config.py +3 -21
- sqlspec/adapters/sqlite/driver.py +50 -26
- sqlspec/cli.py +248 -0
- sqlspec/config.py +18 -20
- sqlspec/driver/_async.py +28 -10
- sqlspec/driver/_common.py +5 -4
- sqlspec/driver/_sync.py +28 -10
- sqlspec/driver/mixins/__init__.py +6 -0
- sqlspec/driver/mixins/_cache.py +114 -0
- sqlspec/driver/mixins/_pipeline.py +0 -4
- sqlspec/{service/base.py → driver/mixins/_query_tools.py} +86 -421
- sqlspec/driver/mixins/_result_utils.py +0 -2
- sqlspec/driver/mixins/_sql_translator.py +0 -2
- sqlspec/driver/mixins/_storage.py +4 -18
- sqlspec/driver/mixins/_type_coercion.py +0 -2
- sqlspec/driver/parameters.py +4 -4
- sqlspec/extensions/aiosql/adapter.py +4 -4
- sqlspec/extensions/litestar/__init__.py +2 -1
- sqlspec/extensions/litestar/cli.py +48 -0
- sqlspec/extensions/litestar/plugin.py +3 -0
- sqlspec/loader.py +1 -1
- sqlspec/migrations/__init__.py +23 -0
- sqlspec/migrations/base.py +390 -0
- sqlspec/migrations/commands.py +525 -0
- sqlspec/migrations/runner.py +215 -0
- sqlspec/migrations/tracker.py +153 -0
- sqlspec/migrations/utils.py +89 -0
- sqlspec/protocols.py +37 -3
- sqlspec/statement/builder/__init__.py +8 -8
- sqlspec/statement/builder/{column.py → _column.py} +82 -52
- sqlspec/statement/builder/{ddl.py → _ddl.py} +5 -5
- sqlspec/statement/builder/_ddl_utils.py +1 -1
- sqlspec/statement/builder/{delete.py → _delete.py} +1 -1
- sqlspec/statement/builder/{insert.py → _insert.py} +1 -1
- sqlspec/statement/builder/{merge.py → _merge.py} +1 -1
- sqlspec/statement/builder/_parsing_utils.py +5 -3
- sqlspec/statement/builder/{select.py → _select.py} +59 -61
- sqlspec/statement/builder/{update.py → _update.py} +2 -2
- sqlspec/statement/builder/mixins/__init__.py +24 -30
- sqlspec/statement/builder/mixins/{_set_ops.py → _cte_and_set_ops.py} +86 -2
- sqlspec/statement/builder/mixins/{_delete_from.py → _delete_operations.py} +2 -0
- sqlspec/statement/builder/mixins/{_insert_values.py → _insert_operations.py} +70 -1
- sqlspec/statement/builder/mixins/{_merge_clauses.py → _merge_operations.py} +2 -0
- sqlspec/statement/builder/mixins/_order_limit_operations.py +123 -0
- sqlspec/statement/builder/mixins/{_pivot.py → _pivot_operations.py} +71 -2
- sqlspec/statement/builder/mixins/_select_operations.py +612 -0
- sqlspec/statement/builder/mixins/{_update_set.py → _update_operations.py} +73 -2
- sqlspec/statement/builder/mixins/_where_clause.py +536 -0
- sqlspec/statement/cache.py +50 -0
- sqlspec/statement/filters.py +37 -8
- sqlspec/statement/parameters.py +143 -54
- sqlspec/statement/pipelines/__init__.py +1 -1
- sqlspec/statement/pipelines/context.py +4 -10
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +3 -3
- sqlspec/statement/pipelines/validators/_parameter_style.py +22 -22
- sqlspec/statement/pipelines/validators/_performance.py +1 -5
- sqlspec/statement/sql.py +246 -176
- sqlspec/utils/__init__.py +2 -1
- sqlspec/utils/statement_hashing.py +203 -0
- sqlspec/utils/type_guards.py +32 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/METADATA +1 -1
- sqlspec-0.14.1.dist-info/RECORD +145 -0
- sqlspec-0.14.1.dist-info/entry_points.txt +2 -0
- sqlspec/service/__init__.py +0 -4
- sqlspec/service/_util.py +0 -147
- sqlspec/service/pagination.py +0 -26
- sqlspec/statement/builder/mixins/_aggregate_functions.py +0 -250
- sqlspec/statement/builder/mixins/_case_builder.py +0 -91
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -90
- sqlspec/statement/builder/mixins/_from.py +0 -63
- sqlspec/statement/builder/mixins/_group_by.py +0 -118
- sqlspec/statement/builder/mixins/_having.py +0 -35
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -47
- sqlspec/statement/builder/mixins/_insert_into.py +0 -36
- sqlspec/statement/builder/mixins/_limit_offset.py +0 -53
- sqlspec/statement/builder/mixins/_order_by.py +0 -46
- sqlspec/statement/builder/mixins/_returning.py +0 -37
- sqlspec/statement/builder/mixins/_select_columns.py +0 -61
- sqlspec/statement/builder/mixins/_unpivot.py +0 -77
- sqlspec/statement/builder/mixins/_update_from.py +0 -55
- sqlspec/statement/builder/mixins/_update_table.py +0 -29
- sqlspec/statement/builder/mixins/_where.py +0 -401
- sqlspec/statement/builder/mixins/_window_functions.py +0 -86
- sqlspec/statement/parameter_manager.py +0 -220
- sqlspec/statement/sql_compiler.py +0 -140
- sqlspec-0.13.1.dist-info/RECORD +0 -150
- /sqlspec/statement/builder/{base.py → _base.py} +0 -0
- /sqlspec/statement/builder/mixins/{_join.py → _join_operations.py} +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.13.1.dist-info → sqlspec-0.14.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -12,12 +12,13 @@ from sqlspec.driver import SyncDriverAdapterProtocol
|
|
|
12
12
|
from sqlspec.driver.connection import managed_transaction_sync
|
|
13
13
|
from sqlspec.driver.mixins import (
|
|
14
14
|
SQLTranslatorMixin,
|
|
15
|
+
SyncAdapterCacheMixin,
|
|
15
16
|
SyncPipelinedExecutionMixin,
|
|
16
17
|
SyncStorageMixin,
|
|
17
18
|
ToSchemaMixin,
|
|
18
19
|
TypeCoercionMixin,
|
|
19
20
|
)
|
|
20
|
-
from sqlspec.driver.parameters import
|
|
21
|
+
from sqlspec.driver.parameters import convert_parameter_sequence
|
|
21
22
|
from sqlspec.statement.parameters import ParameterStyle
|
|
22
23
|
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
23
24
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
@@ -38,6 +39,7 @@ logger = get_logger("adapters.duckdb")
|
|
|
38
39
|
|
|
39
40
|
class DuckDBDriver(
|
|
40
41
|
SyncDriverAdapterProtocol["DuckDBConnection", RowT],
|
|
42
|
+
SyncAdapterCacheMixin,
|
|
41
43
|
SQLTranslatorMixin,
|
|
42
44
|
TypeCoercionMixin,
|
|
43
45
|
SyncStorageMixin,
|
|
@@ -63,7 +65,6 @@ class DuckDBDriver(
|
|
|
63
65
|
supports_native_arrow_import: ClassVar[bool] = True
|
|
64
66
|
supports_native_parquet_export: ClassVar[bool] = True
|
|
65
67
|
supports_native_parquet_import: ClassVar[bool] = True
|
|
66
|
-
__slots__ = ()
|
|
67
68
|
|
|
68
69
|
def __init__(
|
|
69
70
|
self,
|
|
@@ -86,10 +87,10 @@ class DuckDBDriver(
|
|
|
86
87
|
self, statement: SQL, connection: Optional["DuckDBConnection"] = None, **kwargs: Any
|
|
87
88
|
) -> SQLResult[RowT]:
|
|
88
89
|
if statement.is_script:
|
|
89
|
-
sql, _ =
|
|
90
|
+
sql, _ = self._get_compiled_sql(statement, ParameterStyle.STATIC)
|
|
90
91
|
return self._execute_script(sql, connection=connection, **kwargs)
|
|
91
92
|
|
|
92
|
-
sql, params =
|
|
93
|
+
sql, params = self._get_compiled_sql(statement, self.default_parameter_style)
|
|
93
94
|
params = self._process_parameters(params)
|
|
94
95
|
|
|
95
96
|
if statement.is_many:
|
|
@@ -104,9 +105,9 @@ class DuckDBDriver(
|
|
|
104
105
|
conn = connection if connection is not None else self._connection(None)
|
|
105
106
|
|
|
106
107
|
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
107
|
-
#
|
|
108
|
-
|
|
109
|
-
final_params =
|
|
108
|
+
# Convert parameters using consolidated utility
|
|
109
|
+
converted_params = convert_parameter_sequence(parameters)
|
|
110
|
+
final_params = converted_params or []
|
|
110
111
|
|
|
111
112
|
if self.returns_rows(statement.expression):
|
|
112
113
|
result = txn_conn.execute(sql, final_params)
|
|
@@ -157,12 +158,12 @@ class DuckDBDriver(
|
|
|
157
158
|
|
|
158
159
|
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
159
160
|
# Normalize parameter list using consolidated utility
|
|
160
|
-
|
|
161
|
-
final_param_list =
|
|
161
|
+
converted_param_list = convert_parameter_sequence(param_list)
|
|
162
|
+
final_param_list = converted_param_list or []
|
|
162
163
|
|
|
163
164
|
# DuckDB throws an error if executemany is called with empty parameter list
|
|
164
165
|
if not final_param_list:
|
|
165
|
-
return SQLResult(
|
|
166
|
+
return SQLResult( # pyright: ignore
|
|
166
167
|
statement=SQL(sql, _dialect=self.dialect),
|
|
167
168
|
data=[],
|
|
168
169
|
rows_affected=0,
|
|
@@ -176,7 +177,7 @@ class DuckDBDriver(
|
|
|
176
177
|
# For executemany, fetchone() only returns the count from the last operation,
|
|
177
178
|
# so use parameter list length as the most accurate estimate
|
|
178
179
|
rows_affected = cursor.rowcount if cursor.rowcount >= 0 else len(final_param_list)
|
|
179
|
-
return SQLResult(
|
|
180
|
+
return SQLResult( # pyright: ignore
|
|
180
181
|
statement=SQL(sql, _dialect=self.dialect),
|
|
181
182
|
data=[],
|
|
182
183
|
rows_affected=rows_affected,
|
|
@@ -191,20 +192,38 @@ class DuckDBDriver(
|
|
|
191
192
|
conn = connection if connection is not None else self._connection(None)
|
|
192
193
|
|
|
193
194
|
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
195
|
+
# Split script into individual statements for validation
|
|
196
|
+
statements = self._split_script_statements(script)
|
|
197
|
+
suppress_warnings = kwargs.get("_suppress_warnings", False)
|
|
198
|
+
|
|
199
|
+
executed_count = 0
|
|
200
|
+
total_rows = 0
|
|
201
|
+
|
|
194
202
|
with self._get_cursor(txn_conn) as cursor:
|
|
195
|
-
|
|
203
|
+
for statement in statements:
|
|
204
|
+
if statement.strip():
|
|
205
|
+
# Validate each statement unless warnings suppressed
|
|
206
|
+
if not suppress_warnings:
|
|
207
|
+
# Run validation through pipeline
|
|
208
|
+
temp_sql = SQL(statement, config=self.config)
|
|
209
|
+
temp_sql._ensure_processed()
|
|
210
|
+
# Validation errors are logged as warnings by default
|
|
211
|
+
|
|
212
|
+
cursor.execute(statement)
|
|
213
|
+
executed_count += 1
|
|
214
|
+
total_rows += cursor.rowcount or 0
|
|
196
215
|
|
|
197
216
|
return SQLResult(
|
|
198
217
|
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
199
218
|
data=[],
|
|
200
|
-
rows_affected=
|
|
219
|
+
rows_affected=total_rows,
|
|
201
220
|
operation_type="SCRIPT",
|
|
202
221
|
metadata={
|
|
203
222
|
"status_message": "Script executed successfully.",
|
|
204
223
|
"description": "The script was sent to the database.",
|
|
205
224
|
},
|
|
206
|
-
total_statements
|
|
207
|
-
successful_statements
|
|
225
|
+
total_statements=executed_count,
|
|
226
|
+
successful_statements=executed_count,
|
|
208
227
|
)
|
|
209
228
|
|
|
210
229
|
# ============================================================================
|
|
@@ -214,7 +233,7 @@ class DuckDBDriver(
|
|
|
214
233
|
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
215
234
|
"""Enhanced DuckDB native Arrow table fetching with streaming support."""
|
|
216
235
|
conn = self._connection(connection)
|
|
217
|
-
sql_string, parameters =
|
|
236
|
+
sql_string, parameters = self._get_compiled_sql(sql, self.default_parameter_style)
|
|
218
237
|
parameters = self._process_parameters(parameters)
|
|
219
238
|
result = conn.execute(sql_string, parameters or [])
|
|
220
239
|
|
|
@@ -23,7 +23,6 @@ if TYPE_CHECKING:
|
|
|
23
23
|
|
|
24
24
|
from oracledb import AuthMode
|
|
25
25
|
from oracledb.pool import AsyncConnectionPool, ConnectionPool
|
|
26
|
-
from sqlglot.dialects.dialect import DialectType
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
__all__ = ("CONNECTION_FIELDS", "POOL_FIELDS", "OracleAsyncConfig", "OracleSyncConfig")
|
|
@@ -73,43 +72,6 @@ POOL_FIELDS = CONNECTION_FIELDS.union(
|
|
|
73
72
|
class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool", OracleSyncDriver]):
|
|
74
73
|
"""Configuration for Oracle synchronous database connections with direct field-based configuration."""
|
|
75
74
|
|
|
76
|
-
__slots__ = (
|
|
77
|
-
"_dialect",
|
|
78
|
-
"config_dir",
|
|
79
|
-
"default_row_type",
|
|
80
|
-
"dsn",
|
|
81
|
-
"edition",
|
|
82
|
-
"events",
|
|
83
|
-
"extras",
|
|
84
|
-
"getmode",
|
|
85
|
-
"homogeneous",
|
|
86
|
-
"host",
|
|
87
|
-
"increment",
|
|
88
|
-
"max",
|
|
89
|
-
"max_lifetime_session",
|
|
90
|
-
"max_sessions_per_shard",
|
|
91
|
-
"min",
|
|
92
|
-
"mode",
|
|
93
|
-
"password",
|
|
94
|
-
"ping_interval",
|
|
95
|
-
"pool_instance",
|
|
96
|
-
"port",
|
|
97
|
-
"retry_count",
|
|
98
|
-
"retry_delay",
|
|
99
|
-
"service_name",
|
|
100
|
-
"session_callback",
|
|
101
|
-
"sid",
|
|
102
|
-
"soda_metadata_cache",
|
|
103
|
-
"statement_config",
|
|
104
|
-
"tcp_connect_timeout",
|
|
105
|
-
"threaded",
|
|
106
|
-
"timeout",
|
|
107
|
-
"user",
|
|
108
|
-
"wait_timeout",
|
|
109
|
-
"wallet_location",
|
|
110
|
-
"wallet_password",
|
|
111
|
-
)
|
|
112
|
-
|
|
113
75
|
is_async: ClassVar[bool] = False
|
|
114
76
|
supports_connection_pooling: ClassVar[bool] = True
|
|
115
77
|
|
|
@@ -120,7 +82,7 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"
|
|
|
120
82
|
supported_parameter_styles: ClassVar[tuple[str, ...]] = ("named_colon", "positional_colon")
|
|
121
83
|
"""OracleDB supports :name (named_colon) and :1 (positional_colon) parameter styles."""
|
|
122
84
|
|
|
123
|
-
|
|
85
|
+
default_parameter_style: ClassVar[str] = "named_colon"
|
|
124
86
|
"""OracleDB's preferred parameter style is :name (named_colon)."""
|
|
125
87
|
|
|
126
88
|
def __init__(
|
|
@@ -236,8 +198,6 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"
|
|
|
236
198
|
# Store other config
|
|
237
199
|
self.statement_config = statement_config or SQLConfig()
|
|
238
200
|
self.default_row_type = default_row_type
|
|
239
|
-
self.pool_instance = pool_instance
|
|
240
|
-
self._dialect: DialectType = None
|
|
241
201
|
|
|
242
202
|
super().__init__()
|
|
243
203
|
|
|
@@ -300,7 +260,7 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"
|
|
|
300
260
|
statement_config = replace(
|
|
301
261
|
statement_config,
|
|
302
262
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
303
|
-
|
|
263
|
+
default_parameter_style=self.default_parameter_style,
|
|
304
264
|
)
|
|
305
265
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
306
266
|
yield driver
|
|
@@ -368,43 +328,6 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"
|
|
|
368
328
|
class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnectionPool", OracleAsyncDriver]):
|
|
369
329
|
"""Configuration for Oracle asynchronous database connections with direct field-based configuration."""
|
|
370
330
|
|
|
371
|
-
__slots__ = (
|
|
372
|
-
"_dialect",
|
|
373
|
-
"config_dir",
|
|
374
|
-
"default_row_type",
|
|
375
|
-
"dsn",
|
|
376
|
-
"edition",
|
|
377
|
-
"events",
|
|
378
|
-
"extras",
|
|
379
|
-
"getmode",
|
|
380
|
-
"homogeneous",
|
|
381
|
-
"host",
|
|
382
|
-
"increment",
|
|
383
|
-
"max",
|
|
384
|
-
"max_lifetime_session",
|
|
385
|
-
"max_sessions_per_shard",
|
|
386
|
-
"min",
|
|
387
|
-
"mode",
|
|
388
|
-
"password",
|
|
389
|
-
"ping_interval",
|
|
390
|
-
"pool_instance",
|
|
391
|
-
"port",
|
|
392
|
-
"retry_count",
|
|
393
|
-
"retry_delay",
|
|
394
|
-
"service_name",
|
|
395
|
-
"session_callback",
|
|
396
|
-
"sid",
|
|
397
|
-
"soda_metadata_cache",
|
|
398
|
-
"statement_config",
|
|
399
|
-
"tcp_connect_timeout",
|
|
400
|
-
"threaded",
|
|
401
|
-
"timeout",
|
|
402
|
-
"user",
|
|
403
|
-
"wait_timeout",
|
|
404
|
-
"wallet_location",
|
|
405
|
-
"wallet_password",
|
|
406
|
-
)
|
|
407
|
-
|
|
408
331
|
is_async: ClassVar[bool] = True
|
|
409
332
|
supports_connection_pooling: ClassVar[bool] = True
|
|
410
333
|
|
|
@@ -415,7 +338,7 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnect
|
|
|
415
338
|
supported_parameter_styles: ClassVar[tuple[str, ...]] = ("named_colon", "positional_colon")
|
|
416
339
|
"""OracleDB supports :name (named_colon) and :1 (positional_colon) parameter styles."""
|
|
417
340
|
|
|
418
|
-
|
|
341
|
+
default_parameter_style: ClassVar[str] = "named_colon"
|
|
419
342
|
"""OracleDB's preferred parameter style is :name (named_colon)."""
|
|
420
343
|
|
|
421
344
|
def __init__(
|
|
@@ -531,8 +454,6 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnect
|
|
|
531
454
|
# Store other config
|
|
532
455
|
self.statement_config = statement_config or SQLConfig()
|
|
533
456
|
self.default_row_type = default_row_type
|
|
534
|
-
self.pool_instance: Optional[AsyncConnectionPool] = pool_instance
|
|
535
|
-
self._dialect: DialectType = None
|
|
536
457
|
|
|
537
458
|
super().__init__()
|
|
538
459
|
|
|
@@ -623,7 +544,7 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnect
|
|
|
623
544
|
statement_config = replace(
|
|
624
545
|
statement_config,
|
|
625
546
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
626
|
-
|
|
547
|
+
default_parameter_style=self.default_parameter_style,
|
|
627
548
|
)
|
|
628
549
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
629
550
|
yield driver
|
|
@@ -8,15 +8,17 @@ from sqlglot.dialects.dialect import DialectType
|
|
|
8
8
|
from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
|
|
9
9
|
from sqlspec.driver.connection import managed_transaction_async, managed_transaction_sync
|
|
10
10
|
from sqlspec.driver.mixins import (
|
|
11
|
+
AsyncAdapterCacheMixin,
|
|
11
12
|
AsyncPipelinedExecutionMixin,
|
|
12
13
|
AsyncStorageMixin,
|
|
13
14
|
SQLTranslatorMixin,
|
|
15
|
+
SyncAdapterCacheMixin,
|
|
14
16
|
SyncPipelinedExecutionMixin,
|
|
15
17
|
SyncStorageMixin,
|
|
16
18
|
ToSchemaMixin,
|
|
17
19
|
TypeCoercionMixin,
|
|
18
20
|
)
|
|
19
|
-
from sqlspec.driver.parameters import
|
|
21
|
+
from sqlspec.driver.parameters import convert_parameter_sequence
|
|
20
22
|
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
21
23
|
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
22
24
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
@@ -63,6 +65,7 @@ def _process_oracle_parameters(params: Any) -> Any:
|
|
|
63
65
|
|
|
64
66
|
class OracleSyncDriver(
|
|
65
67
|
SyncDriverAdapterProtocol[OracleSyncConnection, RowT],
|
|
68
|
+
SyncAdapterCacheMixin,
|
|
66
69
|
SQLTranslatorMixin,
|
|
67
70
|
TypeCoercionMixin,
|
|
68
71
|
SyncStorageMixin,
|
|
@@ -78,7 +81,6 @@ class OracleSyncDriver(
|
|
|
78
81
|
)
|
|
79
82
|
default_parameter_style: ParameterStyle = ParameterStyle.NAMED_COLON
|
|
80
83
|
support_native_arrow_export = True
|
|
81
|
-
__slots__ = ()
|
|
82
84
|
|
|
83
85
|
def __init__(
|
|
84
86
|
self,
|
|
@@ -109,7 +111,7 @@ class OracleSyncDriver(
|
|
|
109
111
|
self, statement: SQL, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
|
|
110
112
|
) -> SQLResult[RowT]:
|
|
111
113
|
if statement.is_script:
|
|
112
|
-
sql, _ =
|
|
114
|
+
sql, _ = self._get_compiled_sql(statement, ParameterStyle.STATIC)
|
|
113
115
|
return self._execute_script(sql, connection=connection, **kwargs)
|
|
114
116
|
|
|
115
117
|
detected_styles = set()
|
|
@@ -132,11 +134,11 @@ class OracleSyncDriver(
|
|
|
132
134
|
break
|
|
133
135
|
|
|
134
136
|
if statement.is_many:
|
|
135
|
-
sql, params =
|
|
137
|
+
sql, params = self._get_compiled_sql(statement, target_style)
|
|
136
138
|
params = self._process_parameters(params)
|
|
137
139
|
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
138
140
|
|
|
139
|
-
sql, params =
|
|
141
|
+
sql, params = self._get_compiled_sql(statement, target_style)
|
|
140
142
|
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
141
143
|
|
|
142
144
|
def _execute(
|
|
@@ -191,19 +193,19 @@ class OracleSyncDriver(
|
|
|
191
193
|
|
|
192
194
|
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
193
195
|
# Normalize parameter list using consolidated utility
|
|
194
|
-
|
|
196
|
+
converted_param_list = convert_parameter_sequence(param_list)
|
|
195
197
|
|
|
196
198
|
# Process parameters for Oracle
|
|
197
|
-
if
|
|
199
|
+
if converted_param_list is None:
|
|
198
200
|
processed_param_list = []
|
|
199
|
-
elif
|
|
201
|
+
elif converted_param_list and not isinstance(converted_param_list, list):
|
|
200
202
|
# Single parameter set, wrap it
|
|
201
|
-
processed_param_list = [
|
|
202
|
-
elif
|
|
203
|
+
processed_param_list = [converted_param_list]
|
|
204
|
+
elif converted_param_list and not isinstance(converted_param_list[0], (list, tuple, dict)):
|
|
203
205
|
# Already a flat list, likely from incorrect usage
|
|
204
|
-
processed_param_list = [
|
|
206
|
+
processed_param_list = [converted_param_list]
|
|
205
207
|
else:
|
|
206
|
-
processed_param_list =
|
|
208
|
+
processed_param_list = converted_param_list
|
|
207
209
|
|
|
208
210
|
# Parameters have already been processed in _execute_statement
|
|
209
211
|
with self._get_cursor(txn_conn) as cursor:
|
|
@@ -224,19 +226,32 @@ class OracleSyncDriver(
|
|
|
224
226
|
|
|
225
227
|
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
226
228
|
statements = self._split_script_statements(script, strip_trailing_semicolon=True)
|
|
229
|
+
suppress_warnings = kwargs.get("_suppress_warnings", False)
|
|
230
|
+
successful = 0
|
|
231
|
+
total_rows = 0
|
|
232
|
+
|
|
227
233
|
with self._get_cursor(txn_conn) as cursor:
|
|
228
234
|
for statement in statements:
|
|
229
235
|
if statement and statement.strip():
|
|
236
|
+
# Validate each statement unless warnings suppressed
|
|
237
|
+
if not suppress_warnings:
|
|
238
|
+
# Run validation through pipeline
|
|
239
|
+
temp_sql = SQL(statement.strip(), config=self.config)
|
|
240
|
+
temp_sql._ensure_processed()
|
|
241
|
+
# Validation errors are logged as warnings by default
|
|
242
|
+
|
|
230
243
|
cursor.execute(statement.strip())
|
|
244
|
+
successful += 1
|
|
245
|
+
total_rows += cursor.rowcount or 0
|
|
231
246
|
|
|
232
247
|
return SQLResult(
|
|
233
248
|
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
234
249
|
data=[],
|
|
235
|
-
rows_affected=
|
|
250
|
+
rows_affected=total_rows,
|
|
236
251
|
operation_type="SCRIPT",
|
|
237
252
|
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
238
253
|
total_statements=len(statements),
|
|
239
|
-
successful_statements=
|
|
254
|
+
successful_statements=successful,
|
|
240
255
|
)
|
|
241
256
|
|
|
242
257
|
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
@@ -306,6 +321,7 @@ class OracleSyncDriver(
|
|
|
306
321
|
|
|
307
322
|
class OracleAsyncDriver(
|
|
308
323
|
AsyncDriverAdapterProtocol[OracleAsyncConnection, RowT],
|
|
324
|
+
AsyncAdapterCacheMixin,
|
|
309
325
|
SQLTranslatorMixin,
|
|
310
326
|
TypeCoercionMixin,
|
|
311
327
|
AsyncStorageMixin,
|
|
@@ -322,7 +338,6 @@ class OracleAsyncDriver(
|
|
|
322
338
|
default_parameter_style: ParameterStyle = ParameterStyle.NAMED_COLON
|
|
323
339
|
__supports_arrow__: ClassVar[bool] = True
|
|
324
340
|
__supports_parquet__: ClassVar[bool] = False
|
|
325
|
-
__slots__ = ()
|
|
326
341
|
|
|
327
342
|
def __init__(
|
|
328
343
|
self,
|
|
@@ -355,7 +370,7 @@ class OracleAsyncDriver(
|
|
|
355
370
|
self, statement: SQL, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
|
|
356
371
|
) -> SQLResult[RowT]:
|
|
357
372
|
if statement.is_script:
|
|
358
|
-
sql, _ =
|
|
373
|
+
sql, _ = self._get_compiled_sql(statement, ParameterStyle.STATIC)
|
|
359
374
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
360
375
|
|
|
361
376
|
detected_styles = set()
|
|
@@ -378,7 +393,7 @@ class OracleAsyncDriver(
|
|
|
378
393
|
break
|
|
379
394
|
|
|
380
395
|
if statement.is_many:
|
|
381
|
-
sql, params =
|
|
396
|
+
sql, params = self._get_compiled_sql(statement, target_style)
|
|
382
397
|
params = self._process_parameters(params)
|
|
383
398
|
# Oracle doesn't like underscores in bind parameter names
|
|
384
399
|
if isinstance(params, list) and params and isinstance(params[0], dict):
|
|
@@ -392,7 +407,7 @@ class OracleAsyncDriver(
|
|
|
392
407
|
param_set[new_key] = param_set.pop(key)
|
|
393
408
|
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
394
409
|
|
|
395
|
-
sql, params =
|
|
410
|
+
sql, params = self._get_compiled_sql(statement, target_style)
|
|
396
411
|
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
397
412
|
|
|
398
413
|
async def _execute(
|
|
@@ -451,19 +466,19 @@ class OracleAsyncDriver(
|
|
|
451
466
|
|
|
452
467
|
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
453
468
|
# Normalize parameter list using consolidated utility
|
|
454
|
-
|
|
469
|
+
converted_param_list = convert_parameter_sequence(param_list)
|
|
455
470
|
|
|
456
471
|
# Process parameters for Oracle
|
|
457
|
-
if
|
|
472
|
+
if converted_param_list is None:
|
|
458
473
|
processed_param_list = []
|
|
459
|
-
elif
|
|
474
|
+
elif converted_param_list and not isinstance(converted_param_list, list):
|
|
460
475
|
# Single parameter set, wrap it
|
|
461
|
-
processed_param_list = [
|
|
462
|
-
elif
|
|
476
|
+
processed_param_list = [converted_param_list]
|
|
477
|
+
elif converted_param_list and not isinstance(converted_param_list[0], (list, tuple, dict)):
|
|
463
478
|
# Already a flat list, likely from incorrect usage
|
|
464
|
-
processed_param_list = [
|
|
479
|
+
processed_param_list = [converted_param_list]
|
|
465
480
|
else:
|
|
466
|
-
processed_param_list =
|
|
481
|
+
processed_param_list = converted_param_list
|
|
467
482
|
|
|
468
483
|
# Parameters have already been processed in _execute_statement
|
|
469
484
|
async with self._get_cursor(txn_conn) as cursor:
|
|
@@ -486,20 +501,32 @@ class OracleAsyncDriver(
|
|
|
486
501
|
# Oracle doesn't support multi-statement scripts in a single execute
|
|
487
502
|
# The splitter now handles PL/SQL blocks correctly when strip_trailing_semicolon=True
|
|
488
503
|
statements = self._split_script_statements(script, strip_trailing_semicolon=True)
|
|
504
|
+
suppress_warnings = kwargs.get("_suppress_warnings", False)
|
|
505
|
+
successful = 0
|
|
506
|
+
total_rows = 0
|
|
489
507
|
|
|
490
508
|
async with self._get_cursor(txn_conn) as cursor:
|
|
491
509
|
for statement in statements:
|
|
492
510
|
if statement and statement.strip():
|
|
511
|
+
# Validate each statement unless warnings suppressed
|
|
512
|
+
if not suppress_warnings:
|
|
513
|
+
# Run validation through pipeline
|
|
514
|
+
temp_sql = SQL(statement.strip(), config=self.config)
|
|
515
|
+
temp_sql._ensure_processed()
|
|
516
|
+
# Validation errors are logged as warnings by default
|
|
517
|
+
|
|
493
518
|
await cursor.execute(statement.strip())
|
|
519
|
+
successful += 1
|
|
520
|
+
total_rows += cursor.rowcount or 0
|
|
494
521
|
|
|
495
522
|
return SQLResult(
|
|
496
523
|
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
497
524
|
data=[],
|
|
498
|
-
rows_affected=
|
|
525
|
+
rows_affected=total_rows,
|
|
499
526
|
operation_type="SCRIPT",
|
|
500
527
|
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
501
528
|
total_statements=len(statements),
|
|
502
|
-
successful_statements=
|
|
529
|
+
successful_statements=successful,
|
|
503
530
|
)
|
|
504
531
|
|
|
505
532
|
async def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
@@ -15,7 +15,6 @@ from sqlspec.typing import DictRow, Empty
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
16
|
from collections.abc import Callable
|
|
17
17
|
|
|
18
|
-
from sqlglot.dialects.dialect import DialectType
|
|
19
18
|
|
|
20
19
|
logger = logging.getLogger("sqlspec.adapters.psqlpy")
|
|
21
20
|
|
|
@@ -69,56 +68,6 @@ __all__ = ("CONNECTION_FIELDS", "POOL_FIELDS", "PsqlpyConfig")
|
|
|
69
68
|
class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyDriver]):
|
|
70
69
|
"""Configuration for Psqlpy asynchronous database connections with direct field-based configuration."""
|
|
71
70
|
|
|
72
|
-
__slots__ = (
|
|
73
|
-
"_dialect",
|
|
74
|
-
"application_name",
|
|
75
|
-
"ca_file",
|
|
76
|
-
"channel_binding",
|
|
77
|
-
"client_encoding",
|
|
78
|
-
"configure",
|
|
79
|
-
"conn_recycling_method",
|
|
80
|
-
"connect_timeout_nanosec",
|
|
81
|
-
"connect_timeout_sec",
|
|
82
|
-
"db_name",
|
|
83
|
-
"default_row_type",
|
|
84
|
-
"dsn",
|
|
85
|
-
"extras",
|
|
86
|
-
"gssdelegation",
|
|
87
|
-
"gssencmode",
|
|
88
|
-
"gsslib",
|
|
89
|
-
"host",
|
|
90
|
-
"hosts",
|
|
91
|
-
"keepalives",
|
|
92
|
-
"keepalives_idle_nanosec",
|
|
93
|
-
"keepalives_idle_sec",
|
|
94
|
-
"keepalives_interval_nanosec",
|
|
95
|
-
"keepalives_interval_sec",
|
|
96
|
-
"keepalives_retries",
|
|
97
|
-
"krbsrvname",
|
|
98
|
-
"load_balance_hosts",
|
|
99
|
-
"max_db_pool_size",
|
|
100
|
-
"options",
|
|
101
|
-
"password",
|
|
102
|
-
"pool_instance",
|
|
103
|
-
"port",
|
|
104
|
-
"ports",
|
|
105
|
-
"require_auth",
|
|
106
|
-
"service",
|
|
107
|
-
"ssl_mode",
|
|
108
|
-
"sslcert",
|
|
109
|
-
"sslcompression",
|
|
110
|
-
"sslcrl",
|
|
111
|
-
"sslkey",
|
|
112
|
-
"sslnegotiation",
|
|
113
|
-
"sslpassword",
|
|
114
|
-
"sslrootcert",
|
|
115
|
-
"statement_config",
|
|
116
|
-
"target_session_attrs",
|
|
117
|
-
"tcp_user_timeout_nanosec",
|
|
118
|
-
"tcp_user_timeout_sec",
|
|
119
|
-
"username",
|
|
120
|
-
)
|
|
121
|
-
|
|
122
71
|
is_async: ClassVar[bool] = True
|
|
123
72
|
supports_connection_pooling: ClassVar[bool] = True
|
|
124
73
|
|
|
@@ -128,7 +77,7 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
128
77
|
supported_parameter_styles: ClassVar[tuple[str, ...]] = ("numeric",)
|
|
129
78
|
"""Psqlpy only supports $1, $2, ... (numeric) parameter style."""
|
|
130
79
|
|
|
131
|
-
|
|
80
|
+
default_parameter_style: ClassVar[str] = "numeric"
|
|
132
81
|
"""Psqlpy's native parameter style is $1, $2, ... (numeric)."""
|
|
133
82
|
|
|
134
83
|
def __init__(
|
|
@@ -283,8 +232,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
283
232
|
# Store other config
|
|
284
233
|
self.statement_config = statement_config or SQLConfig()
|
|
285
234
|
self.default_row_type = default_row_type
|
|
286
|
-
self.pool_instance: Optional[ConnectionPool] = pool_instance
|
|
287
|
-
self._dialect: DialectType = None
|
|
288
235
|
|
|
289
236
|
super().__init__()
|
|
290
237
|
|
|
@@ -399,7 +346,7 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
399
346
|
statement_config = replace(
|
|
400
347
|
statement_config,
|
|
401
348
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
402
|
-
|
|
349
|
+
default_parameter_style=self.default_parameter_style,
|
|
403
350
|
)
|
|
404
351
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
405
352
|
yield driver
|
|
@@ -9,6 +9,7 @@ from psqlpy import Connection
|
|
|
9
9
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
10
|
from sqlspec.driver.connection import managed_transaction_async
|
|
11
11
|
from sqlspec.driver.mixins import (
|
|
12
|
+
AsyncAdapterCacheMixin,
|
|
12
13
|
AsyncPipelinedExecutionMixin,
|
|
13
14
|
AsyncStorageMixin,
|
|
14
15
|
SQLTranslatorMixin,
|
|
@@ -31,6 +32,7 @@ logger = logging.getLogger("sqlspec")
|
|
|
31
32
|
|
|
32
33
|
class PsqlpyDriver(
|
|
33
34
|
AsyncDriverAdapterProtocol[PsqlpyConnection, RowT],
|
|
35
|
+
AsyncAdapterCacheMixin,
|
|
34
36
|
SQLTranslatorMixin,
|
|
35
37
|
TypeCoercionMixin,
|
|
36
38
|
AsyncStorageMixin,
|
|
@@ -45,7 +47,6 @@ class PsqlpyDriver(
|
|
|
45
47
|
dialect: "DialectType" = "postgres"
|
|
46
48
|
supported_parameter_styles: "tuple[ParameterStyle, ...]" = (ParameterStyle.NUMERIC,)
|
|
47
49
|
default_parameter_style: ParameterStyle = ParameterStyle.NUMERIC
|
|
48
|
-
__slots__ = ()
|
|
49
50
|
|
|
50
51
|
def __init__(
|
|
51
52
|
self,
|
|
@@ -79,7 +80,7 @@ class PsqlpyDriver(
|
|
|
79
80
|
self, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
80
81
|
) -> SQLResult[RowT]:
|
|
81
82
|
if statement.is_script:
|
|
82
|
-
sql, _ =
|
|
83
|
+
sql, _ = self._get_compiled_sql(statement, ParameterStyle.STATIC)
|
|
83
84
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
84
85
|
|
|
85
86
|
# Detect parameter styles in the SQL
|
|
@@ -106,7 +107,7 @@ class PsqlpyDriver(
|
|
|
106
107
|
break
|
|
107
108
|
|
|
108
109
|
# Compile with the determined style
|
|
109
|
-
sql, params =
|
|
110
|
+
sql, params = self._get_compiled_sql(statement, target_style)
|
|
110
111
|
params = self._process_parameters(params)
|
|
111
112
|
|
|
112
113
|
if statement.is_many:
|
|
@@ -198,16 +199,35 @@ class PsqlpyDriver(
|
|
|
198
199
|
conn = connection if connection is not None else self._connection(None)
|
|
199
200
|
|
|
200
201
|
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
201
|
-
#
|
|
202
|
-
|
|
202
|
+
# Split script into individual statements for validation
|
|
203
|
+
statements = self._split_script_statements(script)
|
|
204
|
+
suppress_warnings = kwargs.get("_suppress_warnings", False)
|
|
205
|
+
|
|
206
|
+
executed_count = 0
|
|
207
|
+
total_rows = 0
|
|
208
|
+
|
|
209
|
+
# Execute each statement individually for better control and validation
|
|
210
|
+
for statement in statements:
|
|
211
|
+
if statement.strip():
|
|
212
|
+
# Validate each statement unless warnings suppressed
|
|
213
|
+
if not suppress_warnings:
|
|
214
|
+
# Run validation through pipeline
|
|
215
|
+
temp_sql = SQL(statement, config=self.config)
|
|
216
|
+
temp_sql._ensure_processed()
|
|
217
|
+
# Validation errors are logged as warnings by default
|
|
218
|
+
|
|
219
|
+
await txn_conn.execute(statement)
|
|
220
|
+
executed_count += 1
|
|
221
|
+
# psqlpy doesn't provide row count from execute()
|
|
222
|
+
|
|
203
223
|
return SQLResult(
|
|
204
224
|
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
205
225
|
data=[],
|
|
206
|
-
rows_affected=
|
|
226
|
+
rows_affected=total_rows,
|
|
207
227
|
operation_type="SCRIPT",
|
|
208
228
|
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
209
|
-
total_statements
|
|
210
|
-
successful_statements
|
|
229
|
+
total_statements=executed_count,
|
|
230
|
+
successful_statements=executed_count,
|
|
211
231
|
)
|
|
212
232
|
|
|
213
233
|
async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|