sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +725 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.1.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,993 @@
1
+ """Unified storage operations for database drivers.
2
+
3
+ This module provides the new simplified storage architecture that replaces
4
+ the complex web of Arrow, Export, Copy, and ResultConverter mixins with
5
+ just two comprehensive mixins: SyncStorageMixin and AsyncStorageMixin.
6
+
7
+ These mixins provide intelligent routing between native database capabilities
8
+ and storage backend operations for optimal performance.
9
+ """
10
+
11
+ # pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false
12
+ import csv
13
+ import json
14
+ import logging
15
+ import tempfile
16
+ from abc import ABC
17
+ from dataclasses import replace
18
+ from pathlib import Path
19
+ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
20
+ from urllib.parse import urlparse
21
+
22
+ from sqlspec.exceptions import MissingDependencyError
23
+ from sqlspec.statement import SQL, ArrowResult, StatementFilter
24
+ from sqlspec.statement.sql import SQLConfig
25
+ from sqlspec.storage import storage_registry
26
+ from sqlspec.typing import ArrowTable, RowT, StatementParameters
27
+ from sqlspec.utils.sync_tools import async_
28
+
29
+ if TYPE_CHECKING:
30
+ from sqlglot.dialects.dialect import DialectType
31
+
32
+ from sqlspec.statement import SQLResult, Statement
33
+ from sqlspec.storage.protocol import ObjectStoreProtocol
34
+ from sqlspec.typing import ConnectionT
35
+
36
+ __all__ = ("AsyncStorageMixin", "SyncStorageMixin")
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Constants
41
+ WINDOWS_PATH_MIN_LENGTH = 3
42
+
43
+
44
+ def _separate_filters_from_parameters(
45
+ parameters: "tuple[Any, ...]",
46
+ ) -> "tuple[list[StatementFilter], Optional[StatementParameters]]":
47
+ """Separate filters from parameters in positional args."""
48
+ filters: list[StatementFilter] = []
49
+ params: list[Any] = []
50
+
51
+ for arg in parameters:
52
+ if isinstance(arg, StatementFilter):
53
+ filters.append(arg)
54
+ else:
55
+ # Everything else is treated as parameters
56
+ params.append(arg)
57
+
58
+ # Convert to appropriate parameter format
59
+ if len(params) == 0:
60
+ return filters, None
61
+ if len(params) == 1:
62
+ return filters, params[0]
63
+ return filters, params
64
+
65
+
66
+ class StorageMixinBase(ABC):
67
+ """Base class with common storage functionality."""
68
+
69
+ __slots__ = ()
70
+
71
+ # These attributes are expected to be provided by the driver class
72
+ config: Any # Driver config - drivers use 'config' not '_config'
73
+ _connection: Any # Database connection
74
+ dialect: "DialectType"
75
+ supports_native_parquet_export: "ClassVar[bool]"
76
+ supports_native_parquet_import: "ClassVar[bool]"
77
+
78
+ @staticmethod
79
+ def _ensure_pyarrow_installed() -> None:
80
+ """Ensure PyArrow is installed for Arrow operations."""
81
+ from sqlspec.typing import PYARROW_INSTALLED
82
+
83
+ if not PYARROW_INSTALLED:
84
+ msg = "pyarrow is required for Arrow operations. Install with: pip install pyarrow"
85
+ raise MissingDependencyError(msg)
86
+
87
+ @staticmethod
88
+ def _get_storage_backend(uri_or_key: str) -> "ObjectStoreProtocol":
89
+ """Get storage backend by URI or key with intelligent routing."""
90
+ return storage_registry.get(uri_or_key)
91
+
92
+ @staticmethod
93
+ def _is_uri(path_or_uri: str) -> bool:
94
+ """Check if input is a URI rather than a relative path."""
95
+ schemes = {"s3", "gs", "gcs", "az", "azure", "abfs", "abfss", "file", "http", "https"}
96
+ if "://" in path_or_uri:
97
+ scheme = path_or_uri.split("://", maxsplit=1)[0].lower()
98
+ return scheme in schemes
99
+ if len(path_or_uri) >= WINDOWS_PATH_MIN_LENGTH and path_or_uri[1:3] == ":\\":
100
+ return True
101
+ return bool(path_or_uri.startswith("/"))
102
+
103
+ @staticmethod
104
+ def _detect_format(uri: str) -> str:
105
+ """Detect file format from URI extension."""
106
+ parsed = urlparse(uri)
107
+ path = Path(parsed.path)
108
+ extension = path.suffix.lower().lstrip(".")
109
+
110
+ format_map = {
111
+ "csv": "csv",
112
+ "tsv": "csv",
113
+ "txt": "csv",
114
+ "parquet": "parquet",
115
+ "pq": "parquet",
116
+ "json": "json",
117
+ "jsonl": "jsonl",
118
+ "ndjson": "jsonl",
119
+ }
120
+
121
+ return format_map.get(extension, "csv")
122
+
123
+ def _resolve_backend_and_path(self, uri: str) -> "tuple[ObjectStoreProtocol, str]":
124
+ """Resolve backend and path from URI with Phase 3 URI-first routing.
125
+
126
+ Args:
127
+ uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path")
128
+
129
+ Returns:
130
+ Tuple of (backend, path) where path is relative to the backend's base path
131
+ """
132
+ # Convert Path objects to string
133
+ uri = str(uri)
134
+ original_path = uri
135
+
136
+ # Convert absolute paths to file:// URIs if needed
137
+ if self._is_uri(uri) and "://" not in uri:
138
+ # It's an absolute path without scheme
139
+ uri = f"file://{uri}"
140
+
141
+ backend = self._get_storage_backend(uri)
142
+
143
+ # For file:// URIs, return just the path part for the backend
144
+ path = uri[7:] if uri.startswith("file://") else original_path
145
+
146
+ return backend, path
147
+
148
+ @staticmethod
149
+ def _rows_to_arrow_table(rows: "list[RowT]", columns: "list[str]") -> ArrowTable:
150
+ """Convert rows to Arrow table."""
151
+ import pyarrow as pa
152
+
153
+ if not rows:
154
+ # Empty table with column names
155
+ # Create empty arrays for each column
156
+ empty_data: dict[str, list[Any]] = {col: [] for col in columns}
157
+ return pa.table(empty_data)
158
+
159
+ # Convert rows to columnar format
160
+ if isinstance(rows[0], dict):
161
+ # Dict rows
162
+ data = {col: [cast("dict[str, Any]", row).get(col) for row in rows] for col in columns}
163
+ else:
164
+ # Tuple/list rows
165
+ data = {col: [cast("tuple[Any, ...]", row)[i] for row in rows] for i, col in enumerate(columns)}
166
+
167
+ return pa.table(data)
168
+
169
+
170
+ class SyncStorageMixin(StorageMixinBase):
171
+ """Unified storage operations for synchronous drivers."""
172
+
173
+ __slots__ = ()
174
+
175
+ def ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
176
+ """Ingest an Arrow table into the database.
177
+
178
+ This public method provides a consistent entry point and can be used for
179
+ instrumentation, logging, etc., while delegating the actual work to the
180
+ driver-specific `_ingest_arrow_table` implementation.
181
+ """
182
+ return self._ingest_arrow_table(table, table_name, mode, **options)
183
+
184
+ def _ingest_arrow_table(self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any) -> int:
185
+ """Generic fallback for ingesting an Arrow table.
186
+
187
+ This implementation writes the Arrow table to a temporary Parquet file
188
+ and then uses the driver's generic `_bulk_load_file` capability.
189
+ Drivers with more efficient, native Arrow ingestion methods should override this.
190
+ """
191
+ import pyarrow.parquet as pq
192
+
193
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
194
+ tmp_path = Path(tmp.name)
195
+ pq.write_table(table, tmp_path) # pyright: ignore
196
+
197
+ try:
198
+ # Use database's bulk load capabilities for Parquet
199
+ return self._bulk_load_file(tmp_path, table_name, "parquet", mode, **options)
200
+ finally:
201
+ tmp_path.unlink(missing_ok=True)
202
+
203
+ # ============================================================================
204
+ # Core Arrow Operations
205
+ # ============================================================================
206
+
207
+ def fetch_arrow_table(
208
+ self,
209
+ statement: "Statement",
210
+ /,
211
+ *parameters: "Union[StatementParameters, StatementFilter]",
212
+ _connection: "Optional[ConnectionT]" = None,
213
+ _config: "Optional[SQLConfig]" = None,
214
+ **kwargs: Any,
215
+ ) -> "ArrowResult":
216
+ """Fetch query results as Arrow table with intelligent routing.
217
+
218
+ Args:
219
+ statement: SQL statement (string, SQL object, or sqlglot Expression)
220
+ *parameters: Mixed parameters and filters
221
+ _connection: Optional connection override
222
+ _config: Optional SQL config override
223
+ **kwargs: Additional options
224
+
225
+ Returns:
226
+ ArrowResult wrapping the Arrow table
227
+ """
228
+ self._ensure_pyarrow_installed()
229
+
230
+ filters, params = _separate_filters_from_parameters(parameters)
231
+ # Convert to SQL object for processing
232
+ # Use a custom config if transformations will add parameters
233
+ if _config is None:
234
+ _config = self.config
235
+
236
+ # If no parameters provided but we have transformations enabled,
237
+ # disable parameter validation entirely to allow transformer-added parameters
238
+ if params is None and _config and _config.enable_transformations:
239
+ # Disable validation entirely for transformer-generated parameters
240
+ _config = replace(_config, strict_mode=False, enable_validation=False)
241
+
242
+ # Only pass params if it's not None to avoid adding None as a parameter
243
+ if params is not None:
244
+ sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect, **kwargs)
245
+ else:
246
+ sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect, **kwargs)
247
+
248
+ return self._fetch_arrow_table(sql, connection=_connection, **kwargs)
249
+
250
+ def _fetch_arrow_table(self, sql: SQL, connection: "Optional[ConnectionT]" = None, **kwargs: Any) -> "ArrowResult":
251
+ """Generic fallback for Arrow table fetching.
252
+
253
+ This method executes a regular query and converts the results to Arrow format.
254
+ Drivers can call this method when they don't have native Arrow support.
255
+
256
+ Args:
257
+ sql: SQL object to execute
258
+ connection: Optional connection override
259
+ **kwargs: Additional options (unused in fallback)
260
+
261
+ Returns:
262
+ ArrowResult with converted data
263
+ """
264
+ # Check if this SQL object has validation issues due to transformer-generated parameters
265
+ try:
266
+ result = cast("SQLResult", self.execute(sql, _connection=connection)) # type: ignore[attr-defined]
267
+ except Exception:
268
+ # Get the compiled SQL and parameters
269
+ compiled_sql, compiled_params = sql.compile("qmark")
270
+
271
+ # Execute directly via the driver's _execute method
272
+ driver_result = self._execute(compiled_sql, compiled_params, sql, connection=connection) # type: ignore[attr-defined]
273
+
274
+ # Wrap the result as a SQLResult
275
+ if "data" in driver_result:
276
+ # It's a SELECT result
277
+ result = self._wrap_select_result(sql, driver_result) # type: ignore[attr-defined]
278
+ else:
279
+ # It's a DML result
280
+ result = self._wrap_execute_result(sql, driver_result) # type: ignore[attr-defined]
281
+
282
+ data = result.data or []
283
+ columns = result.column_names or []
284
+ arrow_table = self._rows_to_arrow_table(data, columns)
285
+ return ArrowResult(statement=sql, data=arrow_table)
286
+
287
+ # ============================================================================
288
+ # Storage Integration Operations
289
+ # ============================================================================
290
+
291
+ def export_to_storage(
292
+ self,
293
+ statement: "Statement",
294
+ /,
295
+ *parameters: "Union[StatementParameters, StatementFilter]",
296
+ destination_uri: str,
297
+ format: "Optional[str]" = None,
298
+ _connection: "Optional[ConnectionT]" = None,
299
+ _config: "Optional[SQLConfig]" = None,
300
+ **options: Any,
301
+ ) -> int:
302
+ """Export query results to storage with intelligent routing.
303
+
304
+ Provides instrumentation and delegates to _export_to_storage() for consistent operation.
305
+
306
+ Args:
307
+ statement: SQL query to execute and export
308
+ *parameters: Mixed parameters and filters
309
+ destination_uri: URI to export data to
310
+ format: Optional format override (auto-detected from URI if not provided)
311
+ _connection: Optional connection override
312
+ _config: Optional SQL config override
313
+ **options: Additional export options AND named parameters for query
314
+
315
+ Returns:
316
+ Number of rows exported
317
+ """
318
+ # Create SQL object with proper parameter handling
319
+ filters, params = _separate_filters_from_parameters(parameters)
320
+
321
+ # For storage operations, disable transformations that might add unwanted parameters
322
+ if _config is None:
323
+ _config = self.config
324
+ if _config and _config.enable_transformations:
325
+ from dataclasses import replace
326
+
327
+ _config = replace(_config, enable_transformations=False)
328
+
329
+ if params is not None:
330
+ sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect)
331
+ else:
332
+ sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect)
333
+
334
+ return self._export_to_storage(
335
+ sql, destination_uri=destination_uri, format=format, _connection=_connection, **options
336
+ )
337
+
338
+ def _export_to_storage(
339
+ self,
340
+ statement: "Statement",
341
+ /,
342
+ *parameters: "Union[StatementParameters, StatementFilter]",
343
+ destination_uri: str,
344
+ format: "Optional[str]" = None,
345
+ _connection: "Optional[ConnectionT]" = None,
346
+ _config: "Optional[SQLConfig]" = None,
347
+ **kwargs: Any,
348
+ ) -> int:
349
+ # Convert query to string for format detection
350
+ if hasattr(statement, "to_sql"): # SQL object
351
+ query_str = cast("SQL", statement).to_sql()
352
+ elif isinstance(statement, str):
353
+ query_str = statement
354
+ else: # sqlglot Expression
355
+ query_str = str(statement)
356
+
357
+ # Auto-detect format if not provided
358
+ # If no format is specified and detection fails (returns "csv" as default),
359
+ # default to "parquet" for export operations as it's the most common use case
360
+ detected_format = self._detect_format(destination_uri)
361
+ if format:
362
+ file_format = format
363
+ elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
364
+ # Detection returned default "csv" but file doesn't actually have CSV extension
365
+ # Default to parquet for better compatibility with tests and common usage
366
+ file_format = "parquet"
367
+ else:
368
+ file_format = detected_format
369
+
370
+ # Special handling for parquet format - if we're exporting to parquet but the
371
+ # destination doesn't have .parquet extension, add it to ensure compatibility
372
+ # with pyarrow.parquet.read_table() which requires the extension
373
+ if file_format == "parquet" and not destination_uri.endswith(".parquet"):
374
+ destination_uri = f"{destination_uri}.parquet"
375
+
376
+ # Use storage backend - resolve AFTER modifying destination_uri
377
+ backend, path = self._resolve_backend_and_path(destination_uri)
378
+
379
+ # Try native database export first
380
+ if file_format == "parquet" and self.supports_native_parquet_export:
381
+ # If we have a SQL object with parameters, compile it first
382
+ if hasattr(statement, "compile") and hasattr(statement, "parameters") and statement.parameters:
383
+ _compiled_sql, _compiled_params = statement.compile(placeholder_style=self.default_parameter_style) # type: ignore[attr-defined]
384
+ else:
385
+ try:
386
+ return self._export_native(query_str, destination_uri, file_format, **kwargs)
387
+ except NotImplementedError:
388
+ # Fall through to use storage backend
389
+ pass
390
+
391
+ if file_format == "parquet":
392
+ # Use Arrow for efficient transfer - if statement is already a SQL object, use it directly
393
+ if hasattr(statement, "compile"): # It's already a SQL object from export_to_storage
394
+ # For parquet export via Arrow, just use the SQL object directly
395
+ sql_obj = cast("SQL", statement)
396
+ # Pass connection parameter correctly
397
+ arrow_result = self._fetch_arrow_table(sql_obj, connection=_connection, **kwargs)
398
+ else:
399
+ # Create SQL object if it's still a string
400
+ arrow_result = self.fetch_arrow_table(statement, *parameters, _connection=_connection, _config=_config)
401
+
402
+ # ArrowResult.data is never None according to the type definition
403
+ arrow_table = arrow_result.data
404
+ num_rows = arrow_table.num_rows
405
+ backend.write_arrow(path, arrow_table, **kwargs)
406
+ return num_rows
407
+ # Pass the SQL object if available, otherwise create one
408
+ if isinstance(statement, str):
409
+ sql_obj = SQL(statement, _config=_config, _dialect=self.dialect)
410
+ else:
411
+ sql_obj = cast("SQL", statement)
412
+ return self._export_via_backend(sql_obj, backend, path, file_format, **kwargs)
413
+
414
+ def import_from_storage(
415
+ self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
416
+ ) -> int:
417
+ """Import data from storage with intelligent routing.
418
+
419
+ Provides instrumentation and delegates to _import_from_storage() for consistent operation.
420
+
421
+ Args:
422
+ source_uri: URI to import data from
423
+ table_name: Target table name
424
+ format: Optional format override (auto-detected from URI if not provided)
425
+ mode: Import mode ('create', 'append', 'replace')
426
+ **options: Additional import options
427
+
428
+ Returns:
429
+ Number of rows imported
430
+ """
431
+ return self._import_from_storage(source_uri, table_name, format, mode, **options)
432
+
433
+ def _import_from_storage(
434
+ self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
435
+ ) -> int:
436
+ """Protected method for import operation implementation.
437
+
438
+ Args:
439
+ source_uri: URI to import data from
440
+ table_name: Target table name
441
+ format: Optional format override (auto-detected from URI if not provided)
442
+ mode: Import mode ('create', 'append', 'replace')
443
+ **options: Additional import options
444
+
445
+ Returns:
446
+ Number of rows imported
447
+ """
448
+ # Auto-detect format if not provided
449
+ file_format = format or self._detect_format(source_uri)
450
+
451
+ # Try native database import first
452
+ if file_format == "parquet" and self.supports_native_parquet_import:
453
+ return self._import_native(source_uri, table_name, file_format, mode, **options)
454
+
455
+ # Use storage backend
456
+ backend, path = self._resolve_backend_and_path(source_uri)
457
+
458
+ if file_format == "parquet":
459
+ try:
460
+ # Use Arrow for efficient transfer
461
+ arrow_table = backend.read_arrow(path, **options)
462
+ return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
463
+ except AttributeError:
464
+ pass
465
+
466
+ # Use traditional import through temporary file
467
+ return self._import_via_backend(backend, path, table_name, file_format, mode, **options)
468
+
469
+ # ============================================================================
470
+ # Database-Specific Implementation Hooks
471
+ # ============================================================================
472
+
473
+ def _read_parquet_native(
474
+ self, source_uri: str, columns: "Optional[list[str]]" = None, **options: Any
475
+ ) -> "SQLResult":
476
+ """Database-specific native Parquet reading. Override in drivers."""
477
+ msg = "Driver should implement _read_parquet_native"
478
+ raise NotImplementedError(msg)
479
+
480
+ def _write_parquet_native(self, data: Union[str, ArrowTable], destination_uri: str, **options: Any) -> None:
481
+ """Database-specific native Parquet writing. Override in drivers."""
482
+ msg = "Driver should implement _write_parquet_native"
483
+ raise NotImplementedError(msg)
484
+
485
+ def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
486
+ """Database-specific native export. Override in drivers."""
487
+ msg = "Driver should implement _export_native"
488
+ raise NotImplementedError(msg)
489
+
490
+ def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
491
+ """Database-specific native import. Override in drivers."""
492
+ msg = "Driver should implement _import_native"
493
+ raise NotImplementedError(msg)
494
+
495
+ def _export_via_backend(
496
+ self, sql_obj: "SQL", backend: "ObjectStoreProtocol", path: str, format: str, **options: Any
497
+ ) -> int:
498
+ """Export via storage backend using temporary file."""
499
+
500
+ # Execute query and get results - use the SQL object directly
501
+ try:
502
+ result = cast("SQLResult", self.execute(sql_obj)) # type: ignore[attr-defined]
503
+ except Exception:
504
+ # Fall back to direct execution
505
+ compiled_sql, compiled_params = sql_obj.compile("qmark")
506
+ driver_result = self._execute(compiled_sql, compiled_params, sql_obj) # type: ignore[attr-defined]
507
+ if "data" in driver_result:
508
+ result = self._wrap_select_result(sql_obj, driver_result) # type: ignore[attr-defined]
509
+ else:
510
+ result = self._wrap_execute_result(sql_obj, driver_result) # type: ignore[attr-defined]
511
+
512
+ # For parquet format, convert through Arrow
513
+ if format == "parquet":
514
+ arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
515
+ backend.write_arrow(path, arrow_table, **options)
516
+ return len(result.data or [])
517
+
518
+ # Convert to appropriate format and write to backend
519
+ compression = options.get("compression")
520
+
521
+ # Create temp file with appropriate suffix
522
+ suffix = f".{format}"
523
+ if compression == "gzip":
524
+ suffix += ".gz"
525
+
526
+ with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding="utf-8") as tmp:
527
+ tmp_path = Path(tmp.name)
528
+
529
+ # Handle compression and writing
530
+ if compression == "gzip":
531
+ import gzip
532
+
533
+ with gzip.open(tmp_path, "wt", encoding="utf-8") as file_to_write:
534
+ if format == "csv":
535
+ self._write_csv(result, file_to_write, **options)
536
+ elif format == "json":
537
+ self._write_json(result, file_to_write, **options)
538
+ else:
539
+ msg = f"Unsupported format for backend export: {format}"
540
+ raise ValueError(msg)
541
+ else:
542
+ with tmp_path.open("w", encoding="utf-8") as file_to_write:
543
+ if format == "csv":
544
+ self._write_csv(result, file_to_write, **options)
545
+ elif format == "json":
546
+ self._write_json(result, file_to_write, **options)
547
+ else:
548
+ msg = f"Unsupported format for backend export: {format}"
549
+ raise ValueError(msg)
550
+
551
+ try:
552
+ # Upload to storage backend
553
+ # Adjust path if compression was used
554
+ final_path = path
555
+ if compression == "gzip" and not path.endswith(".gz"):
556
+ final_path = path + ".gz"
557
+
558
+ backend.write_bytes(final_path, tmp_path.read_bytes())
559
+ return result.rows_affected or len(result.data or [])
560
+ finally:
561
+ tmp_path.unlink(missing_ok=True)
562
+
563
+ def _import_via_backend(
564
+ self, backend: "ObjectStoreProtocol", path: str, table_name: str, format: str, mode: str, **options: Any
565
+ ) -> int:
566
+ """Import via storage backend using temporary file."""
567
+ # Download from storage backend
568
+ data = backend.read_bytes(path)
569
+
570
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{format}", delete=False) as tmp:
571
+ tmp.write(data)
572
+ tmp_path = Path(tmp.name)
573
+
574
+ try:
575
+ # Use database's bulk load capabilities
576
+ return self._bulk_load_file(tmp_path, table_name, format, mode, **options)
577
+ finally:
578
+ tmp_path.unlink(missing_ok=True)
579
+
580
+ @staticmethod
581
+ def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
582
+ """Write result to CSV file."""
583
+ # Remove options that csv.writer doesn't understand
584
+ csv_options = options.copy()
585
+ csv_options.pop("compression", None) # Handle compression separately
586
+ csv_options.pop("partition_by", None) # Not applicable to CSV
587
+
588
+ writer = csv.writer(file, **csv_options) # TODO: anything better?
589
+ if result.column_names:
590
+ writer.writerow(result.column_names)
591
+ if result.data:
592
+ # Handle dict rows by extracting values in column order
593
+ if result.data and isinstance(result.data[0], dict):
594
+ rows = []
595
+ for row_dict in result.data:
596
+ # Extract values in the same order as column_names
597
+ row_values = [row_dict.get(col) for col in result.column_names or []]
598
+ rows.append(row_values)
599
+ writer.writerows(rows)
600
+ else:
601
+ writer.writerows(result.data)
602
+
603
+ @staticmethod
604
+ def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
605
+ """Write result to JSON file."""
606
+
607
+ if result.data and result.column_names:
608
+ # Check if data is already in dict format
609
+ if result.data and isinstance(result.data[0], dict):
610
+ # Data is already dictionaries, use as-is
611
+ rows = result.data
612
+ else:
613
+ # Convert tuples/lists to list of dicts
614
+ rows = [dict(zip(result.column_names, row)) for row in result.data]
615
+ json.dump(rows, file, **options) # TODO: use sqlspec.utils.serializer
616
+ else:
617
+ json.dump([], file) # TODO: use sqlspec.utils.serializer
618
+
619
+ def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
620
+ """Database-specific bulk load implementation. Override in drivers."""
621
+ msg = "Driver should implement _bulk_load_file"
622
+ raise NotImplementedError(msg)
623
+
624
+
625
+ class AsyncStorageMixin(StorageMixinBase):
626
+ """Unified storage operations for asynchronous drivers."""
627
+
628
+ __slots__ = ()
629
+
630
+ async def ingest_arrow_table(
631
+ self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any
632
+ ) -> int:
633
+ """Ingest an Arrow table into the database asynchronously.
634
+
635
+ This public method provides a consistent entry point and can be used for
636
+ instrumentation, logging, etc., while delegating the actual work to the
637
+ driver-specific `_ingest_arrow_table` implementation.
638
+ """
639
+ self._ensure_pyarrow_installed()
640
+ return await self._ingest_arrow_table(table, table_name, mode, **options)
641
+
642
+ async def _ingest_arrow_table(
643
+ self, table: "ArrowTable", table_name: str, mode: str = "create", **options: Any
644
+ ) -> int:
645
+ """Generic async fallback for ingesting an Arrow table.
646
+
647
+ This implementation writes the Arrow table to a temporary Parquet file
648
+ and then uses the driver's generic `_bulk_load_file` capability.
649
+ Drivers with more efficient, native Arrow ingestion methods should override this.
650
+ """
651
+ import pyarrow.parquet as pq
652
+
653
+ # Use an async-friendly way to handle the temporary file if possible,
654
+ # but for simplicity, standard tempfile is acceptable here as it's a fallback.
655
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp:
656
+ tmp_path = Path(tmp.name)
657
+ await async_(pq.write_table)(table, tmp_path) # pyright: ignore
658
+
659
+ try:
660
+ # Use database's async bulk load capabilities for Parquet
661
+ return await self._bulk_load_file(tmp_path, table_name, "parquet", mode, **options)
662
+ finally:
663
+ tmp_path.unlink(missing_ok=True)
664
+
665
+ # ============================================================================
666
+ # Core Arrow Operations (Async)
667
+ # ============================================================================
668
+
669
+ async def fetch_arrow_table(
670
+ self,
671
+ statement: "Statement",
672
+ /,
673
+ *parameters: "Union[StatementParameters, StatementFilter]",
674
+ _connection: "Optional[ConnectionT]" = None,
675
+ _config: "Optional[SQLConfig]" = None,
676
+ **kwargs: Any,
677
+ ) -> "ArrowResult":
678
+ """Async fetch query results as Arrow table with intelligent routing.
679
+
680
+ Args:
681
+ statement: SQL statement (string, SQL object, or sqlglot Expression)
682
+ *parameters: Mixed parameters and filters
683
+ _connection: Optional connection override
684
+ _config: Optional SQL config override
685
+ **kwargs: Additional options
686
+
687
+ Returns:
688
+ ArrowResult wrapping the Arrow table
689
+ """
690
+ self._ensure_pyarrow_installed()
691
+
692
+ filters, params = _separate_filters_from_parameters(parameters)
693
+ # Convert to SQL object for processing
694
+ # Use a custom config if transformations will add parameters
695
+ if _config is None:
696
+ _config = self.config
697
+
698
+ # If no parameters provided but we have transformations enabled,
699
+ # disable parameter validation entirely to allow transformer-added parameters
700
+ if params is None and _config and _config.enable_transformations:
701
+ from dataclasses import replace
702
+
703
+ # Disable validation entirely for transformer-generated parameters
704
+ _config = replace(_config, strict_mode=False, enable_validation=False)
705
+
706
+ # Only pass params if it's not None to avoid adding None as a parameter
707
+ if params is not None:
708
+ sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect, **kwargs)
709
+ else:
710
+ sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect, **kwargs)
711
+
712
+ # Delegate to protected method that drivers can override
713
+ return await self._fetch_arrow_table(sql, connection=_connection, **kwargs)
714
+
715
+ async def _fetch_arrow_table(
716
+ self, sql: SQL, connection: "Optional[ConnectionT]" = None, **kwargs: Any
717
+ ) -> "ArrowResult":
718
+ """Generic async fallback for Arrow table fetching.
719
+
720
+ This method executes a regular query and converts the results to Arrow format.
721
+ Drivers should override this method to provide native Arrow support if available.
722
+ If a driver has partial native support, it can call `super()._fetch_arrow_table(...)`
723
+ to use this fallback implementation.
724
+
725
+ Args:
726
+ sql: SQL object to execute
727
+ connection: Optional connection override
728
+ **kwargs: Additional options (unused in fallback)
729
+
730
+ Returns:
731
+ ArrowResult with converted data
732
+ """
733
+ # Execute regular query
734
+ result = await self.execute(sql, _connection=connection) # type: ignore[attr-defined]
735
+
736
+ # Convert to Arrow table
737
+ arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
738
+
739
+ return ArrowResult(statement=sql, data=arrow_table)
740
+
741
+ async def export_to_storage(
742
+ self,
743
+ statement: "Statement",
744
+ /,
745
+ *parameters: "Union[StatementParameters, StatementFilter]",
746
+ destination_uri: str,
747
+ format: "Optional[str]" = None,
748
+ _connection: "Optional[ConnectionT]" = None,
749
+ _config: "Optional[SQLConfig]" = None,
750
+ **options: Any,
751
+ ) -> int:
752
+ # Create SQL object with proper parameter handling
753
+ filters, params = _separate_filters_from_parameters(parameters)
754
+
755
+ # For storage operations, disable transformations that might add unwanted parameters
756
+ if _config is None:
757
+ _config = self.config
758
+ if _config and _config.enable_transformations:
759
+ from dataclasses import replace
760
+
761
+ _config = replace(_config, enable_transformations=False)
762
+
763
+ if params is not None:
764
+ sql = SQL(statement, params, *filters, _config=_config, _dialect=self.dialect, **options)
765
+ else:
766
+ sql = SQL(statement, *filters, _config=_config, _dialect=self.dialect, **options)
767
+
768
+ return await self._export_to_storage(sql, destination_uri, format, connection=_connection, **options)
769
+
770
+ async def _export_to_storage(
771
+ self,
772
+ query: "SQL",
773
+ destination_uri: str,
774
+ format: "Optional[str]" = None,
775
+ connection: "Optional[ConnectionT]" = None,
776
+ **options: Any,
777
+ ) -> int:
778
+ """Protected async method for export operation implementation.
779
+
780
+ Args:
781
+ query: SQL query to execute and export
782
+ destination_uri: URI to export data to
783
+ format: Optional format override (auto-detected from URI if not provided)
784
+ connection: Optional connection override
785
+ **options: Additional export options
786
+
787
+ Returns:
788
+ Number of rows exported
789
+ """
790
+ # Auto-detect format if not provided
791
+ # If no format is specified and detection fails (returns "csv" as default),
792
+ # default to "parquet" for export operations as it's the most common use case
793
+ detected_format = self._detect_format(destination_uri)
794
+ if format:
795
+ file_format = format
796
+ elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
797
+ # Detection returned default "csv" but file doesn't actually have CSV extension
798
+ # Default to parquet for better compatibility with tests and common usage
799
+ file_format = "parquet"
800
+ else:
801
+ file_format = detected_format
802
+
803
+ # Special handling for parquet format - if we're exporting to parquet but the
804
+ # destination doesn't have .parquet extension, add it to ensure compatibility
805
+ # with pyarrow.parquet.read_table() which requires the extension
806
+ if file_format == "parquet" and not destination_uri.endswith(".parquet"):
807
+ destination_uri = f"{destination_uri}.parquet"
808
+
809
+ # Use storage backend - resolve AFTER modifying destination_uri
810
+ backend, path = self._resolve_backend_and_path(destination_uri)
811
+
812
+ # Try native database export first
813
+ if file_format == "parquet" and self.supports_native_parquet_export:
814
+ return await self._export_native(query.as_script().sql, destination_uri, file_format, **options)
815
+
816
+ if file_format == "parquet":
817
+ # For parquet export via Arrow, we need to ensure no unwanted parameter transformations
818
+ # If the query already has parameters from transformations, create a fresh SQL object
819
+ if hasattr(query, "parameters") and query.parameters and hasattr(query, "_raw_sql"):
820
+ # Create fresh SQL object from raw SQL without transformations
821
+ fresh_sql = SQL(
822
+ query._raw_sql,
823
+ _config=replace(self.config, enable_transformations=False)
824
+ if self.config
825
+ else SQLConfig(enable_transformations=False),
826
+ _dialect=self.dialect,
827
+ )
828
+ arrow_result = await self._fetch_arrow_table(fresh_sql, connection=connection, **options)
829
+ else:
830
+ # query is already a SQL object, call _fetch_arrow_table directly
831
+ arrow_result = await self._fetch_arrow_table(query, connection=connection, **options)
832
+ arrow_table = arrow_result.data
833
+ if arrow_table is not None:
834
+ await backend.write_arrow_async(path, arrow_table, **options)
835
+ return arrow_table.num_rows
836
+ return 0
837
+
838
+ return await self._export_via_backend(query, backend, path, file_format, **options)
839
+
840
+ async def import_from_storage(
841
+ self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
842
+ ) -> int:
843
+ """Async import data from storage with intelligent routing.
844
+
845
+ Provides instrumentation and delegates to _import_from_storage() for consistent operation.
846
+
847
+ Args:
848
+ source_uri: URI to import data from
849
+ table_name: Target table name
850
+ format: Optional format override (auto-detected from URI if not provided)
851
+ mode: Import mode ('create', 'append', 'replace')
852
+ **options: Additional import options
853
+
854
+ Returns:
855
+ Number of rows imported
856
+ """
857
+ return await self._import_from_storage(source_uri, table_name, format, mode, **options)
858
+
859
+ async def _import_from_storage(
860
+ self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
861
+ ) -> int:
862
+ """Protected async method for import operation implementation.
863
+
864
+ Args:
865
+ source_uri: URI to import data from
866
+ table_name: Target table name
867
+ format: Optional format override (auto-detected from URI if not provided)
868
+ mode: Import mode ('create', 'append', 'replace')
869
+ **options: Additional import options
870
+
871
+ Returns:
872
+ Number of rows imported
873
+ """
874
+ file_format = format or self._detect_format(source_uri)
875
+ backend, path = self._resolve_backend_and_path(source_uri)
876
+
877
+ if file_format == "parquet":
878
+ arrow_table = await backend.read_arrow_async(path, **options)
879
+ return await self.ingest_arrow_table(arrow_table, table_name, mode=mode)
880
+
881
+ return await self._import_via_backend(backend, path, table_name, file_format, mode, **options)
882
+
883
+ # ============================================================================
884
+ # Async Database-Specific Implementation Hooks
885
+ # ============================================================================
886
+
887
+ async def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
888
+ """Async database-specific native export."""
889
+ msg = "Driver should implement _export_native"
890
+ raise NotImplementedError(msg)
891
+
892
+ async def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
893
+ """Async database-specific native import."""
894
+ msg = "Driver should implement _import_native"
895
+ raise NotImplementedError(msg)
896
+
897
+ async def _export_via_backend(
898
+ self, sql_obj: "SQL", backend: "ObjectStoreProtocol", path: str, format: str, **options: Any
899
+ ) -> int:
900
+ """Async export via storage backend."""
901
+
902
+ # Execute query and get results - use the SQL object directly
903
+ try:
904
+ result = await self.execute(sql_obj) # type: ignore[attr-defined]
905
+ except Exception:
906
+ # Fall back to direct execution
907
+ compiled_sql, compiled_params = sql_obj.compile("qmark")
908
+ driver_result = await self._execute(compiled_sql, compiled_params, sql_obj) # type: ignore[attr-defined]
909
+ if "data" in driver_result:
910
+ result = self._wrap_select_result(sql_obj, driver_result) # type: ignore[attr-defined]
911
+ else:
912
+ result = self._wrap_execute_result(sql_obj, driver_result) # type: ignore[attr-defined]
913
+
914
+ # For parquet format, convert through Arrow
915
+ if format == "parquet":
916
+ arrow_table = self._rows_to_arrow_table(result.data or [], result.column_names or [])
917
+ await backend.write_arrow_async(path, arrow_table, **options)
918
+ return len(result.data or [])
919
+
920
+ # Convert to appropriate format and write to backend
921
+ with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False, encoding="utf-8") as tmp:
922
+ if format == "csv":
923
+ self._write_csv(result, tmp, **options)
924
+ elif format == "json":
925
+ self._write_json(result, tmp, **options)
926
+ else:
927
+ msg = f"Unsupported format for backend export: {format}"
928
+ raise ValueError(msg)
929
+
930
+ tmp_path = Path(tmp.name)
931
+
932
+ try:
933
+ # Upload to storage backend (async if supported)
934
+ await backend.write_bytes_async(path, tmp_path.read_bytes())
935
+ return result.rows_affected or len(result.data or [])
936
+ finally:
937
+ tmp_path.unlink(missing_ok=True)
938
+
939
+ async def _import_via_backend(
940
+ self, backend: "ObjectStoreProtocol", path: str, table_name: str, format: str, mode: str, **options: Any
941
+ ) -> int:
942
+ """Async import via storage backend."""
943
+ # Download from storage backend (async if supported)
944
+ data = await backend.read_bytes_async(path)
945
+
946
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{format}", delete=False) as tmp:
947
+ tmp.write(data)
948
+ tmp_path = Path(tmp.name)
949
+
950
+ try:
951
+ return await self._bulk_load_file(tmp_path, table_name, format, mode, **options)
952
+ finally:
953
+ tmp_path.unlink(missing_ok=True)
954
+
955
+ @staticmethod
956
+ def _write_csv(result: "SQLResult", file: Any, **options: Any) -> None:
957
+ """Reuse sync implementation."""
958
+
959
+ writer = csv.writer(file, **options)
960
+ if result.column_names:
961
+ writer.writerow(result.column_names)
962
+ if result.data:
963
+ # Handle dict rows by extracting values in column order
964
+ if result.data and isinstance(result.data[0], dict):
965
+ rows = []
966
+ for row_dict in result.data:
967
+ # Extract values in the same order as column_names
968
+ row_values = [row_dict.get(col) for col in result.column_names or []]
969
+ rows.append(row_values)
970
+ writer.writerows(rows)
971
+ else:
972
+ writer.writerows(result.data)
973
+
974
+ @staticmethod
975
+ def _write_json(result: "SQLResult", file: Any, **options: Any) -> None:
976
+ """Reuse sync implementation."""
977
+
978
+ if result.data and result.column_names:
979
+ # Check if data is already in dict format
980
+ if result.data and isinstance(result.data[0], dict):
981
+ # Data is already dictionaries, use as-is
982
+ rows = result.data
983
+ else:
984
+ # Convert tuples/lists to list of dicts
985
+ rows = [dict(zip(result.column_names, row)) for row in result.data]
986
+ json.dump(rows, file, **options)
987
+ else:
988
+ json.dump([], file)
989
+
990
+ async def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
991
+ """Async database-specific bulk load implementation."""
992
+ msg = "Driver should implement _bulk_load_file"
993
+ raise NotImplementedError(msg)