sqlspec 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +16 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +17 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +17 -29
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +32 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +18 -9
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +44 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/METADATA +1 -1
- sqlspec-0.13.1.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from collections.abc import AsyncGenerator, Generator
|
|
2
2
|
from contextlib import asynccontextmanager, contextmanager
|
|
3
|
-
from typing import Any, ClassVar, Optional,
|
|
3
|
+
from typing import Any, ClassVar, Optional, cast
|
|
4
4
|
|
|
5
5
|
from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor
|
|
6
6
|
from sqlglot.dialects.dialect import DialectType
|
|
7
7
|
|
|
8
8
|
from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
|
|
9
|
+
from sqlspec.driver.connection import managed_transaction_async, managed_transaction_sync
|
|
9
10
|
from sqlspec.driver.mixins import (
|
|
10
11
|
AsyncPipelinedExecutionMixin,
|
|
11
12
|
AsyncStorageMixin,
|
|
@@ -15,10 +16,11 @@ from sqlspec.driver.mixins import (
|
|
|
15
16
|
ToSchemaMixin,
|
|
16
17
|
TypeCoercionMixin,
|
|
17
18
|
)
|
|
18
|
-
from sqlspec.
|
|
19
|
-
from sqlspec.statement.
|
|
19
|
+
from sqlspec.driver.parameters import normalize_parameter_sequence
|
|
20
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
21
|
+
from sqlspec.statement.result import ArrowResult, SQLResult
|
|
20
22
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
21
|
-
from sqlspec.typing import DictRow,
|
|
23
|
+
from sqlspec.typing import DictRow, RowT, SQLParameterType
|
|
22
24
|
from sqlspec.utils.logging import get_logger
|
|
23
25
|
from sqlspec.utils.sync_tools import ensure_async_
|
|
24
26
|
|
|
@@ -41,30 +43,21 @@ def _process_oracle_parameters(params: Any) -> Any:
|
|
|
41
43
|
if params is None:
|
|
42
44
|
return None
|
|
43
45
|
|
|
44
|
-
# Handle TypedParameter objects
|
|
45
46
|
if isinstance(params, TypedParameter):
|
|
46
47
|
return _process_oracle_parameters(params.value)
|
|
47
48
|
|
|
48
49
|
if isinstance(params, tuple):
|
|
49
|
-
# Convert single tuple to list and process each element
|
|
50
50
|
return [_process_oracle_parameters(item) for item in params]
|
|
51
51
|
if isinstance(params, list):
|
|
52
|
-
# Process list of parameter sets
|
|
53
52
|
processed = []
|
|
54
53
|
for param_set in params:
|
|
55
|
-
if isinstance(param_set, tuple):
|
|
56
|
-
# Convert tuple to list and process each element
|
|
57
|
-
processed.append([_process_oracle_parameters(item) for item in param_set])
|
|
58
|
-
elif isinstance(param_set, list):
|
|
59
|
-
# Process each element in the list
|
|
54
|
+
if isinstance(param_set, (tuple, list)):
|
|
60
55
|
processed.append([_process_oracle_parameters(item) for item in param_set])
|
|
61
56
|
else:
|
|
62
57
|
processed.append(_process_oracle_parameters(param_set))
|
|
63
58
|
return processed
|
|
64
59
|
if isinstance(params, dict):
|
|
65
|
-
# Process dict values
|
|
66
60
|
return {key: _process_oracle_parameters(value) for key, value in params.items()}
|
|
67
|
-
# Return as-is for other types
|
|
68
61
|
return params
|
|
69
62
|
|
|
70
63
|
|
|
@@ -114,22 +107,24 @@ class OracleSyncDriver(
|
|
|
114
107
|
|
|
115
108
|
def _execute_statement(
|
|
116
109
|
self, statement: SQL, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
|
|
117
|
-
) ->
|
|
110
|
+
) -> SQLResult[RowT]:
|
|
118
111
|
if statement.is_script:
|
|
119
112
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
120
113
|
return self._execute_script(sql, connection=connection, **kwargs)
|
|
121
114
|
|
|
122
|
-
|
|
123
|
-
|
|
115
|
+
detected_styles = set()
|
|
116
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
117
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
118
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
119
|
+
if param_infos:
|
|
120
|
+
detected_styles = {p.style for p in param_infos}
|
|
121
|
+
|
|
124
122
|
target_style = self.default_parameter_style
|
|
125
123
|
|
|
126
|
-
# Check if any detected style is not supported
|
|
127
124
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
128
125
|
if unsupported_styles:
|
|
129
|
-
# Convert to default style if we have unsupported styles
|
|
130
126
|
target_style = self.default_parameter_style
|
|
131
127
|
elif detected_styles:
|
|
132
|
-
# Use the first detected style if all are supported
|
|
133
128
|
# Prefer the first supported style found
|
|
134
129
|
for style in detected_styles:
|
|
135
130
|
if style in self.supported_parameter_styles:
|
|
@@ -138,32 +133,10 @@ class OracleSyncDriver(
|
|
|
138
133
|
|
|
139
134
|
if statement.is_many:
|
|
140
135
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
141
|
-
# Process parameters to convert tuples to lists for Oracle
|
|
142
136
|
params = self._process_parameters(params)
|
|
143
|
-
# Oracle doesn't like underscores in bind parameter names
|
|
144
|
-
if isinstance(params, list) and params and isinstance(params[0], dict):
|
|
145
|
-
# Fix the SQL and parameters
|
|
146
|
-
for key in list(params[0].keys()):
|
|
147
|
-
if key.startswith("_arg_"):
|
|
148
|
-
# Remove leading underscore: _arg_0 -> arg0
|
|
149
|
-
new_key = key[1:].replace("_", "")
|
|
150
|
-
sql = sql.replace(f":{key}", f":{new_key}")
|
|
151
|
-
# Update all parameter sets
|
|
152
|
-
for param_set in params:
|
|
153
|
-
if isinstance(param_set, dict) and key in param_set:
|
|
154
|
-
param_set[new_key] = param_set.pop(key)
|
|
155
137
|
return self._execute_many(sql, params, connection=connection, **kwargs)
|
|
156
138
|
|
|
157
139
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
158
|
-
# Oracle doesn't like underscores in bind parameter names
|
|
159
|
-
if isinstance(params, dict):
|
|
160
|
-
# Fix the SQL and parameters
|
|
161
|
-
for key in list(params.keys()):
|
|
162
|
-
if key.startswith("_arg_"):
|
|
163
|
-
# Remove leading underscore: _arg_0 -> arg0
|
|
164
|
-
new_key = key[1:].replace("_", "")
|
|
165
|
-
sql = sql.replace(f":{key}", f":{new_key}")
|
|
166
|
-
params[new_key] = params.pop(key)
|
|
167
140
|
return self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
168
141
|
|
|
169
142
|
def _execute(
|
|
@@ -173,65 +146,130 @@ class OracleSyncDriver(
|
|
|
173
146
|
statement: SQL,
|
|
174
147
|
connection: Optional[OracleSyncConnection] = None,
|
|
175
148
|
**kwargs: Any,
|
|
176
|
-
) ->
|
|
149
|
+
) -> SQLResult[RowT]:
|
|
150
|
+
# Use provided connection or driver's default connection
|
|
177
151
|
conn = self._connection(connection)
|
|
178
|
-
with self._get_cursor(conn) as cursor:
|
|
179
|
-
# Process parameters to extract values from TypedParameter objects
|
|
180
|
-
processed_params = self._process_parameters(parameters) if parameters else []
|
|
181
|
-
cursor.execute(sql, processed_params)
|
|
182
|
-
|
|
183
|
-
if self.returns_rows(statement.expression):
|
|
184
|
-
fetched_data = cursor.fetchall()
|
|
185
|
-
column_names = [col[0] for col in cursor.description or []]
|
|
186
|
-
return {"data": fetched_data, "column_names": column_names, "rows_affected": cursor.rowcount}
|
|
187
152
|
|
|
188
|
-
|
|
153
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
154
|
+
# Oracle requires special parameter handling
|
|
155
|
+
processed_params = self._process_parameters(parameters) if parameters is not None else []
|
|
156
|
+
|
|
157
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
158
|
+
cursor.execute(sql, processed_params)
|
|
159
|
+
|
|
160
|
+
if self.returns_rows(statement.expression):
|
|
161
|
+
fetched_data = cursor.fetchall()
|
|
162
|
+
column_names = [col[0] for col in cursor.description or []]
|
|
163
|
+
|
|
164
|
+
# Convert to dict if default_row_type is dict
|
|
165
|
+
if self.default_row_type == DictRow or issubclass(self.default_row_type, dict):
|
|
166
|
+
data = cast("list[RowT]", [dict(zip(column_names, row)) for row in fetched_data])
|
|
167
|
+
else:
|
|
168
|
+
data = cast("list[RowT]", fetched_data)
|
|
169
|
+
|
|
170
|
+
return SQLResult(
|
|
171
|
+
statement=statement,
|
|
172
|
+
data=data,
|
|
173
|
+
column_names=column_names,
|
|
174
|
+
rows_affected=cursor.rowcount,
|
|
175
|
+
operation_type="SELECT",
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return SQLResult(
|
|
179
|
+
statement=statement,
|
|
180
|
+
data=[],
|
|
181
|
+
rows_affected=cursor.rowcount,
|
|
182
|
+
operation_type=self._determine_operation_type(statement),
|
|
183
|
+
metadata={"status_message": "OK"},
|
|
184
|
+
)
|
|
189
185
|
|
|
190
186
|
def _execute_many(
|
|
191
187
|
self, sql: str, param_list: Any, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
|
|
192
|
-
) ->
|
|
188
|
+
) -> SQLResult[RowT]:
|
|
189
|
+
# Use provided connection or driver's default connection
|
|
193
190
|
conn = self._connection(connection)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
191
|
+
|
|
192
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
193
|
+
# Normalize parameter list using consolidated utility
|
|
194
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
195
|
+
|
|
196
|
+
# Process parameters for Oracle
|
|
197
|
+
if normalized_param_list is None:
|
|
198
|
+
processed_param_list = []
|
|
199
|
+
elif normalized_param_list and not isinstance(normalized_param_list, list):
|
|
200
200
|
# Single parameter set, wrap it
|
|
201
|
-
|
|
202
|
-
elif
|
|
201
|
+
processed_param_list = [normalized_param_list]
|
|
202
|
+
elif normalized_param_list and not isinstance(normalized_param_list[0], (list, tuple, dict)):
|
|
203
203
|
# Already a flat list, likely from incorrect usage
|
|
204
|
-
|
|
204
|
+
processed_param_list = [normalized_param_list]
|
|
205
|
+
else:
|
|
206
|
+
processed_param_list = normalized_param_list
|
|
207
|
+
|
|
205
208
|
# Parameters have already been processed in _execute_statement
|
|
206
|
-
|
|
207
|
-
|
|
209
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
210
|
+
cursor.executemany(sql, processed_param_list or [])
|
|
211
|
+
return SQLResult(
|
|
212
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
213
|
+
data=[],
|
|
214
|
+
rows_affected=cursor.rowcount,
|
|
215
|
+
operation_type="EXECUTE",
|
|
216
|
+
metadata={"status_message": "OK"},
|
|
217
|
+
)
|
|
208
218
|
|
|
209
219
|
def _execute_script(
|
|
210
220
|
self, script: str, connection: Optional[OracleSyncConnection] = None, **kwargs: Any
|
|
211
|
-
) ->
|
|
221
|
+
) -> SQLResult[RowT]:
|
|
222
|
+
# Use provided connection or driver's default connection
|
|
212
223
|
conn = self._connection(connection)
|
|
213
|
-
statements = self._split_script_statements(script, strip_trailing_semicolon=True)
|
|
214
|
-
with self._get_cursor(conn) as cursor:
|
|
215
|
-
for statement in statements:
|
|
216
|
-
if statement and statement.strip():
|
|
217
|
-
cursor.execute(statement.strip())
|
|
218
224
|
|
|
219
|
-
|
|
225
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
226
|
+
statements = self._split_script_statements(script, strip_trailing_semicolon=True)
|
|
227
|
+
with self._get_cursor(txn_conn) as cursor:
|
|
228
|
+
for statement in statements:
|
|
229
|
+
if statement and statement.strip():
|
|
230
|
+
cursor.execute(statement.strip())
|
|
231
|
+
|
|
232
|
+
return SQLResult(
|
|
233
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
234
|
+
data=[],
|
|
235
|
+
rows_affected=0,
|
|
236
|
+
operation_type="SCRIPT",
|
|
237
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
238
|
+
total_statements=len(statements),
|
|
239
|
+
successful_statements=len(statements),
|
|
240
|
+
)
|
|
220
241
|
|
|
221
242
|
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
222
243
|
self._ensure_pyarrow_installed()
|
|
223
244
|
conn = self._connection(connection)
|
|
224
245
|
|
|
225
|
-
#
|
|
226
|
-
|
|
227
|
-
sql_str
|
|
228
|
-
if
|
|
229
|
-
|
|
246
|
+
# Use the exact same parameter style detection logic as _execute_statement
|
|
247
|
+
detected_styles = set()
|
|
248
|
+
sql_str = sql.to_sql(placeholder_style=None) # Get raw SQL
|
|
249
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
250
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
251
|
+
if param_infos:
|
|
252
|
+
detected_styles = {p.style for p in param_infos}
|
|
230
253
|
|
|
231
|
-
|
|
232
|
-
|
|
254
|
+
target_style = self.default_parameter_style
|
|
255
|
+
|
|
256
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
257
|
+
if unsupported_styles:
|
|
258
|
+
target_style = self.default_parameter_style
|
|
259
|
+
elif detected_styles:
|
|
260
|
+
# Prefer the first supported style found
|
|
261
|
+
for style in detected_styles:
|
|
262
|
+
if style in self.supported_parameter_styles:
|
|
263
|
+
target_style = style
|
|
264
|
+
break
|
|
265
|
+
|
|
266
|
+
sql_str, params = sql.compile(placeholder_style=target_style)
|
|
267
|
+
processed_params = self._process_parameters(params) if params is not None else []
|
|
268
|
+
|
|
269
|
+
# Use proper transaction management like other methods
|
|
270
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn:
|
|
271
|
+
oracle_df = txn_conn.fetch_df_all(sql_str, processed_params)
|
|
233
272
|
|
|
234
|
-
oracle_df = conn.fetch_df_all(sql_str, processed_params)
|
|
235
273
|
from pyarrow.interchange.from_dataframe import from_dataframe
|
|
236
274
|
|
|
237
275
|
arrow_table = from_dataframe(oracle_df)
|
|
@@ -242,7 +280,8 @@ class OracleSyncDriver(
|
|
|
242
280
|
self._ensure_pyarrow_installed()
|
|
243
281
|
conn = self._connection(None)
|
|
244
282
|
|
|
245
|
-
|
|
283
|
+
# Use proper transaction management like other methods
|
|
284
|
+
with managed_transaction_sync(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
|
|
246
285
|
if mode == "replace":
|
|
247
286
|
cursor.execute(f"TRUNCATE TABLE {table_name}")
|
|
248
287
|
elif mode == "create":
|
|
@@ -260,57 +299,9 @@ class OracleSyncDriver(
|
|
|
260
299
|
cursor.executemany(sql, data_for_ingest)
|
|
261
300
|
return cursor.rowcount
|
|
262
301
|
|
|
263
|
-
def
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
fetched_tuples = result.get("data", [])
|
|
267
|
-
column_names = result.get("column_names", [])
|
|
268
|
-
|
|
269
|
-
if not fetched_tuples:
|
|
270
|
-
return SQLResult[RowT](statement=statement, data=[], column_names=column_names, operation_type="SELECT")
|
|
271
|
-
|
|
272
|
-
rows_as_dicts: list[dict[str, Any]] = [dict(zip(column_names, row_tuple)) for row_tuple in fetched_tuples]
|
|
273
|
-
|
|
274
|
-
if schema_type:
|
|
275
|
-
converted_data = self.to_schema(rows_as_dicts, schema_type=schema_type)
|
|
276
|
-
return SQLResult[ModelDTOT](
|
|
277
|
-
statement=statement, data=list(converted_data), column_names=column_names, operation_type="SELECT"
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
return SQLResult[RowT](
|
|
281
|
-
statement=statement, data=rows_as_dicts, column_names=column_names, operation_type="SELECT"
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
def _wrap_execute_result(
|
|
285
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
286
|
-
) -> SQLResult[RowT]:
|
|
287
|
-
operation_type = "UNKNOWN"
|
|
288
|
-
if statement.expression:
|
|
289
|
-
operation_type = str(statement.expression.key).upper()
|
|
290
|
-
|
|
291
|
-
if "statements_executed" in result:
|
|
292
|
-
script_result = cast("ScriptResultDict", result)
|
|
293
|
-
return SQLResult[RowT](
|
|
294
|
-
statement=statement,
|
|
295
|
-
data=[],
|
|
296
|
-
rows_affected=0,
|
|
297
|
-
operation_type="SCRIPT",
|
|
298
|
-
metadata={
|
|
299
|
-
"status_message": script_result.get("status_message", ""),
|
|
300
|
-
"statements_executed": script_result.get("statements_executed", -1),
|
|
301
|
-
},
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
dml_result = cast("DMLResultDict", result)
|
|
305
|
-
rows_affected = dml_result.get("rows_affected", -1)
|
|
306
|
-
status_message = dml_result.get("status_message", "")
|
|
307
|
-
return SQLResult[RowT](
|
|
308
|
-
statement=statement,
|
|
309
|
-
data=[],
|
|
310
|
-
rows_affected=rows_affected,
|
|
311
|
-
operation_type=operation_type,
|
|
312
|
-
metadata={"status_message": status_message},
|
|
313
|
-
)
|
|
302
|
+
def _connection(self, connection: Optional[OracleSyncConnection] = None) -> OracleSyncConnection:
|
|
303
|
+
"""Get the connection to use for the operation."""
|
|
304
|
+
return connection or self.connection
|
|
314
305
|
|
|
315
306
|
|
|
316
307
|
class OracleAsyncDriver(
|
|
@@ -362,22 +353,24 @@ class OracleAsyncDriver(
|
|
|
362
353
|
|
|
363
354
|
async def _execute_statement(
|
|
364
355
|
self, statement: SQL, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
|
|
365
|
-
) ->
|
|
356
|
+
) -> SQLResult[RowT]:
|
|
366
357
|
if statement.is_script:
|
|
367
358
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
368
359
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
369
360
|
|
|
370
|
-
|
|
371
|
-
|
|
361
|
+
detected_styles = set()
|
|
362
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
363
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
364
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
365
|
+
if param_infos:
|
|
366
|
+
detected_styles = {p.style for p in param_infos}
|
|
367
|
+
|
|
372
368
|
target_style = self.default_parameter_style
|
|
373
369
|
|
|
374
|
-
# Check if any detected style is not supported
|
|
375
370
|
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
376
371
|
if unsupported_styles:
|
|
377
|
-
# Convert to default style if we have unsupported styles
|
|
378
372
|
target_style = self.default_parameter_style
|
|
379
373
|
elif detected_styles:
|
|
380
|
-
# Use the first detected style if all are supported
|
|
381
374
|
# Prefer the first supported style found
|
|
382
375
|
for style in detected_styles:
|
|
383
376
|
if style in self.supported_parameter_styles:
|
|
@@ -386,32 +379,20 @@ class OracleAsyncDriver(
|
|
|
386
379
|
|
|
387
380
|
if statement.is_many:
|
|
388
381
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
389
|
-
# Process parameters to convert tuples to lists for Oracle
|
|
390
382
|
params = self._process_parameters(params)
|
|
391
383
|
# Oracle doesn't like underscores in bind parameter names
|
|
392
384
|
if isinstance(params, list) and params and isinstance(params[0], dict):
|
|
393
385
|
# Fix the SQL and parameters
|
|
394
386
|
for key in list(params[0].keys()):
|
|
395
387
|
if key.startswith("_arg_"):
|
|
396
|
-
# Remove leading underscore: _arg_0 -> arg0
|
|
397
388
|
new_key = key[1:].replace("_", "")
|
|
398
389
|
sql = sql.replace(f":{key}", f":{new_key}")
|
|
399
|
-
# Update all parameter sets
|
|
400
390
|
for param_set in params:
|
|
401
391
|
if isinstance(param_set, dict) and key in param_set:
|
|
402
392
|
param_set[new_key] = param_set.pop(key)
|
|
403
393
|
return await self._execute_many(sql, params, connection=connection, **kwargs)
|
|
404
394
|
|
|
405
395
|
sql, params = statement.compile(placeholder_style=target_style)
|
|
406
|
-
# Oracle doesn't like underscores in bind parameter names
|
|
407
|
-
if isinstance(params, dict):
|
|
408
|
-
# Fix the SQL and parameters
|
|
409
|
-
for key in list(params.keys()):
|
|
410
|
-
if key.startswith("_arg_"):
|
|
411
|
-
# Remove leading underscore: _arg_0 -> arg0
|
|
412
|
-
new_key = key[1:].replace("_", "")
|
|
413
|
-
sql = sql.replace(f":{key}", f":{new_key}")
|
|
414
|
-
params[new_key] = params.pop(key)
|
|
415
396
|
return await self._execute(sql, params, statement, connection=connection, **kwargs)
|
|
416
397
|
|
|
417
398
|
async def _execute(
|
|
@@ -421,77 +402,132 @@ class OracleAsyncDriver(
|
|
|
421
402
|
statement: SQL,
|
|
422
403
|
connection: Optional[OracleAsyncConnection] = None,
|
|
423
404
|
**kwargs: Any,
|
|
424
|
-
) ->
|
|
405
|
+
) -> SQLResult[RowT]:
|
|
406
|
+
# Use provided connection or driver's default connection
|
|
425
407
|
conn = self._connection(connection)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
else
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
408
|
+
|
|
409
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
410
|
+
# Oracle requires special parameter handling
|
|
411
|
+
processed_params = self._process_parameters(parameters) if parameters is not None else []
|
|
412
|
+
|
|
413
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
414
|
+
if parameters is None:
|
|
415
|
+
await cursor.execute(sql)
|
|
416
|
+
else:
|
|
417
|
+
await cursor.execute(sql, processed_params)
|
|
418
|
+
|
|
419
|
+
# For SELECT statements, extract data while cursor is open
|
|
420
|
+
if self.returns_rows(statement.expression):
|
|
421
|
+
fetched_data = await cursor.fetchall()
|
|
422
|
+
column_names = [col[0] for col in cursor.description or []]
|
|
423
|
+
|
|
424
|
+
# Convert to dict if default_row_type is dict
|
|
425
|
+
if self.default_row_type == DictRow or issubclass(self.default_row_type, dict):
|
|
426
|
+
data = cast("list[RowT]", [dict(zip(column_names, row)) for row in fetched_data])
|
|
427
|
+
else:
|
|
428
|
+
data = cast("list[RowT]", fetched_data)
|
|
429
|
+
|
|
430
|
+
return SQLResult(
|
|
431
|
+
statement=statement,
|
|
432
|
+
data=data,
|
|
433
|
+
column_names=column_names,
|
|
434
|
+
rows_affected=cursor.rowcount,
|
|
435
|
+
operation_type="SELECT",
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
return SQLResult(
|
|
439
|
+
statement=statement,
|
|
440
|
+
data=[],
|
|
441
|
+
rows_affected=cursor.rowcount,
|
|
442
|
+
operation_type=self._determine_operation_type(statement),
|
|
443
|
+
metadata={"status_message": "OK"},
|
|
444
|
+
)
|
|
446
445
|
|
|
447
446
|
async def _execute_many(
|
|
448
447
|
self, sql: str, param_list: Any, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
|
|
449
|
-
) ->
|
|
448
|
+
) -> SQLResult[RowT]:
|
|
449
|
+
# Use provided connection or driver's default connection
|
|
450
450
|
conn = self._connection(connection)
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
451
|
+
|
|
452
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
453
|
+
# Normalize parameter list using consolidated utility
|
|
454
|
+
normalized_param_list = normalize_parameter_sequence(param_list)
|
|
455
|
+
|
|
456
|
+
# Process parameters for Oracle
|
|
457
|
+
if normalized_param_list is None:
|
|
458
|
+
processed_param_list = []
|
|
459
|
+
elif normalized_param_list and not isinstance(normalized_param_list, list):
|
|
457
460
|
# Single parameter set, wrap it
|
|
458
|
-
|
|
459
|
-
elif
|
|
461
|
+
processed_param_list = [normalized_param_list]
|
|
462
|
+
elif normalized_param_list and not isinstance(normalized_param_list[0], (list, tuple, dict)):
|
|
460
463
|
# Already a flat list, likely from incorrect usage
|
|
461
|
-
|
|
464
|
+
processed_param_list = [normalized_param_list]
|
|
465
|
+
else:
|
|
466
|
+
processed_param_list = normalized_param_list
|
|
467
|
+
|
|
462
468
|
# Parameters have already been processed in _execute_statement
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
469
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
470
|
+
await cursor.executemany(sql, processed_param_list or [])
|
|
471
|
+
return SQLResult(
|
|
472
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
473
|
+
data=[],
|
|
474
|
+
rows_affected=cursor.rowcount,
|
|
475
|
+
operation_type="EXECUTE",
|
|
476
|
+
metadata={"status_message": "OK"},
|
|
477
|
+
)
|
|
466
478
|
|
|
467
479
|
async def _execute_script(
|
|
468
480
|
self, script: str, connection: Optional[OracleAsyncConnection] = None, **kwargs: Any
|
|
469
|
-
) ->
|
|
481
|
+
) -> SQLResult[RowT]:
|
|
482
|
+
# Use provided connection or driver's default connection
|
|
470
483
|
conn = self._connection(connection)
|
|
471
|
-
# Oracle doesn't support multi-statement scripts in a single execute
|
|
472
|
-
# The splitter now handles PL/SQL blocks correctly when strip_trailing_semicolon=True
|
|
473
|
-
statements = self._split_script_statements(script, strip_trailing_semicolon=True)
|
|
474
484
|
|
|
475
|
-
async with
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
485
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
486
|
+
# Oracle doesn't support multi-statement scripts in a single execute
|
|
487
|
+
# The splitter now handles PL/SQL blocks correctly when strip_trailing_semicolon=True
|
|
488
|
+
statements = self._split_script_statements(script, strip_trailing_semicolon=True)
|
|
479
489
|
|
|
480
|
-
|
|
481
|
-
|
|
490
|
+
async with self._get_cursor(txn_conn) as cursor:
|
|
491
|
+
for statement in statements:
|
|
492
|
+
if statement and statement.strip():
|
|
493
|
+
await cursor.execute(statement.strip())
|
|
494
|
+
|
|
495
|
+
return SQLResult(
|
|
496
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
497
|
+
data=[],
|
|
498
|
+
rows_affected=0,
|
|
499
|
+
operation_type="SCRIPT",
|
|
500
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
501
|
+
total_statements=len(statements),
|
|
502
|
+
successful_statements=len(statements),
|
|
503
|
+
)
|
|
482
504
|
|
|
483
505
|
async def _fetch_arrow_table(self, sql: SQL, connection: "Optional[Any]" = None, **kwargs: Any) -> "ArrowResult":
|
|
484
506
|
self._ensure_pyarrow_installed()
|
|
485
507
|
conn = self._connection(connection)
|
|
486
508
|
|
|
487
|
-
#
|
|
488
|
-
|
|
489
|
-
sql_str
|
|
490
|
-
if
|
|
491
|
-
|
|
509
|
+
# Use the exact same parameter style detection logic as _execute_statement
|
|
510
|
+
detected_styles = set()
|
|
511
|
+
sql_str = sql.to_sql(placeholder_style=None) # Get raw SQL
|
|
512
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
513
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
514
|
+
if param_infos:
|
|
515
|
+
detected_styles = {p.style for p in param_infos}
|
|
492
516
|
|
|
493
|
-
|
|
494
|
-
|
|
517
|
+
target_style = self.default_parameter_style
|
|
518
|
+
|
|
519
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
520
|
+
if unsupported_styles:
|
|
521
|
+
target_style = self.default_parameter_style
|
|
522
|
+
elif detected_styles:
|
|
523
|
+
# Prefer the first supported style found
|
|
524
|
+
for style in detected_styles:
|
|
525
|
+
if style in self.supported_parameter_styles:
|
|
526
|
+
target_style = style
|
|
527
|
+
break
|
|
528
|
+
|
|
529
|
+
sql_str, params = sql.compile(placeholder_style=target_style)
|
|
530
|
+
processed_params = self._process_parameters(params) if params is not None else []
|
|
495
531
|
|
|
496
532
|
oracle_df = await conn.fetch_df_all(sql_str, processed_params)
|
|
497
533
|
from pyarrow.interchange.from_dataframe import from_dataframe
|
|
@@ -504,7 +540,8 @@ class OracleAsyncDriver(
|
|
|
504
540
|
self._ensure_pyarrow_installed()
|
|
505
541
|
conn = self._connection(None)
|
|
506
542
|
|
|
507
|
-
|
|
543
|
+
# Use proper transaction management like other methods
|
|
544
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn, self._get_cursor(txn_conn) as cursor:
|
|
508
545
|
if mode == "replace":
|
|
509
546
|
await cursor.execute(f"TRUNCATE TABLE {table_name}")
|
|
510
547
|
elif mode == "create":
|
|
@@ -522,60 +559,6 @@ class OracleAsyncDriver(
|
|
|
522
559
|
await cursor.executemany(sql, data_for_ingest)
|
|
523
560
|
return cursor.rowcount
|
|
524
561
|
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
result: SelectResultDict,
|
|
529
|
-
schema_type: Optional[type[ModelDTOT]] = None,
|
|
530
|
-
**kwargs: Any, # pyright: ignore[reportUnusedParameter]
|
|
531
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
532
|
-
fetched_tuples = result["data"]
|
|
533
|
-
column_names = result["column_names"]
|
|
534
|
-
|
|
535
|
-
if not fetched_tuples:
|
|
536
|
-
return SQLResult[RowT](statement=statement, data=[], column_names=column_names, operation_type="SELECT")
|
|
537
|
-
|
|
538
|
-
rows_as_dicts: list[dict[str, Any]] = [dict(zip(column_names, row_tuple)) for row_tuple in fetched_tuples]
|
|
539
|
-
|
|
540
|
-
if schema_type:
|
|
541
|
-
converted_data = self.to_schema(rows_as_dicts, schema_type=schema_type)
|
|
542
|
-
return SQLResult[ModelDTOT](
|
|
543
|
-
statement=statement, data=list(converted_data), column_names=column_names, operation_type="SELECT"
|
|
544
|
-
)
|
|
545
|
-
return SQLResult[RowT](
|
|
546
|
-
statement=statement, data=rows_as_dicts, column_names=column_names, operation_type="SELECT"
|
|
547
|
-
)
|
|
548
|
-
|
|
549
|
-
async def _wrap_execute_result(
|
|
550
|
-
self,
|
|
551
|
-
statement: SQL,
|
|
552
|
-
result: Union[DMLResultDict, ScriptResultDict],
|
|
553
|
-
**kwargs: Any, # pyright: ignore[reportUnusedParameter]
|
|
554
|
-
) -> SQLResult[RowT]:
|
|
555
|
-
operation_type = "UNKNOWN"
|
|
556
|
-
if statement.expression:
|
|
557
|
-
operation_type = str(statement.expression.key).upper()
|
|
558
|
-
|
|
559
|
-
if "statements_executed" in result:
|
|
560
|
-
script_result = cast("ScriptResultDict", result)
|
|
561
|
-
return SQLResult[RowT](
|
|
562
|
-
statement=statement,
|
|
563
|
-
data=[],
|
|
564
|
-
rows_affected=0,
|
|
565
|
-
operation_type="SCRIPT",
|
|
566
|
-
metadata={
|
|
567
|
-
"status_message": script_result.get("status_message", ""),
|
|
568
|
-
"statements_executed": script_result.get("statements_executed", -1),
|
|
569
|
-
},
|
|
570
|
-
)
|
|
571
|
-
|
|
572
|
-
dml_result = cast("DMLResultDict", result)
|
|
573
|
-
rows_affected = dml_result.get("rows_affected", -1)
|
|
574
|
-
status_message = dml_result.get("status_message", "")
|
|
575
|
-
return SQLResult[RowT](
|
|
576
|
-
statement=statement,
|
|
577
|
-
data=[],
|
|
578
|
-
rows_affected=rows_affected,
|
|
579
|
-
operation_type=operation_type,
|
|
580
|
-
metadata={"status_message": status_message},
|
|
581
|
-
)
|
|
562
|
+
def _connection(self, connection: Optional[OracleAsyncConnection] = None) -> OracleAsyncConnection:
|
|
563
|
+
"""Get the connection to use for the operation."""
|
|
564
|
+
return connection or self.connection
|