sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,993 @@
|
|
|
1
|
+
"""Unified storage operations for database drivers.
|
|
2
|
+
|
|
3
|
+
This module provides the new simplified storage architecture that replaces
|
|
4
|
+
the complex web of Arrow, Export, Copy, and ResultConverter mixins with
|
|
5
|
+
just two comprehensive mixins: SyncStorageMixin and AsyncStorageMixin.
|
|
6
|
+
|
|
7
|
+
These mixins provide intelligent routing between native database capabilities
|
|
8
|
+
and storage backend operations for optimal performance.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false
|
|
12
|
+
import csv
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import tempfile
|
|
16
|
+
from abc import ABC
|
|
17
|
+
from dataclasses import replace
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
|
|
20
|
+
from urllib.parse import urlparse
|
|
21
|
+
|
|
22
|
+
from sqlspec.exceptions import MissingDependencyError
|
|
23
|
+
from sqlspec.statement import SQL, ArrowResult, StatementFilter
|
|
24
|
+
from sqlspec.statement.sql import SQLConfig
|
|
25
|
+
from sqlspec.storage import storage_registry
|
|
26
|
+
from sqlspec.typing import ArrowTable, RowT, StatementParameters
|
|
27
|
+
from sqlspec.utils.sync_tools import async_
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from sqlglot.dialects.dialect import DialectType
|
|
31
|
+
|
|
32
|
+
from sqlspec.statement import SQLResult, Statement
|
|
33
|
+
from sqlspec.storage.protocol import ObjectStoreProtocol
|
|
34
|
+
from sqlspec.typing import ConnectionT
|
|
35
|
+
|
|
36
|
+
__all__ = ("AsyncStorageMixin", "SyncStorageMixin")
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
# Constants
|
|
41
|
+
WINDOWS_PATH_MIN_LENGTH = 3
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _separate_filters_from_parameters(
|
|
45
|
+
parameters: "tuple[Any, ...]",
|
|
46
|
+
) -> "tuple[list[StatementFilter], Optional[StatementParameters]]":
|
|
47
|
+
"""Separate filters from parameters in positional args."""
|
|
48
|
+
filters: list[StatementFilter] = []
|
|
49
|
+
params: list[Any] = []
|
|
50
|
+
|
|
51
|
+
for arg in parameters:
|
|
52
|
+
if isinstance(arg, StatementFilter):
|
|
53
|
+
filters.append(arg)
|
|
54
|
+
else:
|
|
55
|
+
# Everything else is treated as parameters
|
|
56
|
+
params.append(arg)
|
|
57
|
+
|
|
58
|
+
# Convert to appropriate parameter format
|
|
59
|
+
if len(params) == 0:
|
|
60
|
+
return filters, None
|
|
61
|
+
if len(params) == 1:
|
|
62
|
+
return filters, params[0]
|
|
63
|
+
return filters, params
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class StorageMixinBase(ABC):
|
|
67
|
+
"""Base class with common storage functionality."""
|
|
68
|
+
|
|
69
|
+
__slots__ = ()
|
|
70
|
+
|
|
71
|
+
# These attributes are expected to be provided by the driver class
|
|
72
|
+
config: Any # Driver config - drivers use 'config' not '_config'
|
|
73
|
+
_connection: Any # Database connection
|
|
74
|
+
dialect: "DialectType"
|
|
75
|
+
supports_native_parquet_export: "ClassVar[bool]"
|
|
76
|
+
supports_native_parquet_import: "ClassVar[bool]"
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _ensure_pyarrow_installed() -> None:
|
|
80
|
+
"""Ensure PyArrow is installed for Arrow operations."""
|
|
81
|
+
from sqlspec.typing import PYARROW_INSTALLED
|
|
82
|
+
|
|
83
|
+
if not PYARROW_INSTALLED:
|
|
84
|
+
msg = "pyarrow is required for Arrow operations. Install with: pip install pyarrow"
|
|
85
|
+
raise MissingDependencyError(msg)
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def _get_storage_backend(uri_or_key: str) -> "ObjectStoreProtocol":
|
|
89
|
+
"""Get storage backend by URI or key with intelligent routing."""
|
|
90
|
+
return storage_registry.get(uri_or_key)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _is_uri(path_or_uri: str) -> bool:
|
|
94
|
+
"""Check if input is a URI rather than a relative path."""
|
|
95
|
+
schemes = {"s3", "gs", "gcs", "az", "azure", "abfs", "abfss", "file", "http", "https"}
|
|
96
|
+
if "://" in path_or_uri:
|
|
97
|
+
scheme = path_or_uri.split("://", maxsplit=1)[0].lower()
|
|
98
|
+
return scheme in schemes
|
|
99
|
+
if len(path_or_uri) >= WINDOWS_PATH_MIN_LENGTH and path_or_uri[1:3] == ":\\":
|
|
100
|
+
return True
|
|
101
|
+
return bool(path_or_uri.startswith("/"))
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _detect_format(uri: str) -> str:
|
|
105
|
+
"""Detect file format from URI extension."""
|
|
106
|
+
parsed = urlparse(uri)
|
|
107
|
+
path = Path(parsed.path)
|
|
108
|
+
extension = path.suffix.lower().lstrip(".")
|
|
109
|
+
|
|
110
|
+
format_map = {
|
|
111
|
+
"csv": "csv",
|
|
112
|
+
"tsv": "csv",
|
|
113
|
+
"txt": "csv",
|
|
114
|
+
"parquet": "parquet",
|
|
115
|
+
"pq": "parquet",
|
|
116
|
+
"json": "json",
|
|
117
|
+
"jsonl": "jsonl",
|
|
118
|
+
"ndjson": "jsonl",
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return format_map.get(extension, "csv")
|
|
122
|
+
|
|
123
|
+
def _resolve_backend_and_path(self, uri: str) -> "tuple[ObjectStoreProtocol, str]":
|
|
124
|
+
"""Resolve backend and path from URI with Phase 3 URI-first routing.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path")
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Tuple of (backend, path) where path is relative to the backend's base path
|
|
131
|
+
"""
|
|
132
|
+
# Convert Path objects to string
|
|
133
|
+
uri = str(uri)
|
|
134
|
+
original_path = uri
|
|
135
|
+
|
|
136
|
+
# Convert absolute paths to file:// URIs if needed
|
|
137
|
+
if self._is_uri(uri) and "://" not in uri:
|
|
138
|
+
# It's an absolute path without scheme
|
|
139
|
+
uri = f"file://{uri}"
|
|
140
|
+
|
|
141
|
+
backend = self._get_storage_backend(uri)
|
|
142
|
+
|
|
143
|
+
# For file:// URIs, return just the path part for the backend
|
|
144
|
+
path = uri[7:] if uri.startswith("file://") else original_path
|
|
145
|
+
|
|
146
|
+
return backend, path
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _rows_to_arrow_table(rows: "list[RowT]", columns: "list[str]") -> ArrowTable:
|
|
150
|
+
"""Convert rows to Arrow table."""
|
|
151
|
+
import pyarrow as pa
|
|
152
|
+
|
|
153
|
+
if not rows:
|
|
154
|
+
# Empty table with column names
|
|
155
|
+
# Create empty arrays for each column
|
|
156
|
+
empty_data: dict[str, list[Any]] = {col: [] for col in columns}
|
|
157
|
+
return pa.table(empty_data)
|
|
158
|
+
|
|
159
|
+
# Convert rows to columnar format
|
|
160
|
+
if isinstance(rows[0], dict):
|
|
161
|
+
# Dict rows
|
|
162
|
+
data = {col: [cast("dict[str, Any]", row).get(col) for row in rows] for col in columns}
|
|
163
|
+
else:
|
|
164
|
+
# Tuple/list rows
|
|
165
|
+
data = {col: [cast("tuple[Any, ...]", row)[i] for row in rows] for i, col in enumerate(columns)}
|
|
166
|
+
|
|
167
|
+
return pa.table(data)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class SyncStorageMixin(StorageMixinBase):
|
|
171
|
+
"""Unified storage operations for synchronous drivers."""
|
|
172
|
+
|
|
173
|
+
__slots__ = ()
|
|
174
|
+
|
|
175
|
+
def ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
|
|
176
|
+
"""Ingest an Arrow table into the database.
|
|
177
|
+
|
|
178
|
+
This public method provides a consistent entry point and can be used for
|
|
179
|
+
instrumentation, logging, etc., while delegating the actual work to the
|
|
180
|
+
driver-specific `_ingest_arrow_table` implementation.
|
|
181
|
+
"""
|
|
182
|
+
return self._ingest_arrow_table(table, table_name, mode, **options)
|
|
183
|
+
|
|
184
|
+
def _ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
|
|
185
|
+
"""Generic fallback for ingesting an Arrow table.
|
|
186
|
+
|
|
187
|
+
This implementation writes the Arrow table to a temporary Parquet file
|
|
188
|
+
and then uses the driver's generic `_bulk_load_file` capability.
|
|
189
|
+
Drivers with more efficient, native Arrow ingestion methods should override this.
|
|
190
|
+
"""
|
|
191
|
+
import pyarrow.parquet as pq
|
|
192
|
+
|
|
193
|
+
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
|
|
194
|
+
tmp_path = Path(tmp.name)
|
|
195
|
+
pq.write_table(table, tmp_path) # pyright: ignore
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
# Use database's bulk load capabilities for Parquet
|
|
199
|
+
return self._bulk_load_file(tmp_path, table_name, "parquet", mode, **options)
|
|
200
|
+
finally:
|
|
201
|
+
tmp_path.unlink(missing_ok=True)
|
|
202
|
+
|
|
203
|
+
# ============================================================================
|
|
204
|
+
# Core Arrow Operations
|
|
205
|
+
# ============================================================================
|
|
206
|
+
|
|
207
|
+
def fetch_arrow_table(
|
|
208
|
+
self,
|
|
209
|
+
statement: "Statement",
|
|
210
|
+
/,
|
|
211
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
212
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
213
|
+
_config: "Optional[SQLConfig]" = None,
|
|
214
|
+
**kwargs: Any,
|
|
215
|
+
) -> "ArrowResult":
|
|
216
|
+
"""Fetch query results as Arrow table with intelligent routing.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
statement: SQL statement (string, SQL object, or sqlglot Expression)
|
|
220
|
+
*parameters: Mixed parameters and filters
|
|
221
|
+
_connection: Optional connection override
|
|
222
|
+
_config: Optional SQL config override
|
|
223
|
+
**kwargs: Additional options
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
ArrowResult wrapping the Arrow table
|
|
227
|
+
"""
|
|
228
|
+
self._ensure_pyarrow_installed()
|
|
229
|
+
|
|
230
|
+
filters, params = _separate_filters_from_parameters(parameters)
|
|
231
|
+
# Convert to SQL object for processing
|
|
232
|
+
# Use a custom config if transformations will add parameters
|
|
233
|
+
if _config is None:
|
|
234
|
+
_config = self.config
|
|
235
|
+
|
|
236
|
+
# If no parameters provided but we have transformations enabled,
|
|
237
|
+
# disable parameter validation entirely to allow transformer-added parameters
|
|
238
|
+
if params is None and _config and _config.enable_transformations:
|
|
239
|
+
# Disable validation entirely for transformer-generated parameters
|
|
240
|
+
_config = replace(_config, strict_mode=False, enable_validation=False)
|
|
241
|
+
|
|
242
|
+
# Only pass params if it's not None to avoid adding None as a parameter
|
|
243
|
+
if params is not None:
|
|
244
|
+
sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect, **kwargs)
|
|
245
|
+
else:
|
|
246
|
+
sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect, **kwargs)
|
|
247
|
+
|
|
248
|
+
return self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
249
|
+
|
|
250
|
+
def _fetch_arrow_table(self, sql: SQL, connection: "Optional[ConnectionT]" = None, **kwargs: Any) -> "ArrowResult":
|
|
251
|
+
"""Generic fallback for Arrow table fetching.
|
|
252
|
+
|
|
253
|
+
This method executes a regular query and converts the results to Arrow format.
|
|
254
|
+
Drivers can call this method when they don't have native Arrow support.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
sql: SQL object to execute
|
|
258
|
+
connection: Optional connection override
|
|
259
|
+
**kwargs: Additional options (unused in fallback)
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
ArrowResult with converted data
|
|
263
|
+
"""
|
|
264
|
+
# Check if this SQL object has validation issues due to transformer-generated parameters
|
|
265
|
+
try:
|
|
266
|
+
result = cast("SQLResult", self.execute(sql, _connection=connection)) # type: ignore[attr-defined]
|
|
267
|
+
except Exception:
|
|
268
|
+
# Get the compiled SQL and parameters
|
|
269
|
+
compiled_sql, compiled_params = sql.compile("qmark")
|
|
270
|
+
|
|
271
|
+
# Execute directly via the driver's _execute method
|
|
272
|
+
driver_result = self._execute(compiled_sql, compiled_params, sql, connection=connection) # type: ignore[attr-defined]
|
|
273
|
+
|
|
274
|
+
# Wrap the result as a SQLResult
|
|
275
|
+
if "data" in driver_result:
|
|
276
|
+
# It's a SELECT result
|
|
277
|
+
result = self._wrap_select_result(sql, driver_result) # type: ignore[attr-defined]
|
|
278
|
+
else:
|
|
279
|
+
# It's a DML result
|
|
280
|
+
result = self._wrap_execute_result(sql, driver_result) # type: ignore[attr-defined]
|
|
281
|
+
|
|
282
|
+
data = result.data or []
|
|
283
|
+
columns = result.column_names or []
|
|
284
|
+
arrow_table = self._rows_to_arrow_table(data, columns)
|
|
285
|
+
return ArrowResult(statement=sql, data=arrow_table)
|
|
286
|
+
|
|
287
|
+
# ============================================================================
|
|
288
|
+
# Storage Integration Operations
|
|
289
|
+
# ============================================================================
|
|
290
|
+
|
|
291
|
+
def export_to_storage(
|
|
292
|
+
self,
|
|
293
|
+
statement: "Statement",
|
|
294
|
+
/,
|
|
295
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
296
|
+
destination_uri: str,
|
|
297
|
+
format: "Optional[str]" = None,
|
|
298
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
299
|
+
_config: "Optional[SQLConfig]" = None,
|
|
300
|
+
**options: Any,
|
|
301
|
+
) -> int:
|
|
302
|
+
"""Export query results to storage with intelligent routing.
|
|
303
|
+
|
|
304
|
+
Provides instrumentation and delegates to _export_to_storage() for consistent operation.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
statement: SQL query to execute and export
|
|
308
|
+
*parameters: Mixed parameters and filters
|
|
309
|
+
destination_uri: URI to export data to
|
|
310
|
+
format: Optional format override (auto-detected from URI if not provided)
|
|
311
|
+
_connection: Optional connection override
|
|
312
|
+
_config: Optional SQL config override
|
|
313
|
+
**options: Additional export options AND named parameters for query
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Number of rows exported
|
|
317
|
+
"""
|
|
318
|
+
# Create SQL object with proper parameter handling
|
|
319
|
+
filters, params = _separate_filters_from_parameters(parameters)
|
|
320
|
+
|
|
321
|
+
# For storage operations, disable transformations that might add unwanted parameters
|
|
322
|
+
if _config is None:
|
|
323
|
+
_config = self.config
|
|
324
|
+
if _config and _config.enable_transformations:
|
|
325
|
+
from dataclasses import replace
|
|
326
|
+
|
|
327
|
+
_config = replace(_config, enable_transformations=False)
|
|
328
|
+
|
|
329
|
+
if params is not None:
|
|
330
|
+
sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect)
|
|
331
|
+
else:
|
|
332
|
+
sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect)
|
|
333
|
+
|
|
334
|
+
return self._export_to_storage(
|
|
335
|
+
sql, destination_uri=destination_uri, format=format, _connection=_connection, **options
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def _export_to_storage(
|
|
339
|
+
self,
|
|
340
|
+
statement: "Statement",
|
|
341
|
+
/,
|
|
342
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
343
|
+
destination_uri: str,
|
|
344
|
+
format: "Optional[str]" = None,
|
|
345
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
346
|
+
_config: "Optional[SQLConfig]" = None,
|
|
347
|
+
**kwargs: Any,
|
|
348
|
+
) -> int:
|
|
349
|
+
# Convert query to string for format detection
|
|
350
|
+
if hasattr(statement, "to_sql"): # SQL object
|
|
351
|
+
query_str = cast("SQL", statement).to_sql()
|
|
352
|
+
elif isinstance(statement, str):
|
|
353
|
+
query_str = statement
|
|
354
|
+
else: # sqlglot Expression
|
|
355
|
+
query_str = str(statement)
|
|
356
|
+
|
|
357
|
+
# Auto-detect format if not provided
|
|
358
|
+
# If no format is specified and detection fails (returns "csv" as default),
|
|
359
|
+
# default to "parquet" for export operations as it's the most common use case
|
|
360
|
+
detected_format = self._detect_format(destination_uri)
|
|
361
|
+
if format:
|
|
362
|
+
file_format = format
|
|
363
|
+
elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
|
|
364
|
+
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
365
|
+
# Default to parquet for better compatibility with tests and common usage
|
|
366
|
+
file_format = "parquet"
|
|
367
|
+
else:
|
|
368
|
+
file_format = detected_format
|
|
369
|
+
|
|
370
|
+
# Special handling for parquet format - if we're exporting to parquet but the
|
|
371
|
+
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
372
|
+
# with pyarrow.parquet.read_table() which requires the extension
|
|
373
|
+
if file_format == "parquet" and not destination_uri.endswith(".parquet"):
|
|
374
|
+
destination_uri = f"{destination_uri}.parquet"
|
|
375
|
+
|
|
376
|
+
# Use storage backend - resolve AFTER modifying destination_uri
|
|
377
|
+
backend, path = self._resolve_backend_and_path(destination_uri)
|
|
378
|
+
|
|
379
|
+
# Try native database export first
|
|
380
|
+
if file_format == "parquet" and self.supports_native_parquet_export:
|
|
381
|
+
# If we have a SQL object with parameters, compile it first
|
|
382
|
+
if hasattr(statement, "compile") and hasattr(statement, "parameters") and statement.parameters:
|
|
383
|
+
_compiled_sql, _compiled_params = statement.compile(placeholder_style=self.default_parameter_style) # type: ignore[attr-defined]
|
|
384
|
+
else:
|
|
385
|
+
try:
|
|
386
|
+
return self._export_native(query_str, destination_uri, file_format, **kwargs)
|
|
387
|
+
except NotImplementedError:
|
|
388
|
+
# Fall through to use storage backend
|
|
389
|
+
pass
|
|
390
|
+
|
|
391
|
+
if file_format == "parquet":
|
|
392
|
+
# Use Arrow for efficient transfer - if statement is already a SQL object, use it directly
|
|
393
|
+
if hasattr(statement, "compile"): # It's already a SQL object from export_to_storage
|
|
394
|
+
# For parquet export via Arrow, just use the SQL object directly
|
|
395
|
+
sql_obj = cast("SQL", statement)
|
|
396
|
+
# Pass connection parameter correctly
|
|
397
|
+
arrow_result = self._fetch_arrow_table(sql_obj, connection=_connection, **kwargs)
|
|
398
|
+
else:
|
|
399
|
+
# Create SQL object if it's still a string
|
|
400
|
+
arrow_result = self.fetch_arrow_table(statement, *parameters, _connection=_connection, _config=_config)
|
|
401
|
+
|
|
402
|
+
# ArrowResult.data is never None according to the type definition
|
|
403
|
+
arrow_table = arrow_result.data
|
|
404
|
+
num_rows = arrow_table.num_rows
|
|
405
|
+
backend.write_arrow(path, arrow_table, **kwargs)
|
|
406
|
+
return num_rows
|
|
407
|
+
# Pass the SQL object if available, otherwise create one
|
|
408
|
+
if isinstance(statement, str):
|
|
409
|
+
sql_obj = SQL(statement, _config=_config, _dialect=self.dialect)
|
|
410
|
+
else:
|
|
411
|
+
sql_obj = cast("SQL", statement)
|
|
412
|
+
return self._export_via_backend(sql_obj, backend, path, file_format, **kwargs)
|
|
413
|
+
|
|
414
|
+
def import_from_storage(
|
|
415
|
+
self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
|
|
416
|
+
) -> int:
|
|
417
|
+
"""Import data from storage with intelligent routing.
|
|
418
|
+
|
|
419
|
+
Provides instrumentation and delegates to _import_from_storage() for consistent operation.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
source_uri: URI to import data from
|
|
423
|
+
table_name: Target table name
|
|
424
|
+
format: Optional format override (auto-detected from URI if not provided)
|
|
425
|
+
mode: Import mode ('create', 'append', 'replace')
|
|
426
|
+
**options: Additional import options
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Number of rows imported
|
|
430
|
+
"""
|
|
431
|
+
return self._import_from_storage(source_uri, table_name, format, mode, **options)
|
|
432
|
+
|
|
433
|
+
def _import_from_storage(
|
|
434
|
+
self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
|
|
435
|
+
) -> int:
|
|
436
|
+
"""Protected method for import operation implementation.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
source_uri: URI to import data from
|
|
440
|
+
table_name: Target table name
|
|
441
|
+
format: Optional format override (auto-detected from URI if not provided)
|
|
442
|
+
mode: Import mode ('create', 'append', 'replace')
|
|
443
|
+
**options: Additional import options
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
Number of rows imported
|
|
447
|
+
"""
|
|
448
|
+
# Auto-detect format if not provided
|
|
449
|
+
file_format = format or self._detect_format(source_uri)
|
|
450
|
+
|
|
451
|
+
# Try native database import first
|
|
452
|
+
if file_format == "parquet" and self.supports_native_parquet_import:
|
|
453
|
+
return self._import_native(source_uri, table_name, file_format, mode, **options)
|
|
454
|
+
|
|
455
|
+
# Use storage backend
|
|
456
|
+
backend, path = self._resolve_backend_and_path(source_uri)
|
|
457
|
+
|
|
458
|
+
if file_format == "parquet":
|
|
459
|
+
try:
|
|
460
|
+
# Use Arrow for efficient transfer
|
|
461
|
+
arrow_table = backend.read_arrow(path, **options)
|
|
462
|
+
return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
|
|
463
|
+
except AttributeError:
|
|
464
|
+
pass
|
|
465
|
+
|
|
466
|
+
# Use traditional import through temporary file
|
|
467
|
+
return self._import_via_backend(backend, path, table_name, file_format, mode, **options)
|
|
468
|
+
|
|
469
|
+
# ============================================================================
|
|
470
|
+
# Database-Specific Implementation Hooks
|
|
471
|
+
# ============================================================================
|
|
472
|
+
|
|
473
|
+
def _read_parquet_native(
|
|
474
|
+
self, source_uri: str, columns: "Optional[list[str]]" = None, **options: Any
|
|
475
|
+
) -> "SQLResult":
|
|
476
|
+
"""Database-specific native Parquet reading. Override in drivers."""
|
|
477
|
+
msg = "Driver should implement _read_parquet_native"
|
|
478
|
+
raise NotImplementedError(msg)
|
|
479
|
+
|
|
480
|
+
def _write_parquet_native(self, data: Union[str, ArrowTable], destination_uri: str, **options: Any) -> None:
|
|
481
|
+
"""Database-specific native Parquet writing. Override in drivers."""
|
|
482
|
+
msg = "Driver should implement _write_parquet_native"
|
|
483
|
+
raise NotImplementedError(msg)
|
|
484
|
+
|
|
485
|
+
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
486
|
+
"""Database-specific native export. Override in drivers."""
|
|
487
|
+
msg = "Driver should implement _export_native"
|
|
488
|
+
raise NotImplementedError(msg)
|
|
489
|
+
|
|
490
|
+
def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
491
|
+
"""Database-specific native import. Override in drivers."""
|
|
492
|
+
msg = "Driver should implement _import_native"
|
|
493
|
+
raise NotImplementedError(msg)
|
|
494
|
+
|
|
495
|
+
def _export_via_backend(
|
|
496
|
+
self, sql_obj: "SQL", backend: "ObjectStoreProtocol", path: str, format: str, **options: Any
|
|
497
|
+
) -> int:
|
|
498
|
+
"""Export via storage backend using temporary file."""
|
|
499
|
+
|
|
500
|
+
# Execute query and get results - use the SQL object directly
|
|
501
|
+
try:
|
|
502
|
+
result = cast("SQLResult", self.execute(sql_obj)) # type: ignore[attr-defined]
|
|
503
|
+
except Exception:
|
|
504
|
+
# Fall back to direct execution
|
|
505
|
+
compiled_sql, compiled_params = sql_obj.compile("qmark")
|
|
506
|
+
driver_result = self._execute(compiled_sql, compiled_params, sql_obj) # type: ignore[attr-defined]
|
|
507
|
+
if "data" in driver_result:
|
|
508
|
+
result = self._wrap_select_result(sql_obj, driver_result) # type: ignore[attr-defined]
|
|
509
|
+
else:
|
|
510
|
+
result = self._wrap_execute_result(sql_obj, driver_result) # type: ignore[attr-defined]
|
|
511
|
+
|
|
512
|
+
# For parquet format, convert through Arrow
|
|
513
|
+
if format == "parquet":
|
|
514
|
+
arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
|
|
515
|
+
backend.write_arrow(path, arrow_table, **options)
|
|
516
|
+
return len(result.data or [])
|
|
517
|
+
|
|
518
|
+
# Convert to appropriate format and write to backend
|
|
519
|
+
compression = options.get("compression")
|
|
520
|
+
|
|
521
|
+
# Create temp file with appropriate suffix
|
|
522
|
+
suffix = f".{format}"
|
|
523
|
+
if compression == "gzip":
|
|
524
|
+
suffix += ".gz"
|
|
525
|
+
|
|
526
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding="utf-8") as tmp:
|
|
527
|
+
tmp_path = Path(tmp.name)
|
|
528
|
+
|
|
529
|
+
# Handle compression and writing
|
|
530
|
+
if compression == "gzip":
|
|
531
|
+
import gzip
|
|
532
|
+
|
|
533
|
+
with gzip.open(tmp_path, "wt", encoding="utf-8") as file_to_write:
|
|
534
|
+
if format == "csv":
|
|
535
|
+
self._write_csv(result, file_to_write, **options)
|
|
536
|
+
elif format == "json":
|
|
537
|
+
self._write_json(result, file_to_write, **options)
|
|
538
|
+
else:
|
|
539
|
+
msg = f"Unsupported format for backend export: {format}"
|
|
540
|
+
raise ValueError(msg)
|
|
541
|
+
else:
|
|
542
|
+
with tmp_path.open("w", encoding="utf-8") as file_to_write:
|
|
543
|
+
if format == "csv":
|
|
544
|
+
self._write_csv(result, file_to_write, **options)
|
|
545
|
+
elif format == "json":
|
|
546
|
+
self._write_json(result, file_to_write, **options)
|
|
547
|
+
else:
|
|
548
|
+
msg = f"Unsupported format for backend export: {format}"
|
|
549
|
+
raise ValueError(msg)
|
|
550
|
+
|
|
551
|
+
try:
|
|
552
|
+
# Upload to storage backend
|
|
553
|
+
# Adjust path if compression was used
|
|
554
|
+
final_path = path
|
|
555
|
+
if compression == "gzip" and not path.endswith(".gz"):
|
|
556
|
+
final_path = path + ".gz"
|
|
557
|
+
|
|
558
|
+
backend.write_bytes(final_path, tmp_path.read_bytes())
|
|
559
|
+
return result.rows_affected or len(result.data or [])
|
|
560
|
+
finally:
|
|
561
|
+
tmp_path.unlink(missing_ok=True)
|
|
562
|
+
|
|
563
|
+
def _import_via_backend(
|
|
564
|
+
self, backend: "ObjectStoreProtocol", path: str, table_name: str, format: str, mode: str, **options: Any
|
|
565
|
+
) -> int:
|
|
566
|
+
"""Import via storage backend using temporary file."""
|
|
567
|
+
# Download from storage backend
|
|
568
|
+
data = backend.read_bytes(path)
|
|
569
|
+
|
|
570
|
+
with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{format}", delete=False) as tmp:
|
|
571
|
+
tmp.write(data)
|
|
572
|
+
tmp_path = Path(tmp.name)
|
|
573
|
+
|
|
574
|
+
try:
|
|
575
|
+
# Use database's bulk load capabilities
|
|
576
|
+
return self._bulk_load_file(tmp_path, table_name, format, mode, **options)
|
|
577
|
+
finally:
|
|
578
|
+
tmp_path.unlink(missing_ok=True)
|
|
579
|
+
|
|
580
|
+
@staticmethod
|
|
581
|
+
def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
582
|
+
"""Write result to CSV file."""
|
|
583
|
+
# Remove options that csv.writer doesn't understand
|
|
584
|
+
csv_options = options.copy()
|
|
585
|
+
csv_options.pop("compression", None) # Handle compression separately
|
|
586
|
+
csv_options.pop("partition_by", None) # Not applicable to CSV
|
|
587
|
+
|
|
588
|
+
writer = csv.writer(file, **csv_options) # TODO: anything better?
|
|
589
|
+
if result.column_names:
|
|
590
|
+
writer.writerow(result.column_names)
|
|
591
|
+
if result.data:
|
|
592
|
+
# Handle dict rows by extracting values in column order
|
|
593
|
+
if result.data and isinstance(result.data[0], dict):
|
|
594
|
+
rows = []
|
|
595
|
+
for row_dict in result.data:
|
|
596
|
+
# Extract values in the same order as column_names
|
|
597
|
+
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
598
|
+
rows.append(row_values)
|
|
599
|
+
writer.writerows(rows)
|
|
600
|
+
else:
|
|
601
|
+
writer.writerows(result.data)
|
|
602
|
+
|
|
603
|
+
@staticmethod
|
|
604
|
+
def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
605
|
+
"""Write result to JSON file."""
|
|
606
|
+
|
|
607
|
+
if result.data and result.column_names:
|
|
608
|
+
# Check if data is already in dict format
|
|
609
|
+
if result.data and isinstance(result.data[0], dict):
|
|
610
|
+
# Data is already dictionaries, use as-is
|
|
611
|
+
rows = result.data
|
|
612
|
+
else:
|
|
613
|
+
# Convert tuples/lists to list of dicts
|
|
614
|
+
rows = [dict(zip(result.column_names, row)) for row in result.data]
|
|
615
|
+
json.dump(rows, file, **options) # TODO: use sqlspec.utils.serializer
|
|
616
|
+
else:
|
|
617
|
+
json.dump([], file) # TODO: use sqlspec.utils.serializer
|
|
618
|
+
|
|
619
|
+
def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
620
|
+
"""Database-specific bulk load implementation. Override in drivers."""
|
|
621
|
+
msg = "Driver should implement _bulk_load_file"
|
|
622
|
+
raise NotImplementedError(msg)
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
class AsyncStorageMixin(StorageMixinBase):
|
|
626
|
+
"""Unified storage operations for asynchronous drivers."""
|
|
627
|
+
|
|
628
|
+
__slots__ = ()
|
|
629
|
+
|
|
630
|
+
async def ingest_arrow_table(
|
|
631
|
+
self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any
|
|
632
|
+
) -> int:
|
|
633
|
+
"""Ingest an Arrow table into the database asynchronously.
|
|
634
|
+
|
|
635
|
+
This public method provides a consistent entry point and can be used for
|
|
636
|
+
instrumentation, logging, etc., while delegating the actual work to the
|
|
637
|
+
driver-specific `_ingest_arrow_table` implementation.
|
|
638
|
+
"""
|
|
639
|
+
self._ensure_pyarrow_installed()
|
|
640
|
+
return await self._ingest_arrow_table(table, table_name, mode, **options)
|
|
641
|
+
|
|
642
|
+
async def _ingest_arrow_table(
|
|
643
|
+
self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any
|
|
644
|
+
) -> int:
|
|
645
|
+
"""Generic async fallback for ingesting an Arrow table.
|
|
646
|
+
|
|
647
|
+
This implementation writes the Arrow table to a temporary Parquet file
|
|
648
|
+
and then uses the driver's generic `_bulk_load_file` capability.
|
|
649
|
+
Drivers with more efficient, native Arrow ingestion methods should override this.
|
|
650
|
+
"""
|
|
651
|
+
import pyarrow.parquet as pq
|
|
652
|
+
|
|
653
|
+
# Use an async-friendly way to handle the temporary file if possible,
|
|
654
|
+
# but for simplicity, standard tempfile is acceptable here as it's a fallback.
|
|
655
|
+
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
|
|
656
|
+
tmp_path = Path(tmp.name)
|
|
657
|
+
await async_(pq.write_table)(table, tmp_path) # pyright: ignore
|
|
658
|
+
|
|
659
|
+
try:
|
|
660
|
+
# Use database's async bulk load capabilities for Parquet
|
|
661
|
+
return await self._bulk_load_file(tmp_path, table_name, "parquet", mode, **options)
|
|
662
|
+
finally:
|
|
663
|
+
tmp_path.unlink(missing_ok=True)
|
|
664
|
+
|
|
665
|
+
# ============================================================================
|
|
666
|
+
# Core Arrow Operations (Async)
|
|
667
|
+
# ============================================================================
|
|
668
|
+
|
|
669
|
+
async def fetch_arrow_table(
|
|
670
|
+
self,
|
|
671
|
+
statement: "Statement",
|
|
672
|
+
/,
|
|
673
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
674
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
675
|
+
_config: "Optional[SQLConfig]" = None,
|
|
676
|
+
**kwargs: Any,
|
|
677
|
+
) -> "ArrowResult":
|
|
678
|
+
"""Async fetch query results as Arrow table with intelligent routing.
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
statement: SQL statement (string, SQL object, or sqlglot Expression)
|
|
682
|
+
*parameters: Mixed parameters and filters
|
|
683
|
+
_connection: Optional connection override
|
|
684
|
+
_config: Optional SQL config override
|
|
685
|
+
**kwargs: Additional options
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
ArrowResult wrapping the Arrow table
|
|
689
|
+
"""
|
|
690
|
+
self._ensure_pyarrow_installed()
|
|
691
|
+
|
|
692
|
+
filters, params = _separate_filters_from_parameters(parameters)
|
|
693
|
+
# Convert to SQL object for processing
|
|
694
|
+
# Use a custom config if transformations will add parameters
|
|
695
|
+
if _config is None:
|
|
696
|
+
_config = self.config
|
|
697
|
+
|
|
698
|
+
# If no parameters provided but we have transformations enabled,
|
|
699
|
+
# disable parameter validation entirely to allow transformer-added parameters
|
|
700
|
+
if params is None and _config and _config.enable_transformations:
|
|
701
|
+
from dataclasses import replace
|
|
702
|
+
|
|
703
|
+
# Disable validation entirely for transformer-generated parameters
|
|
704
|
+
_config = replace(_config, strict_mode=False, enable_validation=False)
|
|
705
|
+
|
|
706
|
+
# Only pass params if it's not None to avoid adding None as a parameter
|
|
707
|
+
if params is not None:
|
|
708
|
+
sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect, **kwargs)
|
|
709
|
+
else:
|
|
710
|
+
sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect, **kwargs)
|
|
711
|
+
|
|
712
|
+
# Delegate to protected method that drivers can override
|
|
713
|
+
return await self._fetch_arrow_table(sql, connection=_connection, **kwargs)
|
|
714
|
+
|
|
715
|
+
async def _fetch_arrow_table(
|
|
716
|
+
self, sql: SQL, connection: "Optional[ConnectionT]" = None, **kwargs: Any
|
|
717
|
+
) -> "ArrowResult":
|
|
718
|
+
"""Generic async fallback for Arrow table fetching.
|
|
719
|
+
|
|
720
|
+
This method executes a regular query and converts the results to Arrow format.
|
|
721
|
+
Drivers should override this method to provide native Arrow support if available.
|
|
722
|
+
If a driver has partial native support, it can call `super()._fetch_arrow_table(...)`
|
|
723
|
+
to use this fallback implementation.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
sql: SQL object to execute
|
|
727
|
+
connection: Optional connection override
|
|
728
|
+
**kwargs: Additional options (unused in fallback)
|
|
729
|
+
|
|
730
|
+
Returns:
|
|
731
|
+
ArrowResult with converted data
|
|
732
|
+
"""
|
|
733
|
+
# Execute regular query
|
|
734
|
+
result = await self.execute(sql, _connection=connection) # type: ignore[attr-defined]
|
|
735
|
+
|
|
736
|
+
# Convert to Arrow table
|
|
737
|
+
arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
|
|
738
|
+
|
|
739
|
+
return ArrowResult(statement=sql, data=arrow_table)
|
|
740
|
+
|
|
741
|
+
async def export_to_storage(
|
|
742
|
+
self,
|
|
743
|
+
statement: "Statement",
|
|
744
|
+
/,
|
|
745
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
746
|
+
destination_uri: str,
|
|
747
|
+
format: "Optional[str]" = None,
|
|
748
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
749
|
+
_config: "Optional[SQLConfig]" = None,
|
|
750
|
+
**options: Any,
|
|
751
|
+
) -> int:
|
|
752
|
+
# Create SQL object with proper parameter handling
|
|
753
|
+
filters, params = _separate_filters_from_parameters(parameters)
|
|
754
|
+
|
|
755
|
+
# For storage operations, disable transformations that might add unwanted parameters
|
|
756
|
+
if _config is None:
|
|
757
|
+
_config = self.config
|
|
758
|
+
if _config and _config.enable_transformations:
|
|
759
|
+
from dataclasses import replace
|
|
760
|
+
|
|
761
|
+
_config = replace(_config, enable_transformations=False)
|
|
762
|
+
|
|
763
|
+
if params is not None:
|
|
764
|
+
sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect, **options)
|
|
765
|
+
else:
|
|
766
|
+
sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect, **options)
|
|
767
|
+
|
|
768
|
+
return await self._export_to_storage(sql, destination_uri, format, connection=_connection, **options)
|
|
769
|
+
|
|
770
|
+
async def _export_to_storage(
|
|
771
|
+
self,
|
|
772
|
+
query: "SQL",
|
|
773
|
+
destination_uri: str,
|
|
774
|
+
format: "Optional[str]" = None,
|
|
775
|
+
connection: "Optional[ConnectionT]" = None,
|
|
776
|
+
**options: Any,
|
|
777
|
+
) -> int:
|
|
778
|
+
"""Protected async method for export operation implementation.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
query: SQL query to execute and export
|
|
782
|
+
destination_uri: URI to export data to
|
|
783
|
+
format: Optional format override (auto-detected from URI if not provided)
|
|
784
|
+
connection: Optional connection override
|
|
785
|
+
**options: Additional export options
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
Number of rows exported
|
|
789
|
+
"""
|
|
790
|
+
# Auto-detect format if not provided
|
|
791
|
+
# If no format is specified and detection fails (returns "csv" as default),
|
|
792
|
+
# default to "parquet" for export operations as it's the most common use case
|
|
793
|
+
detected_format = self._detect_format(destination_uri)
|
|
794
|
+
if format:
|
|
795
|
+
file_format = format
|
|
796
|
+
elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
|
|
797
|
+
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
798
|
+
# Default to parquet for better compatibility with tests and common usage
|
|
799
|
+
file_format = "parquet"
|
|
800
|
+
else:
|
|
801
|
+
file_format = detected_format
|
|
802
|
+
|
|
803
|
+
# Special handling for parquet format - if we're exporting to parquet but the
|
|
804
|
+
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
805
|
+
# with pyarrow.parquet.read_table() which requires the extension
|
|
806
|
+
if file_format == "parquet" and not destination_uri.endswith(".parquet"):
|
|
807
|
+
destination_uri = f"{destination_uri}.parquet"
|
|
808
|
+
|
|
809
|
+
# Use storage backend - resolve AFTER modifying destination_uri
|
|
810
|
+
backend, path = self._resolve_backend_and_path(destination_uri)
|
|
811
|
+
|
|
812
|
+
# Try native database export first
|
|
813
|
+
if file_format == "parquet" and self.supports_native_parquet_export:
|
|
814
|
+
return await self._export_native(query.as_script().sql, destination_uri, file_format, **options)
|
|
815
|
+
|
|
816
|
+
if file_format == "parquet":
|
|
817
|
+
# For parquet export via Arrow, we need to ensure no unwanted parameter transformations
|
|
818
|
+
# If the query already has parameters from transformations, create a fresh SQL object
|
|
819
|
+
if hasattr(query, "parameters") and query.parameters and hasattr(query, "_raw_sql"):
|
|
820
|
+
# Create fresh SQL object from raw SQL without transformations
|
|
821
|
+
fresh_sql = SQL(
|
|
822
|
+
query._raw_sql,
|
|
823
|
+
_config=replace(self.config, enable_transformations=False)
|
|
824
|
+
if self.config
|
|
825
|
+
else SQLConfig(enable_transformations=False),
|
|
826
|
+
_dialect=self.dialect,
|
|
827
|
+
)
|
|
828
|
+
arrow_result = await self._fetch_arrow_table(fresh_sql, connection=connection, **options)
|
|
829
|
+
else:
|
|
830
|
+
# query is already a SQL object, call _fetch_arrow_table directly
|
|
831
|
+
arrow_result = await self._fetch_arrow_table(query, connection=connection, **options)
|
|
832
|
+
arrow_table = arrow_result.data
|
|
833
|
+
if arrow_table is not None:
|
|
834
|
+
await backend.write_arrow_async(path, arrow_table, **options)
|
|
835
|
+
return arrow_table.num_rows
|
|
836
|
+
return 0
|
|
837
|
+
|
|
838
|
+
return await self._export_via_backend(query, backend, path, file_format, **options)
|
|
839
|
+
|
|
840
|
+
async def import_from_storage(
|
|
841
|
+
self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
|
|
842
|
+
) -> int:
|
|
843
|
+
"""Async import data from storage with intelligent routing.
|
|
844
|
+
|
|
845
|
+
Provides instrumentation and delegates to _import_from_storage() for consistent operation.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
source_uri: URI to import data from
|
|
849
|
+
table_name: Target table name
|
|
850
|
+
format: Optional format override (auto-detected from URI if not provided)
|
|
851
|
+
mode: Import mode ('create', 'append', 'replace')
|
|
852
|
+
**options: Additional import options
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
Number of rows imported
|
|
856
|
+
"""
|
|
857
|
+
return await self._import_from_storage(source_uri, table_name, format, mode, **options)
|
|
858
|
+
|
|
859
|
+
async def _import_from_storage(
|
|
860
|
+
self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
|
|
861
|
+
) -> int:
|
|
862
|
+
"""Protected async method for import operation implementation.
|
|
863
|
+
|
|
864
|
+
Args:
|
|
865
|
+
source_uri: URI to import data from
|
|
866
|
+
table_name: Target table name
|
|
867
|
+
format: Optional format override (auto-detected from URI if not provided)
|
|
868
|
+
mode: Import mode ('create', 'append', 'replace')
|
|
869
|
+
**options: Additional import options
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
Number of rows imported
|
|
873
|
+
"""
|
|
874
|
+
file_format = format or self._detect_format(source_uri)
|
|
875
|
+
backend, path = self._resolve_backend_and_path(source_uri)
|
|
876
|
+
|
|
877
|
+
if file_format == "parquet":
|
|
878
|
+
arrow_table = await backend.read_arrow_async(path, **options)
|
|
879
|
+
return await self.ingest_arrow_table(arrow_table, table_name, mode=mode)
|
|
880
|
+
|
|
881
|
+
return await self._import_via_backend(backend, path, table_name, file_format, mode, **options)
|
|
882
|
+
|
|
883
|
+
# ============================================================================
|
|
884
|
+
# Async Database-Specific Implementation Hooks
|
|
885
|
+
# ============================================================================
|
|
886
|
+
|
|
887
|
+
async def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
888
|
+
"""Async database-specific native export."""
|
|
889
|
+
msg = "Driver should implement _export_native"
|
|
890
|
+
raise NotImplementedError(msg)
|
|
891
|
+
|
|
892
|
+
async def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
893
|
+
"""Async database-specific native import."""
|
|
894
|
+
msg = "Driver should implement _import_native"
|
|
895
|
+
raise NotImplementedError(msg)
|
|
896
|
+
|
|
897
|
+
async def _export_via_backend(
|
|
898
|
+
self, sql_obj: "SQL", backend: "ObjectStoreProtocol", path: str, format: str, **options: Any
|
|
899
|
+
) -> int:
|
|
900
|
+
"""Async export via storage backend."""
|
|
901
|
+
|
|
902
|
+
# Execute query and get results - use the SQL object directly
|
|
903
|
+
try:
|
|
904
|
+
result = await self.execute(sql_obj) # type: ignore[attr-defined]
|
|
905
|
+
except Exception:
|
|
906
|
+
# Fall back to direct execution
|
|
907
|
+
compiled_sql, compiled_params = sql_obj.compile("qmark")
|
|
908
|
+
driver_result = await self._execute(compiled_sql, compiled_params, sql_obj) # type: ignore[attr-defined]
|
|
909
|
+
if "data" in driver_result:
|
|
910
|
+
result = self._wrap_select_result(sql_obj, driver_result) # type: ignore[attr-defined]
|
|
911
|
+
else:
|
|
912
|
+
result = self._wrap_execute_result(sql_obj, driver_result) # type: ignore[attr-defined]
|
|
913
|
+
|
|
914
|
+
# For parquet format, convert through Arrow
|
|
915
|
+
if format == "parquet":
|
|
916
|
+
arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
|
|
917
|
+
await backend.write_arrow_async(path, arrow_table, **options)
|
|
918
|
+
return len(result.data or [])
|
|
919
|
+
|
|
920
|
+
# Convert to appropriate format and write to backend
|
|
921
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False, encoding="utf-8") as tmp:
|
|
922
|
+
if format == "csv":
|
|
923
|
+
self._write_csv(result, tmp, **options)
|
|
924
|
+
elif format == "json":
|
|
925
|
+
self._write_json(result, tmp, **options)
|
|
926
|
+
else:
|
|
927
|
+
msg = f"Unsupported format for backend export: {format}"
|
|
928
|
+
raise ValueError(msg)
|
|
929
|
+
|
|
930
|
+
tmp_path = Path(tmp.name)
|
|
931
|
+
|
|
932
|
+
try:
|
|
933
|
+
# Upload to storage backend (async if supported)
|
|
934
|
+
await backend.write_bytes_async(path, tmp_path.read_bytes())
|
|
935
|
+
return result.rows_affected or len(result.data or [])
|
|
936
|
+
finally:
|
|
937
|
+
tmp_path.unlink(missing_ok=True)
|
|
938
|
+
|
|
939
|
+
async def _import_via_backend(
|
|
940
|
+
self, backend: "ObjectStoreProtocol", path: str, table_name: str, format: str, mode: str, **options: Any
|
|
941
|
+
) -> int:
|
|
942
|
+
"""Async import via storage backend."""
|
|
943
|
+
# Download from storage backend (async if supported)
|
|
944
|
+
data = await backend.read_bytes_async(path)
|
|
945
|
+
|
|
946
|
+
with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{format}", delete=False) as tmp:
|
|
947
|
+
tmp.write(data)
|
|
948
|
+
tmp_path = Path(tmp.name)
|
|
949
|
+
|
|
950
|
+
try:
|
|
951
|
+
return await self._bulk_load_file(tmp_path, table_name, format, mode, **options)
|
|
952
|
+
finally:
|
|
953
|
+
tmp_path.unlink(missing_ok=True)
|
|
954
|
+
|
|
955
|
+
@staticmethod
|
|
956
|
+
def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
957
|
+
"""Reuse sync implementation."""
|
|
958
|
+
|
|
959
|
+
writer = csv.writer(file, **options)
|
|
960
|
+
if result.column_names:
|
|
961
|
+
writer.writerow(result.column_names)
|
|
962
|
+
if result.data:
|
|
963
|
+
# Handle dict rows by extracting values in column order
|
|
964
|
+
if result.data and isinstance(result.data[0], dict):
|
|
965
|
+
rows = []
|
|
966
|
+
for row_dict in result.data:
|
|
967
|
+
# Extract values in the same order as column_names
|
|
968
|
+
row_values = [row_dict.get(col) for col in result.column_names or []]
|
|
969
|
+
rows.append(row_values)
|
|
970
|
+
writer.writerows(rows)
|
|
971
|
+
else:
|
|
972
|
+
writer.writerows(result.data)
|
|
973
|
+
|
|
974
|
+
@staticmethod
|
|
975
|
+
def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
|
|
976
|
+
"""Reuse sync implementation."""
|
|
977
|
+
|
|
978
|
+
if result.data and result.column_names:
|
|
979
|
+
# Check if data is already in dict format
|
|
980
|
+
if result.data and isinstance(result.data[0], dict):
|
|
981
|
+
# Data is already dictionaries, use as-is
|
|
982
|
+
rows = result.data
|
|
983
|
+
else:
|
|
984
|
+
# Convert tuples/lists to list of dicts
|
|
985
|
+
rows = [dict(zip(result.column_names, row)) for row in result.data]
|
|
986
|
+
json.dump(rows, file, **options)
|
|
987
|
+
else:
|
|
988
|
+
json.dump([], file)
|
|
989
|
+
|
|
990
|
+
async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
991
|
+
"""Async database-specific bulk load implementation."""
|
|
992
|
+
msg = "Driver should implement _bulk_load_file"
|
|
993
|
+
raise NotImplementedError(msg)
|