sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -9,8 +9,6 @@ and storage backend operations for optimal performance.
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
# pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false
|
|
12
|
-
import csv
|
|
13
|
-
import json
|
|
14
12
|
import logging
|
|
15
13
|
import tempfile
|
|
16
14
|
from abc import ABC
|
|
@@ -19,58 +17,37 @@ from pathlib import Path
|
|
|
19
17
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
|
|
20
18
|
from urllib.parse import urlparse
|
|
21
19
|
|
|
20
|
+
from sqlspec.driver.mixins._csv_writer import write_csv
|
|
21
|
+
from sqlspec.driver.parameters import separate_filters_and_parameters
|
|
22
22
|
from sqlspec.exceptions import MissingDependencyError
|
|
23
23
|
from sqlspec.statement import SQL, ArrowResult, StatementFilter
|
|
24
|
-
from sqlspec.statement.sql import SQLConfig
|
|
25
24
|
from sqlspec.storage import storage_registry
|
|
26
25
|
from sqlspec.typing import ArrowTable, RowT, StatementParameters
|
|
26
|
+
from sqlspec.utils.serializers import to_json
|
|
27
27
|
from sqlspec.utils.sync_tools import async_
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from sqlglot.dialects.dialect import DialectType
|
|
31
31
|
|
|
32
|
+
from sqlspec.protocols import ObjectStoreProtocol
|
|
32
33
|
from sqlspec.statement import SQLResult, Statement
|
|
33
|
-
from sqlspec.
|
|
34
|
+
from sqlspec.statement.sql import SQLConfig
|
|
34
35
|
from sqlspec.typing import ConnectionT
|
|
35
36
|
|
|
36
37
|
__all__ = ("AsyncStorageMixin", "SyncStorageMixin")
|
|
37
38
|
|
|
38
39
|
logger = logging.getLogger(__name__)
|
|
39
40
|
|
|
40
|
-
# Constants
|
|
41
41
|
WINDOWS_PATH_MIN_LENGTH = 3
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def _separate_filters_from_parameters(
|
|
45
|
-
parameters: "tuple[Any, ...]",
|
|
46
|
-
) -> "tuple[list[StatementFilter], Optional[StatementParameters]]":
|
|
47
|
-
"""Separate filters from parameters in positional args."""
|
|
48
|
-
filters: list[StatementFilter] = []
|
|
49
|
-
params: list[Any] = []
|
|
50
|
-
|
|
51
|
-
for arg in parameters:
|
|
52
|
-
if isinstance(arg, StatementFilter):
|
|
53
|
-
filters.append(arg)
|
|
54
|
-
else:
|
|
55
|
-
# Everything else is treated as parameters
|
|
56
|
-
params.append(arg)
|
|
57
|
-
|
|
58
|
-
# Convert to appropriate parameter format
|
|
59
|
-
if len(params) == 0:
|
|
60
|
-
return filters, None
|
|
61
|
-
if len(params) == 1:
|
|
62
|
-
return filters, params[0]
|
|
63
|
-
return filters, params
|
|
64
|
-
|
|
65
|
-
|
|
66
44
|
class StorageMixinBase(ABC):
|
|
67
45
|
"""Base class with common storage functionality."""
|
|
68
46
|
|
|
69
47
|
__slots__ = ()
|
|
70
48
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
_connection: Any # Database connection
|
|
49
|
+
config: Any
|
|
50
|
+
_connection: Any
|
|
74
51
|
dialect: "DialectType"
|
|
75
52
|
supports_native_parquet_export: "ClassVar[bool]"
|
|
76
53
|
supports_native_parquet_import: "ClassVar[bool]"
|
|
@@ -85,25 +62,29 @@ class StorageMixinBase(ABC):
|
|
|
85
62
|
raise MissingDependencyError(msg)
|
|
86
63
|
|
|
87
64
|
@staticmethod
|
|
88
|
-
def _get_storage_backend(uri_or_key: str) -> "ObjectStoreProtocol":
|
|
65
|
+
def _get_storage_backend(uri_or_key: "Union[str, Path]") -> "ObjectStoreProtocol":
|
|
89
66
|
"""Get storage backend by URI or key with intelligent routing."""
|
|
90
|
-
|
|
67
|
+
if isinstance(uri_or_key, Path):
|
|
68
|
+
return storage_registry.get(uri_or_key)
|
|
69
|
+
return storage_registry.get(str(uri_or_key))
|
|
91
70
|
|
|
92
71
|
@staticmethod
|
|
93
|
-
def _is_uri(path_or_uri: str) -> bool:
|
|
72
|
+
def _is_uri(path_or_uri: "Union[str, Path]") -> bool:
|
|
94
73
|
"""Check if input is a URI rather than a relative path."""
|
|
74
|
+
path_str = str(path_or_uri)
|
|
95
75
|
schemes = {"s3", "gs", "gcs", "az", "azure", "abfs", "abfss", "file", "http", "https"}
|
|
96
|
-
if "://" in
|
|
97
|
-
scheme =
|
|
76
|
+
if "://" in path_str:
|
|
77
|
+
scheme = path_str.split("://", maxsplit=1)[0].lower()
|
|
98
78
|
return scheme in schemes
|
|
99
|
-
if len(
|
|
79
|
+
if len(path_str) >= WINDOWS_PATH_MIN_LENGTH and path_str[1:3] == ":\\":
|
|
100
80
|
return True
|
|
101
|
-
return bool(
|
|
81
|
+
return bool(path_str.startswith("/"))
|
|
102
82
|
|
|
103
83
|
@staticmethod
|
|
104
|
-
def _detect_format(uri: str) -> str:
|
|
84
|
+
def _detect_format(uri: "Union[str, Path]") -> str:
|
|
105
85
|
"""Detect file format from URI extension."""
|
|
106
|
-
|
|
86
|
+
uri_str = str(uri)
|
|
87
|
+
parsed = urlparse(uri_str)
|
|
107
88
|
path = Path(parsed.path)
|
|
108
89
|
extension = path.suffix.lower().lstrip(".")
|
|
109
90
|
|
|
@@ -120,28 +101,24 @@ class StorageMixinBase(ABC):
|
|
|
120
101
|
|
|
121
102
|
return format_map.get(extension, "csv")
|
|
122
103
|
|
|
123
|
-
def _resolve_backend_and_path(self, uri: str) -> "tuple[ObjectStoreProtocol, str]":
|
|
104
|
+
def _resolve_backend_and_path(self, uri: "Union[str, Path]") -> "tuple[ObjectStoreProtocol, str]":
|
|
124
105
|
"""Resolve backend and path from URI with Phase 3 URI-first routing.
|
|
125
106
|
|
|
126
107
|
Args:
|
|
127
|
-
uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path")
|
|
108
|
+
uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path", Path object)
|
|
128
109
|
|
|
129
110
|
Returns:
|
|
130
111
|
Tuple of (backend, path) where path is relative to the backend's base path
|
|
131
112
|
"""
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
original_path = uri
|
|
113
|
+
uri_str = str(uri)
|
|
114
|
+
original_path = uri_str
|
|
135
115
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
# It's an absolute path without scheme
|
|
139
|
-
uri = f"file://{uri}"
|
|
116
|
+
if self._is_uri(uri_str) and "://" not in uri_str:
|
|
117
|
+
uri_str = f"file://{uri_str}"
|
|
140
118
|
|
|
141
|
-
backend = self._get_storage_backend(
|
|
119
|
+
backend = self._get_storage_backend(uri_str)
|
|
142
120
|
|
|
143
|
-
|
|
144
|
-
path = uri[7:] if uri.startswith("file://") else original_path
|
|
121
|
+
path = uri_str[7:] if uri_str.startswith("file://") else original_path
|
|
145
122
|
|
|
146
123
|
return backend, path
|
|
147
124
|
|
|
@@ -151,12 +128,9 @@ class StorageMixinBase(ABC):
|
|
|
151
128
|
import pyarrow as pa
|
|
152
129
|
|
|
153
130
|
if not rows:
|
|
154
|
-
# Empty table with column names
|
|
155
|
-
# Create empty arrays for each column
|
|
156
131
|
empty_data: dict[str, list[Any]] = {col: [] for col in columns}
|
|
157
132
|
return pa.table(empty_data)
|
|
158
133
|
|
|
159
|
-
# Convert rows to columnar format
|
|
160
134
|
if isinstance(rows[0], dict):
|
|
161
135
|
# Dict rows
|
|
162
136
|
data = {col: [cast("dict[str, Any]", row).get(col) for row in rows] for col in columns}
|
|
@@ -227,7 +201,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
227
201
|
"""
|
|
228
202
|
self._ensure_pyarrow_installed()
|
|
229
203
|
|
|
230
|
-
filters, params =
|
|
204
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
231
205
|
# Convert to SQL object for processing
|
|
232
206
|
# Use a custom config if transformations will add parameters
|
|
233
207
|
if _config is None:
|
|
@@ -241,9 +215,9 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
241
215
|
|
|
242
216
|
# Only pass params if it's not None to avoid adding None as a parameter
|
|
243
217
|
if params is not None:
|
|
244
|
-
sql = SQL(statement, params, *filters,
|
|
218
|
+
sql = SQL(statement, params, *filters, config=_config, **kwargs)
|
|
245
219
|
else:
|
|
246
|
-
sql = SQL(statement, *filters,
|
|
220
|
+
sql = SQL(statement, *filters, config=_config, **kwargs)
|
|
247
221
|
|
|
248
222
|
return self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
249
223
|
|
|
@@ -261,11 +235,9 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
261
235
|
Returns:
|
|
262
236
|
ArrowResult with converted data
|
|
263
237
|
"""
|
|
264
|
-
# Check if this SQL object has validation issues due to transformer-generated parameters
|
|
265
238
|
try:
|
|
266
239
|
result = cast("SQLResult", self.execute(sql, _connection=connection)) # type: ignore[attr-defined]
|
|
267
240
|
except Exception:
|
|
268
|
-
# Get the compiled SQL and parameters
|
|
269
241
|
compiled_sql, compiled_params = sql.compile("qmark")
|
|
270
242
|
|
|
271
243
|
# Execute directly via the driver's _execute method
|
|
@@ -293,7 +265,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
293
265
|
statement: "Statement",
|
|
294
266
|
/,
|
|
295
267
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
296
|
-
destination_uri: str,
|
|
268
|
+
destination_uri: "Union[str, Path]",
|
|
297
269
|
format: "Optional[str]" = None,
|
|
298
270
|
_connection: "Optional[ConnectionT]" = None,
|
|
299
271
|
_config: "Optional[SQLConfig]" = None,
|
|
@@ -315,21 +287,21 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
315
287
|
Returns:
|
|
316
288
|
Number of rows exported
|
|
317
289
|
"""
|
|
318
|
-
|
|
319
|
-
filters, params = _separate_filters_from_parameters(parameters)
|
|
290
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
320
291
|
|
|
321
292
|
# For storage operations, disable transformations that might add unwanted parameters
|
|
322
293
|
if _config is None:
|
|
323
294
|
_config = self.config
|
|
295
|
+
if _config and not _config.dialect:
|
|
296
|
+
_config = replace(_config, dialect=self.dialect)
|
|
324
297
|
if _config and _config.enable_transformations:
|
|
325
|
-
from dataclasses import replace
|
|
326
|
-
|
|
327
298
|
_config = replace(_config, enable_transformations=False)
|
|
328
299
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
300
|
+
sql = (
|
|
301
|
+
SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
|
|
302
|
+
)
|
|
303
|
+
for filter_ in filters:
|
|
304
|
+
sql = sql.filter(filter_)
|
|
333
305
|
|
|
334
306
|
return self._export_to_storage(
|
|
335
307
|
sql, destination_uri=destination_uri, format=format, _connection=_connection, **options
|
|
@@ -337,40 +309,25 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
337
309
|
|
|
338
310
|
def _export_to_storage(
|
|
339
311
|
self,
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
343
|
-
destination_uri: str,
|
|
312
|
+
sql: "SQL",
|
|
313
|
+
destination_uri: "Union[str, Path]",
|
|
344
314
|
format: "Optional[str]" = None,
|
|
345
315
|
_connection: "Optional[ConnectionT]" = None,
|
|
346
|
-
_config: "Optional[SQLConfig]" = None,
|
|
347
316
|
**kwargs: Any,
|
|
348
317
|
) -> int:
|
|
349
|
-
|
|
350
|
-
if hasattr(statement, "to_sql"): # SQL object
|
|
351
|
-
query_str = cast("SQL", statement).to_sql()
|
|
352
|
-
elif isinstance(statement, str):
|
|
353
|
-
query_str = statement
|
|
354
|
-
else: # sqlglot Expression
|
|
355
|
-
query_str = str(statement)
|
|
356
|
-
|
|
357
|
-
# Auto-detect format if not provided
|
|
358
|
-
# If no format is specified and detection fails (returns "csv" as default),
|
|
359
|
-
# default to "parquet" for export operations as it's the most common use case
|
|
318
|
+
"""Protected method for sync export operation implementation."""
|
|
360
319
|
detected_format = self._detect_format(destination_uri)
|
|
361
320
|
if format:
|
|
362
321
|
file_format = format
|
|
363
|
-
elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
|
|
322
|
+
elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
|
|
364
323
|
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
365
|
-
# Default to parquet for better compatibility with tests and common usage
|
|
366
324
|
file_format = "parquet"
|
|
367
325
|
else:
|
|
368
326
|
file_format = detected_format
|
|
369
327
|
|
|
370
|
-
# Special handling for parquet format - if we're exporting to parquet but the
|
|
371
328
|
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
372
329
|
# with pyarrow.parquet.read_table() which requires the extension
|
|
373
|
-
if file_format == "parquet" and not destination_uri.endswith(".parquet"):
|
|
330
|
+
if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
|
|
374
331
|
destination_uri = f"{destination_uri}.parquet"
|
|
375
332
|
|
|
376
333
|
# Use storage backend - resolve AFTER modifying destination_uri
|
|
@@ -378,41 +335,30 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
378
335
|
|
|
379
336
|
# Try native database export first
|
|
380
337
|
if file_format == "parquet" and self.supports_native_parquet_export:
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
except NotImplementedError:
|
|
388
|
-
# Fall through to use storage backend
|
|
389
|
-
pass
|
|
338
|
+
try:
|
|
339
|
+
compiled_sql, _ = sql.compile(placeholder_style="static")
|
|
340
|
+
return self._export_native(compiled_sql, destination_uri, file_format, **kwargs)
|
|
341
|
+
except NotImplementedError:
|
|
342
|
+
# Fall through to use storage backend
|
|
343
|
+
pass
|
|
390
344
|
|
|
391
345
|
if file_format == "parquet":
|
|
392
|
-
# Use Arrow for efficient transfer
|
|
393
|
-
|
|
394
|
-
# For parquet export via Arrow, just use the SQL object directly
|
|
395
|
-
sql_obj = cast("SQL", statement)
|
|
396
|
-
# Pass connection parameter correctly
|
|
397
|
-
arrow_result = self._fetch_arrow_table(sql_obj, connection=_connection, **kwargs)
|
|
398
|
-
else:
|
|
399
|
-
# Create SQL object if it's still a string
|
|
400
|
-
arrow_result = self.fetch_arrow_table(statement, *parameters, _connection=_connection, _config=_config)
|
|
401
|
-
|
|
402
|
-
# ArrowResult.data is never None according to the type definition
|
|
346
|
+
# Use Arrow for efficient transfer
|
|
347
|
+
arrow_result = self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
403
348
|
arrow_table = arrow_result.data
|
|
404
349
|
num_rows = arrow_table.num_rows
|
|
405
350
|
backend.write_arrow(path, arrow_table, **kwargs)
|
|
406
351
|
return num_rows
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
sql_obj = SQL(statement, _config=_config, _dialect=self.dialect)
|
|
410
|
-
else:
|
|
411
|
-
sql_obj = cast("SQL", statement)
|
|
412
|
-
return self._export_via_backend(sql_obj, backend, path, file_format, **kwargs)
|
|
352
|
+
|
|
353
|
+
return self._export_via_backend(sql, backend, path, file_format, **kwargs)
|
|
413
354
|
|
|
414
355
|
def import_from_storage(
|
|
415
|
-
self,
|
|
356
|
+
self,
|
|
357
|
+
source_uri: "Union[str, Path]",
|
|
358
|
+
table_name: str,
|
|
359
|
+
format: "Optional[str]" = None,
|
|
360
|
+
mode: str = "create",
|
|
361
|
+
**options: Any,
|
|
416
362
|
) -> int:
|
|
417
363
|
"""Import data from storage with intelligent routing.
|
|
418
364
|
|
|
@@ -431,7 +377,12 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
431
377
|
return self._import_from_storage(source_uri, table_name, format, mode, **options)
|
|
432
378
|
|
|
433
379
|
def _import_from_storage(
|
|
434
|
-
self,
|
|
380
|
+
self,
|
|
381
|
+
source_uri: "Union[str, Path]",
|
|
382
|
+
table_name: str,
|
|
383
|
+
format: "Optional[str]" = None,
|
|
384
|
+
mode: str = "create",
|
|
385
|
+
**options: Any,
|
|
435
386
|
) -> int:
|
|
436
387
|
"""Protected method for import operation implementation.
|
|
437
388
|
|
|
@@ -461,7 +412,23 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
461
412
|
arrow_table = backend.read_arrow(path, **options)
|
|
462
413
|
return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
|
|
463
414
|
except AttributeError:
|
|
464
|
-
|
|
415
|
+
# Backend doesn't support read_arrow, try alternative approach
|
|
416
|
+
try:
|
|
417
|
+
import pyarrow.parquet as pq
|
|
418
|
+
|
|
419
|
+
# Read Parquet file directly
|
|
420
|
+
with tempfile.NamedTemporaryFile(mode="wb", suffix=".parquet", delete=False) as tmp:
|
|
421
|
+
tmp.write(backend.read_bytes(path))
|
|
422
|
+
tmp_path = Path(tmp.name)
|
|
423
|
+
try:
|
|
424
|
+
arrow_table = pq.read_table(tmp_path)
|
|
425
|
+
return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
|
|
426
|
+
finally:
|
|
427
|
+
tmp_path.unlink(missing_ok=True)
|
|
428
|
+
except ImportError:
|
|
429
|
+
# PyArrow not installed, cannot import Parquet
|
|
430
|
+
msg = "PyArrow is required to import Parquet files. Install with: pip install pyarrow"
|
|
431
|
+
raise ImportError(msg) from None
|
|
465
432
|
|
|
466
433
|
# Use traditional import through temporary file
|
|
467
434
|
return self._import_via_backend(backend, path, table_name, file_format, mode, **options)
|
|
@@ -471,23 +438,27 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
471
438
|
# ============================================================================
|
|
472
439
|
|
|
473
440
|
def _read_parquet_native(
|
|
474
|
-
self, source_uri: str, columns: "Optional[list[str]]" = None, **options: Any
|
|
441
|
+
self, source_uri: "Union[str, Path]", columns: "Optional[list[str]]" = None, **options: Any
|
|
475
442
|
) -> "SQLResult":
|
|
476
443
|
"""Database-specific native Parquet reading. Override in drivers."""
|
|
477
444
|
msg = "Driver should implement _read_parquet_native"
|
|
478
445
|
raise NotImplementedError(msg)
|
|
479
446
|
|
|
480
|
-
def _write_parquet_native(
|
|
447
|
+
def _write_parquet_native(
|
|
448
|
+
self, data: Union[str, ArrowTable], destination_uri: "Union[str, Path]", **options: Any
|
|
449
|
+
) -> None:
|
|
481
450
|
"""Database-specific native Parquet writing. Override in drivers."""
|
|
482
451
|
msg = "Driver should implement _write_parquet_native"
|
|
483
452
|
raise NotImplementedError(msg)
|
|
484
453
|
|
|
485
|
-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
454
|
+
def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
|
|
486
455
|
"""Database-specific native export. Override in drivers."""
|
|
487
456
|
msg = "Driver should implement _export_native"
|
|
488
457
|
raise NotImplementedError(msg)
|
|
489
458
|
|
|
490
|
-
def _import_native(
|
|
459
|
+
def _import_native(
|
|
460
|
+
self, source_uri: "Union[str, Path]", table_name: str, format: str, mode: str, **options: Any
|
|
461
|
+
) -> int:
|
|
491
462
|
"""Database-specific native import. Override in drivers."""
|
|
492
463
|
msg = "Driver should implement _import_native"
|
|
493
464
|
raise NotImplementedError(msg)
|
|
@@ -515,10 +486,8 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
515
486
|
backend.write_arrow(path, arrow_table, **options)
|
|
516
487
|
return len(result.data or [])
|
|
517
488
|
|
|
518
|
-
# Convert to appropriate format and write to backend
|
|
519
489
|
compression = options.get("compression")
|
|
520
490
|
|
|
521
|
-
# Create temp file with appropriate suffix
|
|
522
491
|
suffix = f".{format}"
|
|
523
492
|
if compression == "gzip":
|
|
524
493
|
suffix += ".gz"
|
|
@@ -526,7 +495,6 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
526
495
|
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding="utf-8") as tmp:
|
|
527
496
|
tmp_path = Path(tmp.name)
|
|
528
497
|
|
|
529
|
-
# Handle compression and writing
|
|
530
498
|
if compression == "gzip":
|
|
531
499
|
import gzip
|
|
532
500
|
|
|
@@ -580,41 +548,24 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
580
548
|
@staticmethod
|
|
581
549
|
def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
582
550
|
"""Write result to CSV file."""
|
|
583
|
-
|
|
584
|
-
csv_options = options.copy()
|
|
585
|
-
csv_options.pop("compression", None) # Handle compression separately
|
|
586
|
-
csv_options.pop("partition_by", None) # Not applicable to CSV
|
|
587
|
-
|
|
588
|
-
writer = csv.writer(file, **csv_options) # TODO: anything better?
|
|
589
|
-
if result.column_names:
|
|
590
|
-
writer.writerow(result.column_names)
|
|
591
|
-
if result.data:
|
|
592
|
-
# Handle dict rows by extracting values in column order
|
|
593
|
-
if result.data and isinstance(result.data[0], dict):
|
|
594
|
-
rows = []
|
|
595
|
-
for row_dict in result.data:
|
|
596
|
-
# Extract values in the same order as column_names
|
|
597
|
-
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
598
|
-
rows.append(row_values)
|
|
599
|
-
writer.writerows(rows)
|
|
600
|
-
else:
|
|
601
|
-
writer.writerows(result.data)
|
|
551
|
+
write_csv(result, file, **options)
|
|
602
552
|
|
|
603
553
|
@staticmethod
|
|
604
554
|
def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
605
555
|
"""Write result to JSON file."""
|
|
556
|
+
_ = options
|
|
606
557
|
|
|
607
558
|
if result.data and result.column_names:
|
|
608
|
-
# Check if data is already in dict format
|
|
609
559
|
if result.data and isinstance(result.data[0], dict):
|
|
610
560
|
# Data is already dictionaries, use as-is
|
|
611
561
|
rows = result.data
|
|
612
562
|
else:
|
|
613
|
-
# Convert tuples/lists to list of dicts
|
|
614
563
|
rows = [dict(zip(result.column_names, row)) for row in result.data]
|
|
615
|
-
|
|
564
|
+
json_str = to_json(rows)
|
|
565
|
+
file.write(json_str)
|
|
616
566
|
else:
|
|
617
|
-
|
|
567
|
+
json_str = to_json([])
|
|
568
|
+
file.write(json_str)
|
|
618
569
|
|
|
619
570
|
def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
620
571
|
"""Database-specific bulk load implementation. Override in drivers."""
|
|
@@ -689,7 +640,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
689
640
|
"""
|
|
690
641
|
self._ensure_pyarrow_installed()
|
|
691
642
|
|
|
692
|
-
filters, params =
|
|
643
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
693
644
|
# Convert to SQL object for processing
|
|
694
645
|
# Use a custom config if transformations will add parameters
|
|
695
646
|
if _config is None:
|
|
@@ -698,18 +649,15 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
698
649
|
# If no parameters provided but we have transformations enabled,
|
|
699
650
|
# disable parameter validation entirely to allow transformer-added parameters
|
|
700
651
|
if params is None and _config and _config.enable_transformations:
|
|
701
|
-
from dataclasses import replace
|
|
702
|
-
|
|
703
652
|
# Disable validation entirely for transformer-generated parameters
|
|
704
653
|
_config = replace(_config, strict_mode=False, enable_validation=False)
|
|
705
654
|
|
|
706
655
|
# Only pass params if it's not None to avoid adding None as a parameter
|
|
707
656
|
if params is not None:
|
|
708
|
-
sql = SQL(statement, params, *filters,
|
|
657
|
+
sql = SQL(statement, params, *filters, config=_config, **kwargs)
|
|
709
658
|
else:
|
|
710
|
-
sql = SQL(statement, *filters,
|
|
659
|
+
sql = SQL(statement, *filters, config=_config, **kwargs)
|
|
711
660
|
|
|
712
|
-
# Delegate to protected method that drivers can override
|
|
713
661
|
return await self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
714
662
|
|
|
715
663
|
async def _fetch_arrow_table(
|
|
@@ -733,7 +681,6 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
733
681
|
# Execute regular query
|
|
734
682
|
result = await self.execute(sql, _connection=connection) # type: ignore[attr-defined]
|
|
735
683
|
|
|
736
|
-
# Convert to Arrow table
|
|
737
684
|
arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
|
|
738
685
|
|
|
739
686
|
return ArrowResult(statement=sql, data=arrow_table)
|
|
@@ -743,37 +690,37 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
743
690
|
statement: "Statement",
|
|
744
691
|
/,
|
|
745
692
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
746
|
-
destination_uri: str,
|
|
693
|
+
destination_uri: "Union[str, Path]",
|
|
747
694
|
format: "Optional[str]" = None,
|
|
748
695
|
_connection: "Optional[ConnectionT]" = None,
|
|
749
696
|
_config: "Optional[SQLConfig]" = None,
|
|
750
|
-
**
|
|
697
|
+
**kwargs: Any,
|
|
751
698
|
) -> int:
|
|
752
|
-
|
|
753
|
-
filters, params = _separate_filters_from_parameters(parameters)
|
|
699
|
+
filters, params = separate_filters_and_parameters(parameters)
|
|
754
700
|
|
|
755
701
|
# For storage operations, disable transformations that might add unwanted parameters
|
|
756
702
|
if _config is None:
|
|
757
703
|
_config = self.config
|
|
704
|
+
if _config and not _config.dialect:
|
|
705
|
+
_config = replace(_config, dialect=self.dialect)
|
|
758
706
|
if _config and _config.enable_transformations:
|
|
759
|
-
from dataclasses import replace
|
|
760
|
-
|
|
761
707
|
_config = replace(_config, enable_transformations=False)
|
|
762
708
|
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
709
|
+
sql = (
|
|
710
|
+
SQL(statement, parameters=params, config=_config) if params is not None else SQL(statement, config=_config)
|
|
711
|
+
)
|
|
712
|
+
for filter_ in filters:
|
|
713
|
+
sql = sql.filter(filter_)
|
|
767
714
|
|
|
768
|
-
return await self._export_to_storage(sql, destination_uri, format, connection=_connection, **
|
|
715
|
+
return await self._export_to_storage(sql, destination_uri, format, connection=_connection, **kwargs)
|
|
769
716
|
|
|
770
717
|
async def _export_to_storage(
|
|
771
718
|
self,
|
|
772
719
|
query: "SQL",
|
|
773
|
-
destination_uri: str,
|
|
720
|
+
destination_uri: "Union[str, Path]",
|
|
774
721
|
format: "Optional[str]" = None,
|
|
775
722
|
connection: "Optional[ConnectionT]" = None,
|
|
776
|
-
**
|
|
723
|
+
**kwargs: Any,
|
|
777
724
|
) -> int:
|
|
778
725
|
"""Protected async method for export operation implementation.
|
|
779
726
|
|
|
@@ -782,28 +729,24 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
782
729
|
destination_uri: URI to export data to
|
|
783
730
|
format: Optional format override (auto-detected from URI if not provided)
|
|
784
731
|
connection: Optional connection override
|
|
785
|
-
**
|
|
732
|
+
**kwargs: Additional export options
|
|
786
733
|
|
|
787
734
|
Returns:
|
|
788
735
|
Number of rows exported
|
|
789
736
|
"""
|
|
790
737
|
# Auto-detect format if not provided
|
|
791
|
-
# If no format is specified and detection fails (returns "csv" as default),
|
|
792
|
-
# default to "parquet" for export operations as it's the most common use case
|
|
793
738
|
detected_format = self._detect_format(destination_uri)
|
|
794
739
|
if format:
|
|
795
740
|
file_format = format
|
|
796
|
-
elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
|
|
741
|
+
elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
|
|
797
742
|
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
798
|
-
# Default to parquet for better compatibility with tests and common usage
|
|
799
743
|
file_format = "parquet"
|
|
800
744
|
else:
|
|
801
745
|
file_format = detected_format
|
|
802
746
|
|
|
803
|
-
# Special handling for parquet format - if we're exporting to parquet but the
|
|
804
747
|
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
805
748
|
# with pyarrow.parquet.read_table() which requires the extension
|
|
806
|
-
if file_format == "parquet" and not destination_uri.endswith(".parquet"):
|
|
749
|
+
if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
|
|
807
750
|
destination_uri = f"{destination_uri}.parquet"
|
|
808
751
|
|
|
809
752
|
# Use storage backend - resolve AFTER modifying destination_uri
|
|
@@ -811,34 +754,31 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
811
754
|
|
|
812
755
|
# Try native database export first
|
|
813
756
|
if file_format == "parquet" and self.supports_native_parquet_export:
|
|
814
|
-
|
|
757
|
+
try:
|
|
758
|
+
compiled_sql, _ = query.compile(placeholder_style="static")
|
|
759
|
+
return await self._export_native(compiled_sql, destination_uri, file_format, **kwargs)
|
|
760
|
+
except NotImplementedError:
|
|
761
|
+
# Fall through to use storage backend
|
|
762
|
+
pass
|
|
815
763
|
|
|
816
764
|
if file_format == "parquet":
|
|
817
|
-
#
|
|
818
|
-
|
|
819
|
-
if hasattr(query, "parameters") and query.parameters and hasattr(query, "_raw_sql"):
|
|
820
|
-
# Create fresh SQL object from raw SQL without transformations
|
|
821
|
-
fresh_sql = SQL(
|
|
822
|
-
query._raw_sql,
|
|
823
|
-
_config=replace(self.config, enable_transformations=False)
|
|
824
|
-
if self.config
|
|
825
|
-
else SQLConfig(enable_transformations=False),
|
|
826
|
-
_dialect=self.dialect,
|
|
827
|
-
)
|
|
828
|
-
arrow_result = await self._fetch_arrow_table(fresh_sql, connection=connection, **options)
|
|
829
|
-
else:
|
|
830
|
-
# query is already a SQL object, call _fetch_arrow_table directly
|
|
831
|
-
arrow_result = await self._fetch_arrow_table(query, connection=connection, **options)
|
|
765
|
+
# Use Arrow for efficient transfer
|
|
766
|
+
arrow_result = await self._fetch_arrow_table(query, connection=connection, **kwargs)
|
|
832
767
|
arrow_table = arrow_result.data
|
|
833
768
|
if arrow_table is not None:
|
|
834
|
-
await backend.write_arrow_async(path, arrow_table, **
|
|
769
|
+
await backend.write_arrow_async(path, arrow_table, **kwargs)
|
|
835
770
|
return arrow_table.num_rows
|
|
836
771
|
return 0
|
|
837
772
|
|
|
838
|
-
return await self._export_via_backend(query, backend, path, file_format, **
|
|
773
|
+
return await self._export_via_backend(query, backend, path, file_format, **kwargs)
|
|
839
774
|
|
|
840
775
|
async def import_from_storage(
|
|
841
|
-
self,
|
|
776
|
+
self,
|
|
777
|
+
source_uri: "Union[str, Path]",
|
|
778
|
+
table_name: str,
|
|
779
|
+
format: "Optional[str]" = None,
|
|
780
|
+
mode: str = "create",
|
|
781
|
+
**options: Any,
|
|
842
782
|
) -> int:
|
|
843
783
|
"""Async import data from storage with intelligent routing.
|
|
844
784
|
|
|
@@ -857,7 +797,12 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
857
797
|
return await self._import_from_storage(source_uri, table_name, format, mode, **options)
|
|
858
798
|
|
|
859
799
|
async def _import_from_storage(
|
|
860
|
-
self,
|
|
800
|
+
self,
|
|
801
|
+
source_uri: "Union[str, Path]",
|
|
802
|
+
table_name: str,
|
|
803
|
+
format: "Optional[str]" = None,
|
|
804
|
+
mode: str = "create",
|
|
805
|
+
**options: Any,
|
|
861
806
|
) -> int:
|
|
862
807
|
"""Protected async method for import operation implementation.
|
|
863
808
|
|
|
@@ -884,12 +829,14 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
884
829
|
# Async Database-Specific Implementation Hooks
|
|
885
830
|
# ============================================================================
|
|
886
831
|
|
|
887
|
-
async def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
832
|
+
async def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
|
|
888
833
|
"""Async database-specific native export."""
|
|
889
834
|
msg = "Driver should implement _export_native"
|
|
890
835
|
raise NotImplementedError(msg)
|
|
891
836
|
|
|
892
|
-
async def _import_native(
|
|
837
|
+
async def _import_native(
|
|
838
|
+
self, source_uri: "Union[str, Path]", table_name: str, format: str, mode: str, **options: Any
|
|
839
|
+
) -> int:
|
|
893
840
|
"""Async database-specific native import."""
|
|
894
841
|
msg = "Driver should implement _import_native"
|
|
895
842
|
raise NotImplementedError(msg)
|
|
@@ -917,7 +864,6 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
917
864
|
await backend.write_arrow_async(path, arrow_table, **options)
|
|
918
865
|
return len(result.data or [])
|
|
919
866
|
|
|
920
|
-
# Convert to appropriate format and write to backend
|
|
921
867
|
with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False, encoding="utf-8") as tmp:
|
|
922
868
|
if format == "csv":
|
|
923
869
|
self._write_csv(result, tmp, **options)
|
|
@@ -955,37 +901,24 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
955
901
|
@staticmethod
|
|
956
902
|
def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
957
903
|
"""Reuse sync implementation."""
|
|
958
|
-
|
|
959
|
-
writer = csv.writer(file, **options)
|
|
960
|
-
if result.column_names:
|
|
961
|
-
writer.writerow(result.column_names)
|
|
962
|
-
if result.data:
|
|
963
|
-
# Handle dict rows by extracting values in column order
|
|
964
|
-
if result.data and isinstance(result.data[0], dict):
|
|
965
|
-
rows = []
|
|
966
|
-
for row_dict in result.data:
|
|
967
|
-
# Extract values in the same order as column_names
|
|
968
|
-
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
969
|
-
rows.append(row_values)
|
|
970
|
-
writer.writerows(rows)
|
|
971
|
-
else:
|
|
972
|
-
writer.writerows(result.data)
|
|
904
|
+
write_csv(result, file, **options)
|
|
973
905
|
|
|
974
906
|
@staticmethod
|
|
975
907
|
def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
976
908
|
"""Reuse sync implementation."""
|
|
909
|
+
_ = options # May be used in the future for JSON formatting options
|
|
977
910
|
|
|
978
911
|
if result.data and result.column_names:
|
|
979
|
-
# Check if data is already in dict format
|
|
980
912
|
if result.data and isinstance(result.data[0], dict):
|
|
981
913
|
# Data is already dictionaries, use as-is
|
|
982
914
|
rows = result.data
|
|
983
915
|
else:
|
|
984
|
-
# Convert tuples/lists to list of dicts
|
|
985
916
|
rows = [dict(zip(result.column_names, row)) for row in result.data]
|
|
986
|
-
|
|
917
|
+
json_str = to_json(rows)
|
|
918
|
+
file.write(json_str)
|
|
987
919
|
else:
|
|
988
|
-
|
|
920
|
+
json_str = to_json([])
|
|
921
|
+
file.write(json_str)
|
|
989
922
|
|
|
990
923
|
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
991
924
|
"""Async database-specific bulk load implementation."""
|