sqlspec 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/adapters/aiosqlite/driver.py +16 -11
- sqlspec/adapters/bigquery/driver.py +113 -21
- sqlspec/adapters/duckdb/driver.py +18 -13
- sqlspec/adapters/psycopg/config.py +55 -54
- sqlspec/adapters/psycopg/driver.py +82 -1
- sqlspec/adapters/sqlite/driver.py +50 -10
- sqlspec/driver/mixins/_storage.py +83 -36
- sqlspec/loader.py +8 -30
- sqlspec/statement/builder/base.py +3 -1
- sqlspec/statement/builder/ddl.py +14 -1
- sqlspec/statement/pipelines/analyzers/_analyzer.py +1 -5
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +56 -2
- sqlspec/statement/sql.py +40 -6
- sqlspec/storage/backends/fsspec.py +29 -27
- sqlspec/storage/backends/obstore.py +55 -34
- sqlspec/storage/protocol.py +28 -25
- {sqlspec-0.12.0.dist-info → sqlspec-0.12.2.dist-info}/METADATA +1 -1
- {sqlspec-0.12.0.dist-info → sqlspec-0.12.2.dist-info}/RECORD +21 -21
- {sqlspec-0.12.0.dist-info → sqlspec-0.12.2.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.0.dist-info → sqlspec-0.12.2.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.0.dist-info → sqlspec-0.12.2.dist-info}/licenses/NOTICE +0 -0
|
@@ -197,8 +197,41 @@ class SqliteDriver(
|
|
|
197
197
|
result: ScriptResultDict = {"statements_executed": -1, "status_message": "SCRIPT EXECUTED"}
|
|
198
198
|
return result
|
|
199
199
|
|
|
200
|
+
def _ingest_arrow_table(self, table: Any, table_name: str, mode: str = "create", **options: Any) -> int:
|
|
201
|
+
"""SQLite-specific Arrow table ingestion using CSV conversion.
|
|
202
|
+
|
|
203
|
+
Since SQLite only supports CSV bulk loading, we convert the Arrow table
|
|
204
|
+
to CSV format first using the storage backend for efficient operations.
|
|
205
|
+
"""
|
|
206
|
+
import io
|
|
207
|
+
import tempfile
|
|
208
|
+
|
|
209
|
+
import pyarrow.csv as pa_csv
|
|
210
|
+
|
|
211
|
+
# Convert Arrow table to CSV in memory
|
|
212
|
+
csv_buffer = io.BytesIO()
|
|
213
|
+
pa_csv.write_csv(table, csv_buffer)
|
|
214
|
+
csv_content = csv_buffer.getvalue()
|
|
215
|
+
|
|
216
|
+
# Create a temporary file path
|
|
217
|
+
temp_filename = f"sqlspec_temp_{table_name}_{id(self)}.csv"
|
|
218
|
+
temp_path = Path(tempfile.gettempdir()) / temp_filename
|
|
219
|
+
|
|
220
|
+
# Use storage backend to write the CSV content
|
|
221
|
+
backend = self._get_storage_backend(temp_path)
|
|
222
|
+
backend.write_bytes(str(temp_path), csv_content)
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
# Use SQLite's CSV bulk load
|
|
226
|
+
return self._bulk_load_file(temp_path, table_name, "csv", mode, **options)
|
|
227
|
+
finally:
|
|
228
|
+
# Clean up using storage backend
|
|
229
|
+
with contextlib.suppress(Exception):
|
|
230
|
+
# Best effort cleanup
|
|
231
|
+
backend.delete(str(temp_path))
|
|
232
|
+
|
|
200
233
|
def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
|
|
201
|
-
"""Database-specific bulk load implementation."""
|
|
234
|
+
"""Database-specific bulk load implementation using storage backend."""
|
|
202
235
|
if format != "csv":
|
|
203
236
|
msg = f"SQLite driver only supports CSV for bulk loading, not {format}."
|
|
204
237
|
raise NotImplementedError(msg)
|
|
@@ -208,16 +241,23 @@ class SqliteDriver(
|
|
|
208
241
|
if mode == "replace":
|
|
209
242
|
cursor.execute(f"DELETE FROM {table_name}")
|
|
210
243
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
244
|
+
# Use storage backend to read the file
|
|
245
|
+
backend = self._get_storage_backend(file_path)
|
|
246
|
+
content = backend.read_text(str(file_path), encoding="utf-8")
|
|
247
|
+
|
|
248
|
+
# Parse CSV content
|
|
249
|
+
import io
|
|
250
|
+
|
|
251
|
+
csv_file = io.StringIO(content)
|
|
252
|
+
reader = csv.reader(csv_file, **options)
|
|
253
|
+
header = next(reader) # Skip header
|
|
254
|
+
placeholders = ", ".join("?" for _ in header)
|
|
255
|
+
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
|
|
216
256
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
257
|
+
# executemany is efficient for bulk inserts
|
|
258
|
+
data_iter = list(reader) # Read all data into memory
|
|
259
|
+
cursor.executemany(sql, data_iter)
|
|
260
|
+
return cursor.rowcount
|
|
221
261
|
|
|
222
262
|
def _wrap_select_result(
|
|
223
263
|
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
@@ -85,25 +85,30 @@ class StorageMixinBase(ABC):
|
|
|
85
85
|
raise MissingDependencyError(msg)
|
|
86
86
|
|
|
87
87
|
@staticmethod
|
|
88
|
-
def _get_storage_backend(uri_or_key: str) -> "ObjectStoreProtocol":
|
|
88
|
+
def _get_storage_backend(uri_or_key: "Union[str, Path]") -> "ObjectStoreProtocol":
|
|
89
89
|
"""Get storage backend by URI or key with intelligent routing."""
|
|
90
|
-
|
|
90
|
+
# Pass Path objects directly to storage registry for proper URI conversion
|
|
91
|
+
if isinstance(uri_or_key, Path):
|
|
92
|
+
return storage_registry.get(uri_or_key)
|
|
93
|
+
return storage_registry.get(str(uri_or_key))
|
|
91
94
|
|
|
92
95
|
@staticmethod
|
|
93
|
-
def _is_uri(path_or_uri: str) -> bool:
|
|
96
|
+
def _is_uri(path_or_uri: "Union[str, Path]") -> bool:
|
|
94
97
|
"""Check if input is a URI rather than a relative path."""
|
|
98
|
+
path_str = str(path_or_uri)
|
|
95
99
|
schemes = {"s3", "gs", "gcs", "az", "azure", "abfs", "abfss", "file", "http", "https"}
|
|
96
|
-
if "://" in
|
|
97
|
-
scheme =
|
|
100
|
+
if "://" in path_str:
|
|
101
|
+
scheme = path_str.split("://", maxsplit=1)[0].lower()
|
|
98
102
|
return scheme in schemes
|
|
99
|
-
if len(
|
|
103
|
+
if len(path_str) >= WINDOWS_PATH_MIN_LENGTH and path_str[1:3] == ":\\":
|
|
100
104
|
return True
|
|
101
|
-
return bool(
|
|
105
|
+
return bool(path_str.startswith("/"))
|
|
102
106
|
|
|
103
107
|
@staticmethod
|
|
104
|
-
def _detect_format(uri: str) -> str:
|
|
108
|
+
def _detect_format(uri: "Union[str, Path]") -> str:
|
|
105
109
|
"""Detect file format from URI extension."""
|
|
106
|
-
|
|
110
|
+
uri_str = str(uri)
|
|
111
|
+
parsed = urlparse(uri_str)
|
|
107
112
|
path = Path(parsed.path)
|
|
108
113
|
extension = path.suffix.lower().lstrip(".")
|
|
109
114
|
|
|
@@ -120,28 +125,28 @@ class StorageMixinBase(ABC):
|
|
|
120
125
|
|
|
121
126
|
return format_map.get(extension, "csv")
|
|
122
127
|
|
|
123
|
-
def _resolve_backend_and_path(self, uri: str) -> "tuple[ObjectStoreProtocol, str]":
|
|
128
|
+
def _resolve_backend_and_path(self, uri: "Union[str, Path]") -> "tuple[ObjectStoreProtocol, str]":
|
|
124
129
|
"""Resolve backend and path from URI with Phase 3 URI-first routing.
|
|
125
130
|
|
|
126
131
|
Args:
|
|
127
|
-
uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path")
|
|
132
|
+
uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path", Path object)
|
|
128
133
|
|
|
129
134
|
Returns:
|
|
130
135
|
Tuple of (backend, path) where path is relative to the backend's base path
|
|
131
136
|
"""
|
|
132
137
|
# Convert Path objects to string
|
|
133
|
-
|
|
134
|
-
original_path =
|
|
138
|
+
uri_str = str(uri)
|
|
139
|
+
original_path = uri_str
|
|
135
140
|
|
|
136
141
|
# Convert absolute paths to file:// URIs if needed
|
|
137
|
-
if self._is_uri(
|
|
142
|
+
if self._is_uri(uri_str) and "://" not in uri_str:
|
|
138
143
|
# It's an absolute path without scheme
|
|
139
|
-
|
|
144
|
+
uri_str = f"file://{uri_str}"
|
|
140
145
|
|
|
141
|
-
backend = self._get_storage_backend(
|
|
146
|
+
backend = self._get_storage_backend(uri_str)
|
|
142
147
|
|
|
143
148
|
# For file:// URIs, return just the path part for the backend
|
|
144
|
-
path =
|
|
149
|
+
path = uri_str[7:] if uri_str.startswith("file://") else original_path
|
|
145
150
|
|
|
146
151
|
return backend, path
|
|
147
152
|
|
|
@@ -293,7 +298,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
293
298
|
statement: "Statement",
|
|
294
299
|
/,
|
|
295
300
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
296
|
-
destination_uri: str,
|
|
301
|
+
destination_uri: "Union[str, Path]",
|
|
297
302
|
format: "Optional[str]" = None,
|
|
298
303
|
_connection: "Optional[ConnectionT]" = None,
|
|
299
304
|
_config: "Optional[SQLConfig]" = None,
|
|
@@ -340,7 +345,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
340
345
|
statement: "Statement",
|
|
341
346
|
/,
|
|
342
347
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
343
|
-
destination_uri: str,
|
|
348
|
+
destination_uri: "Union[str, Path]",
|
|
344
349
|
format: "Optional[str]" = None,
|
|
345
350
|
_connection: "Optional[ConnectionT]" = None,
|
|
346
351
|
_config: "Optional[SQLConfig]" = None,
|
|
@@ -360,7 +365,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
360
365
|
detected_format = self._detect_format(destination_uri)
|
|
361
366
|
if format:
|
|
362
367
|
file_format = format
|
|
363
|
-
elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
|
|
368
|
+
elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
|
|
364
369
|
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
365
370
|
# Default to parquet for better compatibility with tests and common usage
|
|
366
371
|
file_format = "parquet"
|
|
@@ -370,7 +375,7 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
370
375
|
# Special handling for parquet format - if we're exporting to parquet but the
|
|
371
376
|
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
372
377
|
# with pyarrow.parquet.read_table() which requires the extension
|
|
373
|
-
if file_format == "parquet" and not destination_uri.endswith(".parquet"):
|
|
378
|
+
if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
|
|
374
379
|
destination_uri = f"{destination_uri}.parquet"
|
|
375
380
|
|
|
376
381
|
# Use storage backend - resolve AFTER modifying destination_uri
|
|
@@ -412,7 +417,12 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
412
417
|
return self._export_via_backend(sql_obj, backend, path, file_format, **kwargs)
|
|
413
418
|
|
|
414
419
|
def import_from_storage(
|
|
415
|
-
self,
|
|
420
|
+
self,
|
|
421
|
+
source_uri: "Union[str, Path]",
|
|
422
|
+
table_name: str,
|
|
423
|
+
format: "Optional[str]" = None,
|
|
424
|
+
mode: str = "create",
|
|
425
|
+
**options: Any,
|
|
416
426
|
) -> int:
|
|
417
427
|
"""Import data from storage with intelligent routing.
|
|
418
428
|
|
|
@@ -431,7 +441,12 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
431
441
|
return self._import_from_storage(source_uri, table_name, format, mode, **options)
|
|
432
442
|
|
|
433
443
|
def _import_from_storage(
|
|
434
|
-
self,
|
|
444
|
+
self,
|
|
445
|
+
source_uri: "Union[str, Path]",
|
|
446
|
+
table_name: str,
|
|
447
|
+
format: "Optional[str]" = None,
|
|
448
|
+
mode: str = "create",
|
|
449
|
+
**options: Any,
|
|
435
450
|
) -> int:
|
|
436
451
|
"""Protected method for import operation implementation.
|
|
437
452
|
|
|
@@ -461,7 +476,23 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
461
476
|
arrow_table = backend.read_arrow(path, **options)
|
|
462
477
|
return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
|
|
463
478
|
except AttributeError:
|
|
464
|
-
|
|
479
|
+
# Backend doesn't support read_arrow, try alternative approach
|
|
480
|
+
try:
|
|
481
|
+
import pyarrow.parquet as pq
|
|
482
|
+
|
|
483
|
+
# Read Parquet file directly
|
|
484
|
+
with tempfile.NamedTemporaryFile(mode="wb", suffix=".parquet", delete=False) as tmp:
|
|
485
|
+
tmp.write(backend.read_bytes(path))
|
|
486
|
+
tmp_path = Path(tmp.name)
|
|
487
|
+
try:
|
|
488
|
+
arrow_table = pq.read_table(tmp_path)
|
|
489
|
+
return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
|
|
490
|
+
finally:
|
|
491
|
+
tmp_path.unlink(missing_ok=True)
|
|
492
|
+
except ImportError:
|
|
493
|
+
# PyArrow not installed, cannot import Parquet
|
|
494
|
+
msg = "PyArrow is required to import Parquet files. Install with: pip install pyarrow"
|
|
495
|
+
raise ImportError(msg) from None
|
|
465
496
|
|
|
466
497
|
# Use traditional import through temporary file
|
|
467
498
|
return self._import_via_backend(backend, path, table_name, file_format, mode, **options)
|
|
@@ -471,23 +502,27 @@ class SyncStorageMixin(StorageMixinBase):
|
|
|
471
502
|
# ============================================================================
|
|
472
503
|
|
|
473
504
|
def _read_parquet_native(
|
|
474
|
-
self, source_uri: str, columns: "Optional[list[str]]" = None, **options: Any
|
|
505
|
+
self, source_uri: "Union[str, Path]", columns: "Optional[list[str]]" = None, **options: Any
|
|
475
506
|
) -> "SQLResult":
|
|
476
507
|
"""Database-specific native Parquet reading. Override in drivers."""
|
|
477
508
|
msg = "Driver should implement _read_parquet_native"
|
|
478
509
|
raise NotImplementedError(msg)
|
|
479
510
|
|
|
480
|
-
def _write_parquet_native(
|
|
511
|
+
def _write_parquet_native(
|
|
512
|
+
self, data: Union[str, ArrowTable], destination_uri: "Union[str, Path]", **options: Any
|
|
513
|
+
) -> None:
|
|
481
514
|
"""Database-specific native Parquet writing. Override in drivers."""
|
|
482
515
|
msg = "Driver should implement _write_parquet_native"
|
|
483
516
|
raise NotImplementedError(msg)
|
|
484
517
|
|
|
485
|
-
def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
518
|
+
def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
|
|
486
519
|
"""Database-specific native export. Override in drivers."""
|
|
487
520
|
msg = "Driver should implement _export_native"
|
|
488
521
|
raise NotImplementedError(msg)
|
|
489
522
|
|
|
490
|
-
def _import_native(
|
|
523
|
+
def _import_native(
|
|
524
|
+
self, source_uri: "Union[str, Path]", table_name: str, format: str, mode: str, **options: Any
|
|
525
|
+
) -> int:
|
|
491
526
|
"""Database-specific native import. Override in drivers."""
|
|
492
527
|
msg = "Driver should implement _import_native"
|
|
493
528
|
raise NotImplementedError(msg)
|
|
@@ -743,7 +778,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
743
778
|
statement: "Statement",
|
|
744
779
|
/,
|
|
745
780
|
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
746
|
-
destination_uri: str,
|
|
781
|
+
destination_uri: "Union[str, Path]",
|
|
747
782
|
format: "Optional[str]" = None,
|
|
748
783
|
_connection: "Optional[ConnectionT]" = None,
|
|
749
784
|
_config: "Optional[SQLConfig]" = None,
|
|
@@ -770,7 +805,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
770
805
|
async def _export_to_storage(
|
|
771
806
|
self,
|
|
772
807
|
query: "SQL",
|
|
773
|
-
destination_uri: str,
|
|
808
|
+
destination_uri: "Union[str, Path]",
|
|
774
809
|
format: "Optional[str]" = None,
|
|
775
810
|
connection: "Optional[ConnectionT]" = None,
|
|
776
811
|
**options: Any,
|
|
@@ -793,7 +828,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
793
828
|
detected_format = self._detect_format(destination_uri)
|
|
794
829
|
if format:
|
|
795
830
|
file_format = format
|
|
796
|
-
elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
|
|
831
|
+
elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
|
|
797
832
|
# Detection returned default "csv" but file doesn't actually have CSV extension
|
|
798
833
|
# Default to parquet for better compatibility with tests and common usage
|
|
799
834
|
file_format = "parquet"
|
|
@@ -803,7 +838,7 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
803
838
|
# Special handling for parquet format - if we're exporting to parquet but the
|
|
804
839
|
# destination doesn't have .parquet extension, add it to ensure compatibility
|
|
805
840
|
# with pyarrow.parquet.read_table() which requires the extension
|
|
806
|
-
if file_format == "parquet" and not destination_uri.endswith(".parquet"):
|
|
841
|
+
if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
|
|
807
842
|
destination_uri = f"{destination_uri}.parquet"
|
|
808
843
|
|
|
809
844
|
# Use storage backend - resolve AFTER modifying destination_uri
|
|
@@ -838,7 +873,12 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
838
873
|
return await self._export_via_backend(query, backend, path, file_format, **options)
|
|
839
874
|
|
|
840
875
|
async def import_from_storage(
|
|
841
|
-
self,
|
|
876
|
+
self,
|
|
877
|
+
source_uri: "Union[str, Path]",
|
|
878
|
+
table_name: str,
|
|
879
|
+
format: "Optional[str]" = None,
|
|
880
|
+
mode: str = "create",
|
|
881
|
+
**options: Any,
|
|
842
882
|
) -> int:
|
|
843
883
|
"""Async import data from storage with intelligent routing.
|
|
844
884
|
|
|
@@ -857,7 +897,12 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
857
897
|
return await self._import_from_storage(source_uri, table_name, format, mode, **options)
|
|
858
898
|
|
|
859
899
|
async def _import_from_storage(
|
|
860
|
-
self,
|
|
900
|
+
self,
|
|
901
|
+
source_uri: "Union[str, Path]",
|
|
902
|
+
table_name: str,
|
|
903
|
+
format: "Optional[str]" = None,
|
|
904
|
+
mode: str = "create",
|
|
905
|
+
**options: Any,
|
|
861
906
|
) -> int:
|
|
862
907
|
"""Protected async method for import operation implementation.
|
|
863
908
|
|
|
@@ -884,12 +929,14 @@ class AsyncStorageMixin(StorageMixinBase):
|
|
|
884
929
|
# Async Database-Specific Implementation Hooks
|
|
885
930
|
# ============================================================================
|
|
886
931
|
|
|
887
|
-
async def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
|
|
932
|
+
async def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
|
|
888
933
|
"""Async database-specific native export."""
|
|
889
934
|
msg = "Driver should implement _export_native"
|
|
890
935
|
raise NotImplementedError(msg)
|
|
891
936
|
|
|
892
|
-
async def _import_native(
|
|
937
|
+
async def _import_native(
|
|
938
|
+
self, source_uri: "Union[str, Path]", table_name: str, format: str, mode: str, **options: Any
|
|
939
|
+
) -> int:
|
|
893
940
|
"""Async database-specific native import."""
|
|
894
941
|
msg = "Driver should implement _import_native"
|
|
895
942
|
raise NotImplementedError(msg)
|
sqlspec/loader.py
CHANGED
|
@@ -113,7 +113,7 @@ class SQLFileLoader:
|
|
|
113
113
|
self._query_to_file: dict[str, str] = {} # Maps query name to file path
|
|
114
114
|
|
|
115
115
|
def _read_file_content(self, path: Union[str, Path]) -> str:
|
|
116
|
-
"""Read file content using
|
|
116
|
+
"""Read file content using storage backend.
|
|
117
117
|
|
|
118
118
|
Args:
|
|
119
119
|
path: File path (can be local path or URI).
|
|
@@ -126,37 +126,15 @@ class SQLFileLoader:
|
|
|
126
126
|
"""
|
|
127
127
|
path_str = str(path)
|
|
128
128
|
|
|
129
|
-
# Use storage backend for URIs (anything with a scheme)
|
|
130
|
-
if "://" in path_str:
|
|
131
|
-
try:
|
|
132
|
-
backend = self.storage_registry.get(path_str)
|
|
133
|
-
return backend.read_text(path_str, encoding=self.encoding)
|
|
134
|
-
except KeyError as e:
|
|
135
|
-
raise SQLFileNotFoundError(path_str) from e
|
|
136
|
-
except Exception as e:
|
|
137
|
-
raise SQLFileParseError(path_str, path_str, e) from e
|
|
138
|
-
|
|
139
|
-
# Handle local file paths
|
|
140
|
-
local_path = Path(path_str)
|
|
141
|
-
self._check_file_path(local_path)
|
|
142
|
-
content_bytes = self._read_file_content_bytes(local_path)
|
|
143
|
-
return content_bytes.decode(self.encoding)
|
|
144
|
-
|
|
145
|
-
@staticmethod
|
|
146
|
-
def _read_file_content_bytes(path: Path) -> bytes:
|
|
147
129
|
try:
|
|
148
|
-
|
|
130
|
+
# Always use storage backend for consistent behavior
|
|
131
|
+
# Pass the original path object to allow storage registry to handle Path -> file:// conversion
|
|
132
|
+
backend = self.storage_registry.get(path)
|
|
133
|
+
return backend.read_text(path_str, encoding=self.encoding)
|
|
134
|
+
except KeyError as e:
|
|
135
|
+
raise SQLFileNotFoundError(path_str) from e
|
|
149
136
|
except Exception as e:
|
|
150
|
-
raise SQLFileParseError(
|
|
151
|
-
|
|
152
|
-
@staticmethod
|
|
153
|
-
def _check_file_path(path: Union[str, Path]) -> None:
|
|
154
|
-
"""Ensure the file exists and is a valid path."""
|
|
155
|
-
path_obj = Path(path).resolve()
|
|
156
|
-
if not path_obj.exists():
|
|
157
|
-
raise SQLFileNotFoundError(str(path_obj))
|
|
158
|
-
if not path_obj.is_file():
|
|
159
|
-
raise SQLFileParseError(str(path_obj), str(path_obj), ValueError("Path is not a file"))
|
|
137
|
+
raise SQLFileParseError(path_str, path_str, e) from e
|
|
160
138
|
|
|
161
139
|
@staticmethod
|
|
162
140
|
def _strip_leading_comments(sql_text: str) -> str:
|
|
@@ -192,7 +192,9 @@ class QueryBuilder(ABC, Generic[RowT]):
|
|
|
192
192
|
self._raise_sql_builder_error(msg)
|
|
193
193
|
cte_select_expression = query._expression.copy()
|
|
194
194
|
for p_name, p_value in query._parameters.items():
|
|
195
|
-
|
|
195
|
+
# Try to preserve original parameter name, only rename if collision
|
|
196
|
+
unique_name = self._generate_unique_parameter_name(p_name)
|
|
197
|
+
self.add_parameter(p_value, unique_name)
|
|
196
198
|
|
|
197
199
|
elif isinstance(query, str):
|
|
198
200
|
try:
|
sqlspec/statement/builder/ddl.py
CHANGED
|
@@ -769,14 +769,27 @@ class CreateTableAsSelectBuilder(DDLBuilder):
|
|
|
769
769
|
select_expr = self._select_query.expression
|
|
770
770
|
select_params = getattr(self._select_query, "parameters", None)
|
|
771
771
|
elif isinstance(self._select_query, SelectBuilder):
|
|
772
|
+
# Get the expression and parameters directly
|
|
772
773
|
select_expr = getattr(self._select_query, "_expression", None)
|
|
773
774
|
select_params = getattr(self._select_query, "_parameters", None)
|
|
775
|
+
|
|
776
|
+
# Apply CTEs if present
|
|
777
|
+
with_ctes = getattr(self._select_query, "_with_ctes", {})
|
|
778
|
+
if with_ctes and select_expr and isinstance(select_expr, exp.Select):
|
|
779
|
+
# Apply CTEs directly to the SELECT expression using sqlglot's with_ method
|
|
780
|
+
for alias, cte in with_ctes.items():
|
|
781
|
+
if hasattr(select_expr, "with_"):
|
|
782
|
+
select_expr = select_expr.with_(
|
|
783
|
+
cte.this, # The CTE's SELECT expression
|
|
784
|
+
as_=alias,
|
|
785
|
+
copy=False,
|
|
786
|
+
)
|
|
774
787
|
elif isinstance(self._select_query, str):
|
|
775
788
|
select_expr = exp.maybe_parse(self._select_query)
|
|
776
789
|
select_params = None
|
|
777
790
|
else:
|
|
778
791
|
self._raise_sql_builder_error("Unsupported type for SELECT query in CTAS.")
|
|
779
|
-
if select_expr is None
|
|
792
|
+
if select_expr is None:
|
|
780
793
|
self._raise_sql_builder_error("SELECT query must be a valid SELECT expression.")
|
|
781
794
|
|
|
782
795
|
# Merge parameters from SELECT if present
|
|
@@ -324,11 +324,7 @@ class StatementAnalyzer(ProcessorProtocol):
|
|
|
324
324
|
def _analyze_subqueries(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
|
|
325
325
|
"""Analyze subquery complexity and nesting depth."""
|
|
326
326
|
subqueries: list[exp.Expression] = list(expression.find_all(exp.Subquery))
|
|
327
|
-
subqueries
|
|
328
|
-
query
|
|
329
|
-
for in_clause in expression.find_all(exp.In)
|
|
330
|
-
if (query := in_clause.args.get("query")) and isinstance(query, exp.Select)
|
|
331
|
-
)
|
|
327
|
+
# Workaround for EXISTS clauses: sqlglot doesn't wrap EXISTS subqueries in Subquery nodes
|
|
332
328
|
subqueries.extend(
|
|
333
329
|
[
|
|
334
330
|
exists_clause.this
|
|
@@ -34,7 +34,9 @@ class ParameterizationContext:
|
|
|
34
34
|
in_case_when: bool = False
|
|
35
35
|
in_array: bool = False
|
|
36
36
|
in_in_clause: bool = False
|
|
37
|
+
in_recursive_cte: bool = False
|
|
37
38
|
function_depth: int = 0
|
|
39
|
+
cte_depth: int = 0
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
class ParameterizeLiterals(ProcessorProtocol):
|
|
@@ -53,6 +55,7 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
53
55
|
preserve_boolean: Whether to preserve boolean literals as-is.
|
|
54
56
|
preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
|
|
55
57
|
preserve_in_functions: List of function names where literals should be preserved.
|
|
58
|
+
preserve_in_recursive_cte: Whether to preserve literals in recursive CTEs (default True to avoid type inference issues).
|
|
56
59
|
parameterize_arrays: Whether to parameterize array literals.
|
|
57
60
|
parameterize_in_lists: Whether to parameterize IN clause lists.
|
|
58
61
|
max_string_length: Maximum string length to parameterize.
|
|
@@ -68,6 +71,7 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
68
71
|
preserve_boolean: bool = True,
|
|
69
72
|
preserve_numbers_in_limit: bool = True,
|
|
70
73
|
preserve_in_functions: Optional[list[str]] = None,
|
|
74
|
+
preserve_in_recursive_cte: bool = True,
|
|
71
75
|
parameterize_arrays: bool = True,
|
|
72
76
|
parameterize_in_lists: bool = True,
|
|
73
77
|
max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
|
|
@@ -79,7 +83,18 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
79
83
|
self.preserve_null = preserve_null
|
|
80
84
|
self.preserve_boolean = preserve_boolean
|
|
81
85
|
self.preserve_numbers_in_limit = preserve_numbers_in_limit
|
|
82
|
-
self.
|
|
86
|
+
self.preserve_in_recursive_cte = preserve_in_recursive_cte
|
|
87
|
+
self.preserve_in_functions = preserve_in_functions or [
|
|
88
|
+
"COALESCE",
|
|
89
|
+
"IFNULL",
|
|
90
|
+
"NVL",
|
|
91
|
+
"ISNULL",
|
|
92
|
+
# Array functions that take dimension arguments
|
|
93
|
+
"ARRAYSIZE", # SQLglot converts array_length to ArraySize
|
|
94
|
+
"ARRAY_UPPER",
|
|
95
|
+
"ARRAY_LOWER",
|
|
96
|
+
"ARRAY_NDIMS",
|
|
97
|
+
]
|
|
83
98
|
self.parameterize_arrays = parameterize_arrays
|
|
84
99
|
self.parameterize_in_lists = parameterize_in_lists
|
|
85
100
|
self.max_string_length = max_string_length
|
|
@@ -162,6 +177,17 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
162
177
|
context.in_array = True
|
|
163
178
|
elif isinstance(node, exp.In):
|
|
164
179
|
context.in_in_clause = True
|
|
180
|
+
elif isinstance(node, exp.CTE):
|
|
181
|
+
context.cte_depth += 1
|
|
182
|
+
# Check if this CTE is recursive:
|
|
183
|
+
# 1. Parent WITH must be RECURSIVE
|
|
184
|
+
# 2. CTE must contain UNION (characteristic of recursive CTEs)
|
|
185
|
+
is_in_recursive_with = any(
|
|
186
|
+
isinstance(parent, exp.With) and parent.args.get("recursive", False)
|
|
187
|
+
for parent in reversed(context.parent_stack)
|
|
188
|
+
)
|
|
189
|
+
if is_in_recursive_with and self._contains_union(node):
|
|
190
|
+
context.in_recursive_cte = True
|
|
165
191
|
else:
|
|
166
192
|
if context.parent_stack:
|
|
167
193
|
context.parent_stack.pop()
|
|
@@ -176,6 +202,10 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
176
202
|
context.in_array = False
|
|
177
203
|
elif isinstance(node, exp.In):
|
|
178
204
|
context.in_in_clause = False
|
|
205
|
+
elif isinstance(node, exp.CTE):
|
|
206
|
+
context.cte_depth -= 1
|
|
207
|
+
if context.cte_depth == 0:
|
|
208
|
+
context.in_recursive_cte = False
|
|
179
209
|
|
|
180
210
|
def _process_literal_with_context(
|
|
181
211
|
self, literal: exp.Expression, context: ParameterizationContext
|
|
@@ -206,7 +236,6 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
206
236
|
"type": type_hint,
|
|
207
237
|
"semantic_name": semantic_name,
|
|
208
238
|
"context": self._get_context_description(context),
|
|
209
|
-
# Note: We avoid calling literal.sql() for performance
|
|
210
239
|
}
|
|
211
240
|
)
|
|
212
241
|
|
|
@@ -227,6 +256,21 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
227
256
|
if context.in_function_args:
|
|
228
257
|
return True
|
|
229
258
|
|
|
259
|
+
# Preserve literals in recursive CTEs to avoid type inference issues
|
|
260
|
+
if self.preserve_in_recursive_cte and context.in_recursive_cte:
|
|
261
|
+
return True
|
|
262
|
+
|
|
263
|
+
# Check if this literal is being used as an alias value in SELECT
|
|
264
|
+
# e.g., 'computed' as process_status should be preserved
|
|
265
|
+
if hasattr(literal, "parent") and literal.parent:
|
|
266
|
+
parent = literal.parent
|
|
267
|
+
# Check if it's an Alias node and the literal is the expression (not the alias name)
|
|
268
|
+
if isinstance(parent, exp.Alias) and parent.this == literal:
|
|
269
|
+
# Check if this alias is in a SELECT clause
|
|
270
|
+
for ancestor in context.parent_stack:
|
|
271
|
+
if isinstance(ancestor, exp.Select):
|
|
272
|
+
return True
|
|
273
|
+
|
|
230
274
|
# Check parent context more intelligently
|
|
231
275
|
for parent in context.parent_stack:
|
|
232
276
|
# Preserve in schema/DDL contexts
|
|
@@ -616,6 +660,16 @@ class ParameterizeLiterals(ProcessorProtocol):
|
|
|
616
660
|
"""
|
|
617
661
|
return self._parameter_metadata.copy()
|
|
618
662
|
|
|
663
|
+
def _contains_union(self, cte_node: exp.CTE) -> bool:
|
|
664
|
+
"""Check if a CTE contains a UNION (characteristic of recursive CTEs)."""
|
|
665
|
+
|
|
666
|
+
def has_union(node: exp.Expression) -> bool:
|
|
667
|
+
if isinstance(node, exp.Union):
|
|
668
|
+
return True
|
|
669
|
+
return any(has_union(child) for child in node.iter_expressions())
|
|
670
|
+
|
|
671
|
+
return cte_node.this and has_union(cte_node.this)
|
|
672
|
+
|
|
619
673
|
def clear_parameters(self) -> None:
|
|
620
674
|
"""Clear the extracted parameters list."""
|
|
621
675
|
self.extracted_parameters = []
|