sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Optimized CSV writing utilities."""
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from sqlspec.typing import PYARROW_INSTALLED
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from sqlspec.statement.result import SQLResult
|
|
10
|
+
|
|
11
|
+
__all__ = ("write_csv", "write_csv_default", "write_csv_optimized")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _raise_no_column_names_error() -> None:
|
|
15
|
+
"""Raise error when no column names are available."""
|
|
16
|
+
msg = "No column names available"
|
|
17
|
+
raise ValueError(msg)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
21
|
+
"""Write result to CSV file.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
result: SQL result to write
|
|
25
|
+
file: File-like object to write to
|
|
26
|
+
**options: CSV writer options
|
|
27
|
+
"""
|
|
28
|
+
if PYARROW_INSTALLED:
|
|
29
|
+
try:
|
|
30
|
+
write_csv_optimized(result, file, **options)
|
|
31
|
+
except Exception:
|
|
32
|
+
write_csv_default(result, file, **options)
|
|
33
|
+
else:
|
|
34
|
+
write_csv_default(result, file, **options)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def write_csv_default(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
38
|
+
"""Write result to CSV file using default method.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
result: SQL result to write
|
|
42
|
+
file: File-like object to write to
|
|
43
|
+
**options: CSV writer options
|
|
44
|
+
"""
|
|
45
|
+
csv_options = options.copy()
|
|
46
|
+
csv_options.pop("compression", None)
|
|
47
|
+
csv_options.pop("partition_by", None)
|
|
48
|
+
|
|
49
|
+
writer = csv.writer(file, **csv_options)
|
|
50
|
+
if result.column_names:
|
|
51
|
+
writer.writerow(result.column_names)
|
|
52
|
+
if result.data:
|
|
53
|
+
if result.data and isinstance(result.data[0], dict):
|
|
54
|
+
rows = []
|
|
55
|
+
for row_dict in result.data:
|
|
56
|
+
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
57
|
+
rows.append(row_values)
|
|
58
|
+
writer.writerows(rows)
|
|
59
|
+
else:
|
|
60
|
+
writer.writerows(result.data)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def write_csv_optimized(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
64
|
+
"""Write result to CSV using PyArrow if available for better performance.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
result: SQL result to write
|
|
68
|
+
file: File-like object to write to
|
|
69
|
+
**options: CSV writer options
|
|
70
|
+
"""
|
|
71
|
+
_ = options
|
|
72
|
+
import pyarrow as pa
|
|
73
|
+
import pyarrow.csv as pa_csv
|
|
74
|
+
|
|
75
|
+
if not result.data:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if not hasattr(file, "name"):
|
|
79
|
+
msg = "PyArrow CSV writer requires a file with a 'name' attribute"
|
|
80
|
+
raise ValueError(msg)
|
|
81
|
+
|
|
82
|
+
table: Any
|
|
83
|
+
if isinstance(result.data[0], dict):
|
|
84
|
+
table = pa.Table.from_pylist(result.data)
|
|
85
|
+
elif result.column_names:
|
|
86
|
+
data_dicts = [dict(zip(result.column_names, row)) for row in result.data]
|
|
87
|
+
table = pa.Table.from_pylist(data_dicts)
|
|
88
|
+
else:
|
|
89
|
+
_raise_no_column_names_error()
|
|
90
|
+
|
|
91
|
+
pa_csv.write_csv(table, file.name) # pyright: ignore
|
|
@@ -16,6 +16,12 @@ from sqlspec.statement.filters import StatementFilter
|
|
|
16
16
|
from sqlspec.statement.result import SQLResult
|
|
17
17
|
from sqlspec.statement.sql import SQL
|
|
18
18
|
from sqlspec.utils.logging import get_logger
|
|
19
|
+
from sqlspec.utils.type_guards import (
|
|
20
|
+
is_async_pipeline_capable_driver,
|
|
21
|
+
is_async_transaction_state_capable,
|
|
22
|
+
is_sync_pipeline_capable_driver,
|
|
23
|
+
is_sync_transaction_state_capable,
|
|
24
|
+
)
|
|
19
25
|
|
|
20
26
|
if TYPE_CHECKING:
|
|
21
27
|
from typing import Literal
|
|
@@ -136,11 +142,11 @@ class Pipeline:
|
|
|
136
142
|
"""
|
|
137
143
|
self._operations.append(
|
|
138
144
|
PipelineOperation(
|
|
139
|
-
sql=SQL(statement,
|
|
145
|
+
sql=SQL(statement, parameters=parameters or None, config=self.driver.config, **kwargs),
|
|
146
|
+
operation_type="execute",
|
|
140
147
|
)
|
|
141
148
|
)
|
|
142
149
|
|
|
143
|
-
# Check for auto-flush
|
|
144
150
|
if len(self._operations) >= self.max_operations:
|
|
145
151
|
logger.warning("Pipeline auto-flushing at %s operations", len(self._operations))
|
|
146
152
|
self.process()
|
|
@@ -153,7 +159,8 @@ class Pipeline:
|
|
|
153
159
|
"""Add a select operation to the pipeline."""
|
|
154
160
|
self._operations.append(
|
|
155
161
|
PipelineOperation(
|
|
156
|
-
sql=SQL(statement,
|
|
162
|
+
sql=SQL(statement, parameters=parameters or None, config=self.driver.config, **kwargs),
|
|
163
|
+
operation_type="select",
|
|
157
164
|
)
|
|
158
165
|
)
|
|
159
166
|
return self
|
|
@@ -175,11 +182,11 @@ class Pipeline:
|
|
|
175
182
|
raise ValueError(msg)
|
|
176
183
|
|
|
177
184
|
batch_params = parameters[0]
|
|
178
|
-
# Convert tuple to list if needed
|
|
179
185
|
if isinstance(batch_params, tuple):
|
|
180
186
|
batch_params = list(batch_params)
|
|
181
|
-
|
|
182
|
-
|
|
187
|
+
sql_obj = SQL(
|
|
188
|
+
statement, parameters=parameters[1:] if len(parameters) > 1 else None, config=self.driver.config, **kwargs
|
|
189
|
+
).as_many(batch_params)
|
|
183
190
|
|
|
184
191
|
self._operations.append(PipelineOperation(sql=sql_obj, operation_type="execute_many"))
|
|
185
192
|
return self
|
|
@@ -189,7 +196,7 @@ class Pipeline:
|
|
|
189
196
|
if isinstance(script, SQL):
|
|
190
197
|
sql_obj = script.as_script()
|
|
191
198
|
else:
|
|
192
|
-
sql_obj = SQL(script,
|
|
199
|
+
sql_obj = SQL(script, parameters=filters or None, config=self.driver.config, **kwargs).as_script()
|
|
193
200
|
|
|
194
201
|
self._operations.append(PipelineOperation(sql=sql_obj, operation_type="execute_script"))
|
|
195
202
|
return self
|
|
@@ -206,13 +213,11 @@ class Pipeline:
|
|
|
206
213
|
if not self._operations:
|
|
207
214
|
return []
|
|
208
215
|
|
|
209
|
-
# Apply global filters
|
|
210
216
|
if filters:
|
|
211
217
|
self._apply_global_filters(filters)
|
|
212
218
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
results = self.driver._execute_pipeline_native(self._operations, **self.options) # pyright: ignore
|
|
219
|
+
if is_sync_pipeline_capable_driver(self.driver):
|
|
220
|
+
results = self.driver._execute_pipeline_native(self._operations, **self.options)
|
|
216
221
|
else:
|
|
217
222
|
results = self._execute_pipeline_simulated()
|
|
218
223
|
|
|
@@ -226,7 +231,6 @@ class Pipeline:
|
|
|
226
231
|
connection = None
|
|
227
232
|
auto_transaction = False
|
|
228
233
|
|
|
229
|
-
# Only log once per pipeline, not for each operation
|
|
230
234
|
if not self._simulation_logged:
|
|
231
235
|
logger.info(
|
|
232
236
|
"%s using simulated pipeline. Native support: %s",
|
|
@@ -236,29 +240,21 @@ class Pipeline:
|
|
|
236
240
|
self._simulation_logged = True
|
|
237
241
|
|
|
238
242
|
try:
|
|
239
|
-
# Get a connection for the entire pipeline
|
|
240
243
|
connection = self.driver._connection()
|
|
241
244
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
# Set isolation level if specified
|
|
245
|
-
pass # Driver-specific implementation
|
|
246
|
-
|
|
247
|
-
if hasattr(connection, "in_transaction") and not connection.in_transaction():
|
|
248
|
-
if hasattr(connection, "begin"):
|
|
249
|
-
connection.begin()
|
|
245
|
+
if is_sync_transaction_state_capable(connection) and not connection.in_transaction():
|
|
246
|
+
connection.begin()
|
|
250
247
|
auto_transaction = True
|
|
251
248
|
|
|
252
|
-
# Process each operation
|
|
253
249
|
for i, op in enumerate(self._operations):
|
|
254
250
|
self._execute_single_operation(i, op, results, connection, auto_transaction)
|
|
255
251
|
|
|
256
252
|
# Commit if we started the transaction
|
|
257
|
-
if auto_transaction and
|
|
253
|
+
if auto_transaction and is_sync_transaction_state_capable(connection):
|
|
258
254
|
connection.commit()
|
|
259
255
|
|
|
260
256
|
except Exception as e:
|
|
261
|
-
if connection and auto_transaction and
|
|
257
|
+
if connection and auto_transaction and is_sync_transaction_state_capable(connection):
|
|
262
258
|
connection.rollback()
|
|
263
259
|
if not isinstance(e, PipelineExecutionError):
|
|
264
260
|
msg = f"Pipeline execution failed: {e}"
|
|
@@ -281,20 +277,18 @@ class Pipeline:
|
|
|
281
277
|
else:
|
|
282
278
|
result = cast("SQLResult[Any]", self.driver.execute(op.sql, _connection=connection))
|
|
283
279
|
|
|
284
|
-
# Add operation context to result
|
|
285
280
|
result.operation_index = i
|
|
286
281
|
result.pipeline_sql = op.sql
|
|
287
282
|
results.append(result)
|
|
288
283
|
|
|
289
284
|
except Exception as e:
|
|
290
285
|
if self.continue_on_error:
|
|
291
|
-
# Create error result
|
|
292
286
|
error_result = SQLResult(
|
|
293
287
|
statement=op.sql, data=[], error=e, operation_index=i, parameters=op.sql.parameters
|
|
294
288
|
)
|
|
295
289
|
results.append(error_result)
|
|
296
290
|
else:
|
|
297
|
-
if auto_transaction and
|
|
291
|
+
if auto_transaction and is_sync_transaction_state_capable(connection):
|
|
298
292
|
connection.rollback()
|
|
299
293
|
msg = f"Pipeline failed at operation {i}: {e}"
|
|
300
294
|
raise PipelineExecutionError(
|
|
@@ -304,7 +298,6 @@ class Pipeline:
|
|
|
304
298
|
def _apply_global_filters(self, filters: "list[StatementFilter]") -> None:
|
|
305
299
|
"""Apply global filters to all operations."""
|
|
306
300
|
for operation in self._operations:
|
|
307
|
-
# Add filters to each operation
|
|
308
301
|
if operation.filters is None:
|
|
309
302
|
operation.filters = []
|
|
310
303
|
operation.filters.extend(filters)
|
|
@@ -313,13 +306,12 @@ class Pipeline:
|
|
|
313
306
|
"""Apply filters to a SQL object."""
|
|
314
307
|
result = sql
|
|
315
308
|
for filter_obj in filters:
|
|
316
|
-
|
|
317
|
-
result = cast("Any", filter_obj).apply(result)
|
|
309
|
+
result = filter_obj.append_to_statement(result)
|
|
318
310
|
return result
|
|
319
311
|
|
|
320
312
|
def _has_native_support(self) -> bool:
|
|
321
313
|
"""Check if driver has native pipeline support."""
|
|
322
|
-
return
|
|
314
|
+
return is_sync_pipeline_capable_driver(self.driver)
|
|
323
315
|
|
|
324
316
|
def _process_parameters(self, params: tuple[Any, ...]) -> tuple["list[StatementFilter]", "Optional[Any]"]:
|
|
325
317
|
"""Extract filters and parameters from mixed args.
|
|
@@ -336,7 +328,6 @@ class Pipeline:
|
|
|
336
328
|
else:
|
|
337
329
|
parameters.append(param)
|
|
338
330
|
|
|
339
|
-
# Return parameters based on count
|
|
340
331
|
if not parameters:
|
|
341
332
|
return filters, None
|
|
342
333
|
if len(parameters) == 1:
|
|
@@ -375,11 +366,11 @@ class AsyncPipeline:
|
|
|
375
366
|
"""Add an execute operation to the async pipeline."""
|
|
376
367
|
self._operations.append(
|
|
377
368
|
PipelineOperation(
|
|
378
|
-
sql=SQL(statement,
|
|
369
|
+
sql=SQL(statement, parameters=parameters or None, config=self.driver.config, **kwargs),
|
|
370
|
+
operation_type="execute",
|
|
379
371
|
)
|
|
380
372
|
)
|
|
381
373
|
|
|
382
|
-
# Check for auto-flush
|
|
383
374
|
if len(self._operations) >= self.max_operations:
|
|
384
375
|
logger.warning("Async pipeline auto-flushing at %s operations", len(self._operations))
|
|
385
376
|
await self.process()
|
|
@@ -392,7 +383,8 @@ class AsyncPipeline:
|
|
|
392
383
|
"""Add a select operation to the async pipeline."""
|
|
393
384
|
self._operations.append(
|
|
394
385
|
PipelineOperation(
|
|
395
|
-
sql=SQL(statement,
|
|
386
|
+
sql=SQL(statement, parameters=parameters or None, config=self.driver.config, **kwargs),
|
|
387
|
+
operation_type="select",
|
|
396
388
|
)
|
|
397
389
|
)
|
|
398
390
|
return self
|
|
@@ -407,11 +399,11 @@ class AsyncPipeline:
|
|
|
407
399
|
raise ValueError(msg)
|
|
408
400
|
|
|
409
401
|
batch_params = parameters[0]
|
|
410
|
-
# Convert tuple to list if needed
|
|
411
402
|
if isinstance(batch_params, tuple):
|
|
412
403
|
batch_params = list(batch_params)
|
|
413
|
-
|
|
414
|
-
|
|
404
|
+
sql_obj = SQL(
|
|
405
|
+
statement, parameters=parameters[1:] if len(parameters) > 1 else None, config=self.driver.config, **kwargs
|
|
406
|
+
).as_many(batch_params)
|
|
415
407
|
|
|
416
408
|
self._operations.append(PipelineOperation(sql=sql_obj, operation_type="execute_many"))
|
|
417
409
|
return self
|
|
@@ -423,7 +415,7 @@ class AsyncPipeline:
|
|
|
423
415
|
if isinstance(script, SQL):
|
|
424
416
|
sql_obj = script.as_script()
|
|
425
417
|
else:
|
|
426
|
-
sql_obj = SQL(script,
|
|
418
|
+
sql_obj = SQL(script, parameters=filters or None, config=self.driver.config, **kwargs).as_script()
|
|
427
419
|
|
|
428
420
|
self._operations.append(PipelineOperation(sql=sql_obj, operation_type="execute_script"))
|
|
429
421
|
return self
|
|
@@ -433,8 +425,7 @@ class AsyncPipeline:
|
|
|
433
425
|
if not self._operations:
|
|
434
426
|
return []
|
|
435
427
|
|
|
436
|
-
|
|
437
|
-
if hasattr(self.driver, "_execute_pipeline_native"):
|
|
428
|
+
if is_async_pipeline_capable_driver(self.driver):
|
|
438
429
|
results = await cast("Any", self.driver)._execute_pipeline_native(self._operations, **self.options)
|
|
439
430
|
else:
|
|
440
431
|
results = await self._execute_pipeline_simulated()
|
|
@@ -460,20 +451,18 @@ class AsyncPipeline:
|
|
|
460
451
|
try:
|
|
461
452
|
connection = self.driver._connection()
|
|
462
453
|
|
|
463
|
-
if
|
|
464
|
-
|
|
465
|
-
await connection.begin()
|
|
454
|
+
if is_async_transaction_state_capable(connection) and not connection.in_transaction():
|
|
455
|
+
await connection.begin()
|
|
466
456
|
auto_transaction = True
|
|
467
457
|
|
|
468
|
-
# Process each operation
|
|
469
458
|
for i, op in enumerate(self._operations):
|
|
470
459
|
await self._execute_single_operation_async(i, op, results, connection, auto_transaction)
|
|
471
460
|
|
|
472
|
-
if auto_transaction and
|
|
461
|
+
if auto_transaction and is_async_transaction_state_capable(connection):
|
|
473
462
|
await connection.commit()
|
|
474
463
|
|
|
475
464
|
except Exception as e:
|
|
476
|
-
if connection and auto_transaction and
|
|
465
|
+
if connection and auto_transaction and is_async_transaction_state_capable(connection):
|
|
477
466
|
await connection.rollback()
|
|
478
467
|
if not isinstance(e, PipelineExecutionError):
|
|
479
468
|
msg = f"Async pipeline execution failed: {e}"
|
|
@@ -506,7 +495,7 @@ class AsyncPipeline:
|
|
|
506
495
|
)
|
|
507
496
|
results.append(error_result)
|
|
508
497
|
else:
|
|
509
|
-
if auto_transaction and
|
|
498
|
+
if auto_transaction and is_async_transaction_state_capable(connection):
|
|
510
499
|
await connection.rollback()
|
|
511
500
|
msg = f"Async pipeline failed at operation {i}: {e}"
|
|
512
501
|
raise PipelineExecutionError(
|
|
@@ -515,7 +504,7 @@ class AsyncPipeline:
|
|
|
515
504
|
|
|
516
505
|
def _has_native_support(self) -> bool:
|
|
517
506
|
"""Check if driver has native pipeline support."""
|
|
518
|
-
return
|
|
507
|
+
return is_async_pipeline_capable_driver(self.driver)
|
|
519
508
|
|
|
520
509
|
@property
|
|
521
510
|
def operations(self) -> "list[PipelineOperation]":
|
|
@@ -13,15 +13,9 @@ from typing import Any, Callable, Optional, Union, cast, overload
|
|
|
13
13
|
from uuid import UUID
|
|
14
14
|
|
|
15
15
|
from sqlspec.exceptions import SQLSpecError, wrap_exceptions
|
|
16
|
-
from sqlspec.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
convert,
|
|
20
|
-
get_type_adapter,
|
|
21
|
-
is_dataclass,
|
|
22
|
-
is_msgspec_struct,
|
|
23
|
-
is_pydantic_model,
|
|
24
|
-
)
|
|
16
|
+
from sqlspec.statement.result import OperationType
|
|
17
|
+
from sqlspec.typing import ModelDTOT, ModelT, convert, get_type_adapter
|
|
18
|
+
from sqlspec.utils.type_guards import is_dataclass, is_msgspec_struct, is_pydantic_model
|
|
25
19
|
|
|
26
20
|
__all__ = ("_DEFAULT_TYPE_DECODERS", "ToSchemaMixin", "_default_msgspec_deserializer")
|
|
27
21
|
|
|
@@ -60,6 +54,30 @@ def _default_msgspec_deserializer(
|
|
|
60
54
|
class ToSchemaMixin:
|
|
61
55
|
__slots__ = ()
|
|
62
56
|
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _determine_operation_type(statement: "Any") -> OperationType:
|
|
59
|
+
"""Determine operation type from SQL statement expression.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
statement: SQL statement object with expression attribute
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
OperationType literal value
|
|
66
|
+
"""
|
|
67
|
+
if not hasattr(statement, "expression") or not statement.expression:
|
|
68
|
+
return "EXECUTE"
|
|
69
|
+
|
|
70
|
+
expr_type = type(statement.expression).__name__.upper()
|
|
71
|
+
if "INSERT" in expr_type:
|
|
72
|
+
return "INSERT"
|
|
73
|
+
if "UPDATE" in expr_type:
|
|
74
|
+
return "UPDATE"
|
|
75
|
+
if "DELETE" in expr_type:
|
|
76
|
+
return "DELETE"
|
|
77
|
+
if "SELECT" in expr_type:
|
|
78
|
+
return "SELECT"
|
|
79
|
+
return "EXECUTE"
|
|
80
|
+
|
|
63
81
|
@overload
|
|
64
82
|
@staticmethod
|
|
65
83
|
def to_schema(data: "ModelT", *, schema_type: None = None) -> "ModelT": ...
|