sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -9,8 +9,6 @@ and storage backend operations for optimal performance.
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
# pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false
|
|
12
|
-
import csv
|
|
13
|
-
import json
|
|
14
12
|
import logging
|
|
15
13
|
import tempfile
|
|
16
14
|
from abc import ABC
|
|
@@ -19,58 +17,37 @@ from pathlib import Path
|
|
|
19
17
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
|
|
20
18
|
from urllib.parse import urlparse
|
|
21
19
|
|
|
20
|
+
from sqlspec.driver.mixins._csv_writer import write_csv
|
|
21
|
+
from sqlspec.driver.parameters import separate_filters_and_parameters
|
|
22
22
|
from sqlspec.exceptions import MissingDependencyError
|
|
23
23
|
from sqlspec.statement import SQL, ArrowResult, StatementFilter
|
|
24
|
-
from sqlspec.statement.sql import SQLConfig
|
|
25
24
|
from sqlspec.storage import storage_registry
|
|
26
25
|
from sqlspec.typing import ArrowTable, RowT, StatementParameters
|
|
26
|
+
from sqlspec.utils.serializers import to_json
|
|
27
27
|
from sqlspec.utils.sync_tools import async_
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from sqlglot.dialects.dialect import DialectType
|
|
31
31
|
|
|
32
|
+
from sqlspec.protocols import ObjectStoreProtocol
|
|
32
33
|
from sqlspec.statement import SQLResult, Statement
|
|
33
|
-
from sqlspec.
|
|
34
|
+
from sqlspec.statement.sql import SQLConfig
|
|
34
35
|
from sqlspec.typing import ConnectionT
|
|
35
36
|
|
|
36
37
|
__all__ = ("AsyncStorageMixin", "SyncStorageMixin")
|
|
37
38
|
|
|
38
39
|
logger = logging.getLogger(__name__)
|
|
39
40
|
|
|
40
|
-
# Constants
|
|
41
41
|
WINDOWS_PATH_MIN_LENGTH = 3
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def _separate_filters_from_parameters(
|
|
45
|
-
parameters: "tuple[Any, ...]",
|
|
46
|
-
) -> "tuple[list[StatementFilter], Optional[StatementParameters]]":
|
|
47
|
-
"""Separate filters from parameters in positional args."""
|
|
48
|
-
filters: list[StatementFilter] = []
|
|
49
|
-
params: list[Any] = []
|
|
50
|
-
|
|
51
|
-
for arg in parameters:
|
|
52
|
-
if isinstance(arg, StatementFilter):
|
|
53
|
-
filters.append(arg)
|
|
54
|
-
else:
|
|
55
|
-
# Everything else is treated as parameters
|
|
56
|
-
params.append(arg)
|
|
57
|
-
|
|
58
|
-
# Convert to appropriate parameter format
|
|
59
|
-
if len(params) == 0:
|
|
60
|
-
return filters, None
|
|
61
|
-
if len(params) == 1:
|
|
62
|
-
return filters, params[0]
|
|
63
|
-
return filters, params
|
|
64
|
-
|
|
65
|
-
|
|
66
44
|
class StorageMixinBase(ABC):
|
|
67
45
|
"""Base class with common storage functionality."""
|
|
68
46
|
|
|
69
47
|
__slots__ = ()
|
|
70
48
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
_connection: Any # Database connection
|
|
49
|
+
config: Any
|
|
50
|
+
_connection: Any
|
|
74
51
|
dialect: "DialectType"
|
|
75
52
|
supports_native_parquet_export: "ClassVar[bool]"
|
|
76
53
|
supports_native_parquet_import: "ClassVar[bool]"
|
|
@@ -87,7 +64,6 @@ class StorageMixinBase(ABC):
|
|
|
87
64
|
@staticmethod
|
|
88
65
|
def _get_storage_backend(uri_or_key: "Union[str, Path]") -> "ObjectStoreProtocol":
|
|
89
66
|
"""Get storage backend by URI or key with intelligent routing."""
|
|
90
|
-
# Pass Path objects directly to storage registry for proper URI conversion
|
|
91
67
|
if isinstance(uri_or_key, Path):
|
|
92
68
|
return storage_registry.get(uri_or_key)
|
|
93
69
|
return storage_registry.get(str(uri_or_key))
|
|
@@ -134,18 +110,14 @@ class StorageMixinBase(ABC):
|
|
|
134
110
|
Returns:
|
|
135
111
|
Tuple of (backend, path) where path is relative to the backend's base path
|
|
136
112
|
"""
|
|
137
|
-
# Convert Path objects to string
|
|
138
113
|
uri_str = str(uri)
|
|
139
114
|
original_path = uri_str
|
|
140
115
|
|
|
141
|
-
# Convert absolute paths to file:// URIs if needed
|
|
142
116
|
if self._is_uri(uri_str) and "://" not in uri_str:
|
|
143
|
-
# It's an absolute path without scheme
|
|
144
117
|
uri_str = f"file://{uri_str}"
|
|
145
118
|
|
|
146
119
|
backend = self._get_storage_backend(uri_str)
|
|
147
120
|
|
|
148
|
-
# For file:// URIs, return just the path part for the backend
|
|
149
121
|
path = uri_str[7:] if uri_str.startswith("file://") else original_path
|
|
150
122
|
|
|
151
123
|
return backend, path
|
|
@@ -156,12 +128,9 @@ class StorageMixinBase(ABC):
|
|
|
156
128
|
import pyarrow as pa
|
|
157
129
|
|
|
158
130
|
if not rows:
|
|
159
|
-
# Empty table with column names
|
|
160
|
-
# Create empty arrays for each column
|
|
161
131
|
empty_data: dict[str, list[Any]] = {col: [] for col in columns}
|
|
162
132
|
return pa.table(empty_data)
|
|
163
133
|
|
|
164
|
-
# Convert rows to columnar format
|
|
165
134
|
if isinstance(rows[0], dict):
|
|
166
135
|
# Dict rows
|
|
167
136
|
data = {col: [cast("dict[str, Any]", row).get(col) for row in rows] for col in columns}
|
|
@@ -232,7 +201,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
232
201
|
"""
|
|
233
202
|
self._ensure_pyarrow_installed()
|
|
234
203
|
|
|
235
|
-
filters, params =
|
|
204
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
236
205
|
# Convert to SQL object for processing
|
|
237
206
|
# Use a custom config if transformations will add parameters
|
|
238
207
|
if _config is None:
|
|
@@ -246,9 +215,9 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
246
215
|
|
|
247
216
|
# Only pass params if it's not None to avoid adding None as a parameter
|
|
248
217
|
if params is not None:
|
|
249
|
-
sql = SQL(statement, params, *filters,
|
|
218
|
+
sql = SQL(statement, params, *filters, config=_config, **kwargs)
|
|
250
219
|
else:
|
|
251
|
-
sql = SQL(statement, *filters,
|
|
220
|
+
sql = SQL(statement, *filters, config=_config, **kwargs)
|
|
252
221
|
|
|
253
222
|
return self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
254
223
|
|
|
@@ -266,11 +235,9 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
266
235
|
Returns:
|
|
267
236
|
ArrowResult with converted data
|
|
268
237
|
"""
|
|
269
|
-
# Check if this SQL object has validation issues due to transformer-generated parameters
|
|
270
238
|
try:
|
|
271
239
|
result = cast("SQLResult", self.execute(sql, _connection=connection)) # type: ignore[attr-defined]
|
|
272
240
|
except Exception:
|
|
273
|
-
# Get the compiled SQL and parameters
|
|
274
241
|
compiled_sql, compiled_params = sql.compile("qmark")
|
|
275
242
|
|
|
276
243
|
# Execute directly via the driver's _execute method
|
|
@@ -320,21 +287,21 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
320
287
|
Returns:
|
|
321
288
|
Number of rows exported
|
|
322
289
|
"""
|
|
323
|
-
|
|
324
|
-
filters, params = _separate_filters_from_parameters(parameters)
|
|
290
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
325
291
|
|
|
326
292
|
# For storage operations, disable transformations that might add unwanted parameters
|
|
327
293
|
if _config is None:
|
|
328
294
|
_config = self.config
|
|
295
|
+
if _config and not _config.dialect:
|
|
296
|
+
_config = replace(_config, dialect=self.dialect)
|
|
329
297
|
if _config and _config.enable_transformations:
|
|
330
|
-
from dataclasses import replace
|
|
331
|
-
|
|
332
298
|
_config = replace(_config, enable_transformations=False)
|
|
333
299
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
300
|
+
sql = (
|
|
301
|
+
SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
|
|
302
|
+
)
|
|
303
|
+
for filter_ in filters:
|
|
304
|
+
sql = sql.filter(filter_)
|
|
338
305
|
|
|
339
306
|
return self._export_to_storage(
|
|
340
307
|
sql, destination_uri=destination_uri, format=format, _connection=_connection, **options
|
|
@@ -342,37 +309,22 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
342
309
|
|
|
343
310
|
def _export_to_storage(
|
|
344
311
|
self,
|
|
345
|
-
|
|
346
|
-
/,
|
|
347
|
-
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
312
|
+
sql: "SQL",
|
|
348
313
|
destination_uri: "Union[str, Path]",
|
|
349
314
|
format: "Optional[str]" = None,
|
|
350
315
|
_connection: "Optional[ConnectionT]" = None,
|
|
351
|
-
_config: "Optional[SQLConfig]" = None,
|
|
352
316
|
**kwargs: Any,
|
|
353
317
|
) -> int:
|
|
354
|
-
|
|
355
|
-
if hasattr(statement, "to_sql"): # SQL object
|
|
356
|
-
query_str = cast("SQL", statement).to_sql()
|
|
357
|
-
elif isinstance(statement, str):
|
|
358
|
-
query_str = statement
|
|
359
|
-
else: # sqlglot Expression
|
|
360
|
-
query_str = str(statement)
|
|
361
|
-
|
|
362
|
-
# Auto-detect format if not provided
|
|
363
|
-
# If no format is specified and detection fails (returns "csv" as default),
|
|
364
|
-
# default to "parquet" for export operations as it's the most common use case
|
|
318
|
+
"""Protected method for sync export operation implementation."""
|
|
365
319
|
detected_format = self._detect_format(destination_uri)
|
|
366
320
|
if format:
|
|
367
321
|
file_format = format
|
|
368
322
|
elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
|
|
369
323
|
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
370
|
-
# Default to parquet for better compatibility with tests and common usage
|
|
371
324
|
file_format = "parquet"
|
|
372
325
|
else:
|
|
373
326
|
file_format = detected_format
|
|
374
327
|
|
|
375
|
-
# Special handling for parquet format - if we're exporting to parquet but the
|
|
376
328
|
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
377
329
|
# with pyarrow.parquet.read_table() which requires the extension
|
|
378
330
|
if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
|
|
@@ -383,38 +335,22 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
383
335
|
|
|
384
336
|
# Try native database export first
|
|
385
337
|
if file_format == "parquet" and self.supports_native_parquet_export:
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
except NotImplementedError:
|
|
393
|
-
# Fall through to use storage backend
|
|
394
|
-
pass
|
|
338
|
+
try:
|
|
339
|
+
compiled_sql, _ = sql.compile(placeholder_style="static")
|
|
340
|
+
return self._export_native(compiled_sql, destination_uri, file_format, **kwargs)
|
|
341
|
+
except NotImplementedError:
|
|
342
|
+
# Fall through to use storage backend
|
|
343
|
+
pass
|
|
395
344
|
|
|
396
345
|
if file_format == "parquet":
|
|
397
|
-
# Use Arrow for efficient transfer
|
|
398
|
-
|
|
399
|
-
# For parquet export via Arrow, just use the SQL object directly
|
|
400
|
-
sql_obj = cast("SQL", statement)
|
|
401
|
-
# Pass connection parameter correctly
|
|
402
|
-
arrow_result = self._fetch_arrow_table(sql_obj, connection=_connection, **kwargs)
|
|
403
|
-
else:
|
|
404
|
-
# Create SQL object if it's still a string
|
|
405
|
-
arrow_result = self.fetch_arrow_table(statement, *parameters, _connection=_connection, _config=_config)
|
|
406
|
-
|
|
407
|
-
# ArrowResult.data is never None according to the type definition
|
|
346
|
+
# Use Arrow for efficient transfer
|
|
347
|
+
arrow_result = self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
408
348
|
arrow_table = arrow_result.data
|
|
409
349
|
num_rows = arrow_table.num_rows
|
|
410
350
|
backend.write_arrow(path, arrow_table, **kwargs)
|
|
411
351
|
return num_rows
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
sql_obj = SQL(statement, _config=_config, _dialect=self.dialect)
|
|
415
|
-
else:
|
|
416
|
-
sql_obj = cast("SQL", statement)
|
|
417
|
-
return self._export_via_backend(sql_obj, backend, path, file_format, **kwargs)
|
|
352
|
+
|
|
353
|
+
return self._export_via_backend(sql, backend, path, file_format, **kwargs)
|
|
418
354
|
|
|
419
355
|
def import_from_storage(
|
|
420
356
|
self,
|
|
@@ -550,10 +486,8 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
550
486
|
backend.write_arrow(path, arrow_table, **options)
|
|
551
487
|
return len(result.data or [])
|
|
552
488
|
|
|
553
|
-
# Convert to appropriate format and write to backend
|
|
554
489
|
compression = options.get("compression")
|
|
555
490
|
|
|
556
|
-
# Create temp file with appropriate suffix
|
|
557
491
|
suffix = f".{format}"
|
|
558
492
|
if compression == "gzip":
|
|
559
493
|
suffix += ".gz"
|
|
@@ -561,7 +495,6 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
561
495
|
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding="utf-8") as tmp:
|
|
562
496
|
tmp_path = Path(tmp.name)
|
|
563
497
|
|
|
564
|
-
# Handle compression and writing
|
|
565
498
|
if compression == "gzip":
|
|
566
499
|
import gzip
|
|
567
500
|
|
|
@@ -615,41 +548,24 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
615
548
|
@staticmethod
|
|
616
549
|
def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
617
550
|
"""Write result to CSV file."""
|
|
618
|
-
|
|
619
|
-
csv_options = options.copy()
|
|
620
|
-
csv_options.pop("compression", None) # Handle compression separately
|
|
621
|
-
csv_options.pop("partition_by", None) # Not applicable to CSV
|
|
622
|
-
|
|
623
|
-
writer = csv.writer(file, **csv_options) # TODO: anything better?
|
|
624
|
-
if result.column_names:
|
|
625
|
-
writer.writerow(result.column_names)
|
|
626
|
-
if result.data:
|
|
627
|
-
# Handle dict rows by extracting values in column order
|
|
628
|
-
if result.data and isinstance(result.data[0], dict):
|
|
629
|
-
rows = []
|
|
630
|
-
for row_dict in result.data:
|
|
631
|
-
# Extract values in the same order as column_names
|
|
632
|
-
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
633
|
-
rows.append(row_values)
|
|
634
|
-
writer.writerows(rows)
|
|
635
|
-
else:
|
|
636
|
-
writer.writerows(result.data)
|
|
551
|
+
write_csv(result, file, **options)
|
|
637
552
|
|
|
638
553
|
@staticmethod
|
|
639
554
|
def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
640
555
|
"""Write result to JSON file."""
|
|
556
|
+
_ = options
|
|
641
557
|
|
|
642
558
|
if result.data and result.column_names:
|
|
643
|
-
# Check if data is already in dict format
|
|
644
559
|
if result.data and isinstance(result.data[0], dict):
|
|
645
560
|
# Data is already dictionaries, use as-is
|
|
646
561
|
rows = result.data
|
|
647
562
|
else:
|
|
648
|
-
# Convert tuples/lists to list of dicts
|
|
649
563
|
rows = [dict(zip(result.column_names, row)) for row in result.data]
|
|
650
|
-
|
|
564
|
+
json_str = to_json(rows)
|
|
565
|
+
file.write(json_str)
|
|
651
566
|
else:
|
|
652
|
-
|
|
567
|
+
json_str = to_json([])
|
|
568
|
+
file.write(json_str)
|
|
653
569
|
|
|
654
570
|
def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
655
571
|
"""Database-specific bulk load implementation. Override in drivers."""
|
|
@@ -724,7 +640,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
724
640
|
"""
|
|
725
641
|
self._ensure_pyarrow_installed()
|
|
726
642
|
|
|
727
|
-
filters, params =
|
|
643
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
728
644
|
# Convert to SQL object for processing
|
|
729
645
|
# Use a custom config if transformations will add parameters
|
|
730
646
|
if _config is None:
|
|
@@ -733,18 +649,15 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
733
649
|
# If no parameters provided but we have transformations enabled,
|
|
734
650
|
# disable parameter validation entirely to allow transformer-added parameters
|
|
735
651
|
if params is None and _config and _config.enable_transformations:
|
|
736
|
-
from dataclasses import replace
|
|
737
|
-
|
|
738
652
|
# Disable validation entirely for transformer-generated parameters
|
|
739
653
|
_config = replace(_config, strict_mode=False, enable_validation=False)
|
|
740
654
|
|
|
741
655
|
# Only pass params if it's not None to avoid adding None as a parameter
|
|
742
656
|
if params is not None:
|
|
743
|
-
sql = SQL(statement, params, *filters,
|
|
657
|
+
sql = SQL(statement, params, *filters, config=_config, **kwargs)
|
|
744
658
|
else:
|
|
745
|
-
sql = SQL(statement, *filters,
|
|
659
|
+
sql = SQL(statement, *filters, config=_config, **kwargs)
|
|
746
660
|
|
|
747
|
-
# Delegate to protected method that drivers can override
|
|
748
661
|
return await self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
749
662
|
|
|
750
663
|
async def _fetch_arrow_table(
|
|
@@ -768,7 +681,6 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
768
681
|
# Execute regular query
|
|
769
682
|
result = await self.execute(sql, _connection=connection) # type: ignore[attr-defined]
|
|
770
683
|
|
|
771
|
-
# Convert to Arrow table
|
|
772
684
|
arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
|
|
773
685
|
|
|
774
686
|
return ArrowResult(statement=sql, data=arrow_table)
|
|
@@ -782,25 +694,25 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
782
694
|
format: "Optional[str]" = None,
|
|
783
695
|
_connection: "Optional[ConnectionT]" = None,
|
|
784
696
|
_config: "Optional[SQLConfig]" = None,
|
|
785
|
-
**
|
|
697
|
+
**kwargs: Any,
|
|
786
698
|
) -> int:
|
|
787
|
-
|
|
788
|
-
filters, params = _separate_filters_from_parameters(parameters)
|
|
699
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
789
700
|
|
|
790
701
|
# For storage operations, disable transformations that might add unwanted parameters
|
|
791
702
|
if _config is None:
|
|
792
703
|
_config = self.config
|
|
704
|
+
if _config and not _config.dialect:
|
|
705
|
+
_config = replace(_config, dialect=self.dialect)
|
|
793
706
|
if _config and _config.enable_transformations:
|
|
794
|
-
from dataclasses import replace
|
|
795
|
-
|
|
796
707
|
_config = replace(_config, enable_transformations=False)
|
|
797
708
|
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
709
|
+
sql = (
|
|
710
|
+
SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
|
|
711
|
+
)
|
|
712
|
+
for filter_ in filters:
|
|
713
|
+
sql = sql.filter(filter_)
|
|
802
714
|
|
|
803
|
-
return await self._export_to_storage(sql, destination_uri, format, connection=_connection, **
|
|
715
|
+
return await self._export_to_storage(sql, destination_uri, format, connection=_connection, **kwargs)
|
|
804
716
|
|
|
805
717
|
async def _export_to_storage(
|
|
806
718
|
self,
|
|
@@ -808,7 +720,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
808
720
|
destination_uri: "Union[str, Path]",
|
|
809
721
|
format: "Optional[str]" = None,
|
|
810
722
|
connection: "Optional[ConnectionT]" = None,
|
|
811
|
-
**
|
|
723
|
+
**kwargs: Any,
|
|
812
724
|
) -> int:
|
|
813
725
|
"""Protected async method for export operation implementation.
|
|
814
726
|
|
|
@@ -817,25 +729,21 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
817
729
|
destination_uri: URI to export data to
|
|
818
730
|
format: Optional format override (auto-detected from URI if not provided)
|
|
819
731
|
connection: Optional connection override
|
|
820
|
-
**
|
|
732
|
+
**kwargs: Additional export options
|
|
821
733
|
|
|
822
734
|
Returns:
|
|
823
735
|
Number of rows exported
|
|
824
736
|
"""
|
|
825
737
|
# Auto-detect format if not provided
|
|
826
|
-
# If no format is specified and detection fails (returns "csv" as default),
|
|
827
|
-
# default to "parquet" for export operations as it's the most common use case
|
|
828
738
|
detected_format = self._detect_format(destination_uri)
|
|
829
739
|
if format:
|
|
830
740
|
file_format = format
|
|
831
741
|
elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
|
|
832
742
|
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
833
|
-
# Default to parquet for better compatibility with tests and common usage
|
|
834
743
|
file_format = "parquet"
|
|
835
744
|
else:
|
|
836
745
|
file_format = detected_format
|
|
837
746
|
|
|
838
|
-
# Special handling for parquet format - if we're exporting to parquet but the
|
|
839
747
|
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
840
748
|
# with pyarrow.parquet.read_table() which requires the extension
|
|
841
749
|
if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
|
|
@@ -846,31 +754,23 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
846
754
|
|
|
847
755
|
# Try native database export first
|
|
848
756
|
if file_format == "parquet" and self.supports_native_parquet_export:
|
|
849
|
-
|
|
757
|
+
try:
|
|
758
|
+
compiled_sql, _ = query.compile(placeholder_style="static")
|
|
759
|
+
return await self._export_native(compiled_sql, destination_uri, file_format, **kwargs)
|
|
760
|
+
except NotImplementedError:
|
|
761
|
+
# Fall through to use storage backend
|
|
762
|
+
pass
|
|
850
763
|
|
|
851
764
|
if file_format == "parquet":
|
|
852
|
-
#
|
|
853
|
-
|
|
854
|
-
if hasattr(query, "parameters") and query.parameters and hasattr(query, "_raw_sql"):
|
|
855
|
-
# Create fresh SQL object from raw SQL without transformations
|
|
856
|
-
fresh_sql = SQL(
|
|
857
|
-
query._raw_sql,
|
|
858
|
-
_config=replace(self.config, enable_transformations=False)
|
|
859
|
-
if self.config
|
|
860
|
-
else SQLConfig(enable_transformations=False),
|
|
861
|
-
_dialect=self.dialect,
|
|
862
|
-
)
|
|
863
|
-
arrow_result = await self._fetch_arrow_table(fresh_sql, connection=connection, **options)
|
|
864
|
-
else:
|
|
865
|
-
# query is already a SQL object, call _fetch_arrow_table directly
|
|
866
|
-
arrow_result = await self._fetch_arrow_table(query, connection=connection, **options)
|
|
765
|
+
# Use Arrow for efficient transfer
|
|
766
|
+
arrow_result = await self._fetch_arrow_table(query, connection=connection, **kwargs)
|
|
867
767
|
arrow_table = arrow_result.data
|
|
868
768
|
if arrow_table is not None:
|
|
869
|
-
await backend.write_arrow_async(path, arrow_table, **
|
|
769
|
+
await backend.write_arrow_async(path, arrow_table, **kwargs)
|
|
870
770
|
return arrow_table.num_rows
|
|
871
771
|
return 0
|
|
872
772
|
|
|
873
|
-
return await self._export_via_backend(query, backend, path, file_format, **
|
|
773
|
+
return await self._export_via_backend(query, backend, path, file_format, **kwargs)
|
|
874
774
|
|
|
875
775
|
async def import_from_storage(
|
|
876
776
|
self,
|
|
@@ -964,7 +864,6 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
964
864
|
await backend.write_arrow_async(path, arrow_table, **options)
|
|
965
865
|
return len(result.data or [])
|
|
966
866
|
|
|
967
|
-
# Convert to appropriate format and write to backend
|
|
968
867
|
with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False, encoding="utf-8") as tmp:
|
|
969
868
|
if format == "csv":
|
|
970
869
|
self._write_csv(result, tmp, **options)
|
|
@@ -1002,37 +901,24 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
1002
901
|
@staticmethod
|
|
1003
902
|
def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
1004
903
|
"""Reuse sync implementation."""
|
|
1005
|
-
|
|
1006
|
-
writer = csv.writer(file, **options)
|
|
1007
|
-
if result.column_names:
|
|
1008
|
-
writer.writerow(result.column_names)
|
|
1009
|
-
if result.data:
|
|
1010
|
-
# Handle dict rows by extracting values in column order
|
|
1011
|
-
if result.data and isinstance(result.data[0], dict):
|
|
1012
|
-
rows = []
|
|
1013
|
-
for row_dict in result.data:
|
|
1014
|
-
# Extract values in the same order as column_names
|
|
1015
|
-
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
1016
|
-
rows.append(row_values)
|
|
1017
|
-
writer.writerows(rows)
|
|
1018
|
-
else:
|
|
1019
|
-
writer.writerows(result.data)
|
|
904
|
+
write_csv(result, file, **options)
|
|
1020
905
|
|
|
1021
906
|
@staticmethod
|
|
1022
907
|
def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
1023
908
|
"""Reuse sync implementation."""
|
|
909
|
+
_ = options # May be used in the future for JSON formatting options
|
|
1024
910
|
|
|
1025
911
|
if result.data and result.column_names:
|
|
1026
|
-
# Check if data is already in dict format
|
|
1027
912
|
if result.data and isinstance(result.data[0], dict):
|
|
1028
913
|
# Data is already dictionaries, use as-is
|
|
1029
914
|
rows = result.data
|
|
1030
915
|
else:
|
|
1031
|
-
# Convert tuples/lists to list of dicts
|
|
1032
916
|
rows = [dict(zip(result.column_names, row)) for row in result.data]
|
|
1033
|
-
|
|
917
|
+
json_str = to_json(rows)
|
|
918
|
+
file.write(json_str)
|
|
1034
919
|
else:
|
|
1035
|
-
|
|
920
|
+
json_str = to_json([])
|
|
921
|
+
file.write(json_str)
|
|
1036
922
|
|
|
1037
923
|
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
1038
924
|
"""Async database-specific bulk load implementation."""
|
|
@@ -7,6 +7,8 @@ TypedParameter objects and perform appropriate type conversions.
|
|
|
7
7
|
from decimal import Decimal
|
|
8
8
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
9
|
|
|
10
|
+
from sqlspec.utils.type_guards import has_parameter_value
|
|
11
|
+
|
|
10
12
|
if TYPE_CHECKING:
|
|
11
13
|
from sqlspec.typing import SQLParameterType
|
|
12
14
|
|
|
@@ -68,13 +70,10 @@ class TypeCoercionMixin:
|
|
|
68
70
|
Returns:
|
|
69
71
|
Coerced parameter value suitable for the database
|
|
70
72
|
"""
|
|
71
|
-
|
|
72
|
-
if hasattr(param, "__class__") and param.__class__.__name__ == "TypedParameter":
|
|
73
|
-
# Extract value and type hint
|
|
73
|
+
if has_parameter_value(param):
|
|
74
74
|
value = param.value
|
|
75
75
|
type_hint = param.type_hint
|
|
76
76
|
|
|
77
|
-
# Apply driver-specific coercion based on type hint
|
|
78
77
|
return self._apply_type_coercion(value, type_hint)
|
|
79
78
|
# Regular parameter - apply default coercion
|
|
80
79
|
return self._apply_type_coercion(param, None)
|