sqlspec 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

@@ -197,8 +197,41 @@ class SqliteDriver(
197
197
  result: ScriptResultDict = {"statements_executed": -1, "status_message": "SCRIPT EXECUTED"}
198
198
  return result
199
199
 
200
+ def _ingest_arrow_table(self, table: Any, table_name: str, mode: str = "create", **options: Any) -> int:
201
+ """SQLite-specific Arrow table ingestion using CSV conversion.
202
+
203
+ Since SQLite only supports CSV bulk loading, we convert the Arrow table
204
+ to CSV format first using the storage backend for efficient operations.
205
+ """
206
+ import io
207
+ import tempfile
208
+
209
+ import pyarrow.csv as pa_csv
210
+
211
+ # Convert Arrow table to CSV in memory
212
+ csv_buffer = io.BytesIO()
213
+ pa_csv.write_csv(table, csv_buffer)
214
+ csv_content = csv_buffer.getvalue()
215
+
216
+ # Create a temporary file path
217
+ temp_filename = f"sqlspec_temp_{table_name}_{id(self)}.csv"
218
+ temp_path = Path(tempfile.gettempdir()) / temp_filename
219
+
220
+ # Use storage backend to write the CSV content
221
+ backend = self._get_storage_backend(temp_path)
222
+ backend.write_bytes(str(temp_path), csv_content)
223
+
224
+ try:
225
+ # Use SQLite's CSV bulk load
226
+ return self._bulk_load_file(temp_path, table_name, "csv", mode, **options)
227
+ finally:
228
+ # Clean up using storage backend
229
+ with contextlib.suppress(Exception):
230
+ # Best effort cleanup
231
+ backend.delete(str(temp_path))
232
+
200
233
  def _bulk_load_file(self, file_path: Path, table_name: str, format: str, mode: str, **options: Any) -> int:
201
- """Database-specific bulk load implementation."""
234
+ """Database-specific bulk load implementation using storage backend."""
202
235
  if format != "csv":
203
236
  msg = f"SQLite driver only supports CSV for bulk loading, not {format}."
204
237
  raise NotImplementedError(msg)
@@ -208,16 +241,23 @@ class SqliteDriver(
208
241
  if mode == "replace":
209
242
  cursor.execute(f"DELETE FROM {table_name}")
210
243
 
211
- with Path(file_path).open(encoding="utf-8") as f:
212
- reader = csv.reader(f, **options)
213
- header = next(reader) # Skip header
214
- placeholders = ", ".join("?" for _ in header)
215
- sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
244
+ # Use storage backend to read the file
245
+ backend = self._get_storage_backend(file_path)
246
+ content = backend.read_text(str(file_path), encoding="utf-8")
247
+
248
+ # Parse CSV content
249
+ import io
250
+
251
+ csv_file = io.StringIO(content)
252
+ reader = csv.reader(csv_file, **options)
253
+ header = next(reader) # Skip header
254
+ placeholders = ", ".join("?" for _ in header)
255
+ sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
216
256
 
217
- # executemany is efficient for bulk inserts
218
- data_iter = list(reader) # Read all data into memory
219
- cursor.executemany(sql, data_iter)
220
- return cursor.rowcount
257
+ # executemany is efficient for bulk inserts
258
+ data_iter = list(reader) # Read all data into memory
259
+ cursor.executemany(sql, data_iter)
260
+ return cursor.rowcount
221
261
 
222
262
  def _wrap_select_result(
223
263
  self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
@@ -85,25 +85,30 @@ class StorageMixinBase(ABC):
85
85
  raise MissingDependencyError(msg)
86
86
 
87
87
  @staticmethod
88
- def _get_storage_backend(uri_or_key: str) -> "ObjectStoreProtocol":
88
+ def _get_storage_backend(uri_or_key: "Union[str, Path]") -> "ObjectStoreProtocol":
89
89
  """Get storage backend by URI or key with intelligent routing."""
90
- return storage_registry.get(uri_or_key)
90
+ # Pass Path objects directly to storage registry for proper URI conversion
91
+ if isinstance(uri_or_key, Path):
92
+ return storage_registry.get(uri_or_key)
93
+ return storage_registry.get(str(uri_or_key))
91
94
 
92
95
  @staticmethod
93
- def _is_uri(path_or_uri: str) -> bool:
96
+ def _is_uri(path_or_uri: "Union[str, Path]") -> bool:
94
97
  """Check if input is a URI rather than a relative path."""
98
+ path_str = str(path_or_uri)
95
99
  schemes = {"s3", "gs", "gcs", "az", "azure", "abfs", "abfss", "file", "http", "https"}
96
- if "://" in path_or_uri:
97
- scheme = path_or_uri.split("://", maxsplit=1)[0].lower()
100
+ if "://" in path_str:
101
+ scheme = path_str.split("://", maxsplit=1)[0].lower()
98
102
  return scheme in schemes
99
- if len(path_or_uri) >= WINDOWS_PATH_MIN_LENGTH and path_or_uri[1:3] == ":\\":
103
+ if len(path_str) >= WINDOWS_PATH_MIN_LENGTH and path_str[1:3] == ":\\":
100
104
  return True
101
- return bool(path_or_uri.startswith("/"))
105
+ return bool(path_str.startswith("/"))
102
106
 
103
107
  @staticmethod
104
- def _detect_format(uri: str) -> str:
108
+ def _detect_format(uri: "Union[str, Path]") -> str:
105
109
  """Detect file format from URI extension."""
106
- parsed = urlparse(uri)
110
+ uri_str = str(uri)
111
+ parsed = urlparse(uri_str)
107
112
  path = Path(parsed.path)
108
113
  extension = path.suffix.lower().lstrip(".")
109
114
 
@@ -120,28 +125,28 @@ class StorageMixinBase(ABC):
120
125
 
121
126
  return format_map.get(extension, "csv")
122
127
 
123
- def _resolve_backend_and_path(self, uri: str) -> "tuple[ObjectStoreProtocol, str]":
128
+ def _resolve_backend_and_path(self, uri: "Union[str, Path]") -> "tuple[ObjectStoreProtocol, str]":
124
129
  """Resolve backend and path from URI with Phase 3 URI-first routing.
125
130
 
126
131
  Args:
127
- uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path")
132
+ uri: URI to resolve (e.g., "s3://bucket/path", "file:///local/path", Path object)
128
133
 
129
134
  Returns:
130
135
  Tuple of (backend, path) where path is relative to the backend's base path
131
136
  """
132
137
  # Convert Path objects to string
133
- uri = str(uri)
134
- original_path = uri
138
+ uri_str = str(uri)
139
+ original_path = uri_str
135
140
 
136
141
  # Convert absolute paths to file:// URIs if needed
137
- if self._is_uri(uri) and "://" not in uri:
142
+ if self._is_uri(uri_str) and "://" not in uri_str:
138
143
  # It's an absolute path without scheme
139
- uri = f"file://{uri}"
144
+ uri_str = f"file://{uri_str}"
140
145
 
141
- backend = self._get_storage_backend(uri)
146
+ backend = self._get_storage_backend(uri_str)
142
147
 
143
148
  # For file:// URIs, return just the path part for the backend
144
- path = uri[7:] if uri.startswith("file://") else original_path
149
+ path = uri_str[7:] if uri_str.startswith("file://") else original_path
145
150
 
146
151
  return backend, path
147
152
 
@@ -293,7 +298,7 @@ class SyncStorageMixin(StorageMixinBase):
293
298
  statement: "Statement",
294
299
  /,
295
300
  *parameters: "Union[StatementParameters, StatementFilter]",
296
- destination_uri: str,
301
+ destination_uri: "Union[str, Path]",
297
302
  format: "Optional[str]" = None,
298
303
  _connection: "Optional[ConnectionT]" = None,
299
304
  _config: "Optional[SQLConfig]" = None,
@@ -340,7 +345,7 @@ class SyncStorageMixin(StorageMixinBase):
340
345
  statement: "Statement",
341
346
  /,
342
347
  *parameters: "Union[StatementParameters, StatementFilter]",
343
- destination_uri: str,
348
+ destination_uri: "Union[str, Path]",
344
349
  format: "Optional[str]" = None,
345
350
  _connection: "Optional[ConnectionT]" = None,
346
351
  _config: "Optional[SQLConfig]" = None,
@@ -360,7 +365,7 @@ class SyncStorageMixin(StorageMixinBase):
360
365
  detected_format = self._detect_format(destination_uri)
361
366
  if format:
362
367
  file_format = format
363
- elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
368
+ elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
364
369
  # Detection returned default "csv" but file doesn't actually have CSV extension
365
370
  # Default to parquet for better compatibility with tests and common usage
366
371
  file_format = "parquet"
@@ -370,7 +375,7 @@ class SyncStorageMixin(StorageMixinBase):
370
375
  # Special handling for parquet format - if we're exporting to parquet but the
371
376
  # destination doesn't have .parquet extension, add it to ensure compatibility
372
377
  # with pyarrow.parquet.read_table() which requires the extension
373
- if file_format == "parquet" and not destination_uri.endswith(".parquet"):
378
+ if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
374
379
  destination_uri = f"{destination_uri}.parquet"
375
380
 
376
381
  # Use storage backend - resolve AFTER modifying destination_uri
@@ -412,7 +417,12 @@ class SyncStorageMixin(StorageMixinBase):
412
417
  return self._export_via_backend(sql_obj, backend, path, file_format, **kwargs)
413
418
 
414
419
  def import_from_storage(
415
- self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
420
+ self,
421
+ source_uri: "Union[str, Path]",
422
+ table_name: str,
423
+ format: "Optional[str]" = None,
424
+ mode: str = "create",
425
+ **options: Any,
416
426
  ) -> int:
417
427
  """Import data from storage with intelligent routing.
418
428
 
@@ -431,7 +441,12 @@ class SyncStorageMixin(StorageMixinBase):
431
441
  return self._import_from_storage(source_uri, table_name, format, mode, **options)
432
442
 
433
443
  def _import_from_storage(
434
- self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
444
+ self,
445
+ source_uri: "Union[str, Path]",
446
+ table_name: str,
447
+ format: "Optional[str]" = None,
448
+ mode: str = "create",
449
+ **options: Any,
435
450
  ) -> int:
436
451
  """Protected method for import operation implementation.
437
452
 
@@ -461,7 +476,23 @@ class SyncStorageMixin(StorageMixinBase):
461
476
  arrow_table = backend.read_arrow(path, **options)
462
477
  return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
463
478
  except AttributeError:
464
- pass
479
+ # Backend doesn't support read_arrow, try alternative approach
480
+ try:
481
+ import pyarrow.parquet as pq
482
+
483
+ # Read Parquet file directly
484
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=".parquet", delete=False) as tmp:
485
+ tmp.write(backend.read_bytes(path))
486
+ tmp_path = Path(tmp.name)
487
+ try:
488
+ arrow_table = pq.read_table(tmp_path)
489
+ return self.ingest_arrow_table(arrow_table, table_name, mode=mode)
490
+ finally:
491
+ tmp_path.unlink(missing_ok=True)
492
+ except ImportError:
493
+ # PyArrow not installed, cannot import Parquet
494
+ msg = "PyArrow is required to import Parquet files. Install with: pip install pyarrow"
495
+ raise ImportError(msg) from None
465
496
 
466
497
  # Use traditional import through temporary file
467
498
  return self._import_via_backend(backend, path, table_name, file_format, mode, **options)
@@ -471,23 +502,27 @@ class SyncStorageMixin(StorageMixinBase):
471
502
  # ============================================================================
472
503
 
473
504
  def _read_parquet_native(
474
- self, source_uri: str, columns: "Optional[list[str]]" = None, **options: Any
505
+ self, source_uri: "Union[str, Path]", columns: "Optional[list[str]]" = None, **options: Any
475
506
  ) -> "SQLResult":
476
507
  """Database-specific native Parquet reading. Override in drivers."""
477
508
  msg = "Driver should implement _read_parquet_native"
478
509
  raise NotImplementedError(msg)
479
510
 
480
- def _write_parquet_native(self, data: Union[str, ArrowTable], destination_uri: str, **options: Any) -> None:
511
+ def _write_parquet_native(
512
+ self, data: Union[str, ArrowTable], destination_uri: "Union[str, Path]", **options: Any
513
+ ) -> None:
481
514
  """Database-specific native Parquet writing. Override in drivers."""
482
515
  msg = "Driver should implement _write_parquet_native"
483
516
  raise NotImplementedError(msg)
484
517
 
485
- def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
518
+ def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
486
519
  """Database-specific native export. Override in drivers."""
487
520
  msg = "Driver should implement _export_native"
488
521
  raise NotImplementedError(msg)
489
522
 
490
- def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
523
+ def _import_native(
524
+ self, source_uri: "Union[str, Path]", table_name: str, format: str, mode: str, **options: Any
525
+ ) -> int:
491
526
  """Database-specific native import. Override in drivers."""
492
527
  msg = "Driver should implement _import_native"
493
528
  raise NotImplementedError(msg)
@@ -743,7 +778,7 @@ class AsyncStorageMixin(StorageMixinBase):
743
778
  statement: "Statement",
744
779
  /,
745
780
  *parameters: "Union[StatementParameters, StatementFilter]",
746
- destination_uri: str,
781
+ destination_uri: "Union[str, Path]",
747
782
  format: "Optional[str]" = None,
748
783
  _connection: "Optional[ConnectionT]" = None,
749
784
  _config: "Optional[SQLConfig]" = None,
@@ -770,7 +805,7 @@ class AsyncStorageMixin(StorageMixinBase):
770
805
  async def _export_to_storage(
771
806
  self,
772
807
  query: "SQL",
773
- destination_uri: str,
808
+ destination_uri: "Union[str, Path]",
774
809
  format: "Optional[str]" = None,
775
810
  connection: "Optional[ConnectionT]" = None,
776
811
  **options: Any,
@@ -793,7 +828,7 @@ class AsyncStorageMixin(StorageMixinBase):
793
828
  detected_format = self._detect_format(destination_uri)
794
829
  if format:
795
830
  file_format = format
796
- elif detected_format == "csv" and not destination_uri.endswith((".csv", ".tsv", ".txt")):
831
+ elif detected_format == "csv" and not str(destination_uri).endswith((".csv", ".tsv", ".txt")):
797
832
  # Detection returned default "csv" but file doesn't actually have CSV extension
798
833
  # Default to parquet for better compatibility with tests and common usage
799
834
  file_format = "parquet"
@@ -803,7 +838,7 @@ class AsyncStorageMixin(StorageMixinBase):
803
838
  # Special handling for parquet format - if we're exporting to parquet but the
804
839
  # destination doesn't have .parquet extension, add it to ensure compatibility
805
840
  # with pyarrow.parquet.read_table() which requires the extension
806
- if file_format == "parquet" and not destination_uri.endswith(".parquet"):
841
+ if file_format == "parquet" and not str(destination_uri).endswith(".parquet"):
807
842
  destination_uri = f"{destination_uri}.parquet"
808
843
 
809
844
  # Use storage backend - resolve AFTER modifying destination_uri
@@ -838,7 +873,12 @@ class AsyncStorageMixin(StorageMixinBase):
838
873
  return await self._export_via_backend(query, backend, path, file_format, **options)
839
874
 
840
875
  async def import_from_storage(
841
- self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
876
+ self,
877
+ source_uri: "Union[str, Path]",
878
+ table_name: str,
879
+ format: "Optional[str]" = None,
880
+ mode: str = "create",
881
+ **options: Any,
842
882
  ) -> int:
843
883
  """Async import data from storage with intelligent routing.
844
884
 
@@ -857,7 +897,12 @@ class AsyncStorageMixin(StorageMixinBase):
857
897
  return await self._import_from_storage(source_uri, table_name, format, mode, **options)
858
898
 
859
899
  async def _import_from_storage(
860
- self, source_uri: str, table_name: str, format: "Optional[str]" = None, mode: str = "create", **options: Any
900
+ self,
901
+ source_uri: "Union[str, Path]",
902
+ table_name: str,
903
+ format: "Optional[str]" = None,
904
+ mode: str = "create",
905
+ **options: Any,
861
906
  ) -> int:
862
907
  """Protected async method for import operation implementation.
863
908
 
@@ -884,12 +929,14 @@ class AsyncStorageMixin(StorageMixinBase):
884
929
  # Async Database-Specific Implementation Hooks
885
930
  # ============================================================================
886
931
 
887
- async def _export_native(self, query: str, destination_uri: str, format: str, **options: Any) -> int:
932
+ async def _export_native(self, query: str, destination_uri: "Union[str, Path]", format: str, **options: Any) -> int:
888
933
  """Async database-specific native export."""
889
934
  msg = "Driver should implement _export_native"
890
935
  raise NotImplementedError(msg)
891
936
 
892
- async def _import_native(self, source_uri: str, table_name: str, format: str, mode: str, **options: Any) -> int:
937
+ async def _import_native(
938
+ self, source_uri: "Union[str, Path]", table_name: str, format: str, mode: str, **options: Any
939
+ ) -> int:
893
940
  """Async database-specific native import."""
894
941
  msg = "Driver should implement _import_native"
895
942
  raise NotImplementedError(msg)
sqlspec/loader.py CHANGED
@@ -113,7 +113,7 @@ class SQLFileLoader:
113
113
  self._query_to_file: dict[str, str] = {} # Maps query name to file path
114
114
 
115
115
  def _read_file_content(self, path: Union[str, Path]) -> str:
116
- """Read file content using appropriate backend.
116
+ """Read file content using storage backend.
117
117
 
118
118
  Args:
119
119
  path: File path (can be local path or URI).
@@ -126,37 +126,15 @@ class SQLFileLoader:
126
126
  """
127
127
  path_str = str(path)
128
128
 
129
- # Use storage backend for URIs (anything with a scheme)
130
- if "://" in path_str:
131
- try:
132
- backend = self.storage_registry.get(path_str)
133
- return backend.read_text(path_str, encoding=self.encoding)
134
- except KeyError as e:
135
- raise SQLFileNotFoundError(path_str) from e
136
- except Exception as e:
137
- raise SQLFileParseError(path_str, path_str, e) from e
138
-
139
- # Handle local file paths
140
- local_path = Path(path_str)
141
- self._check_file_path(local_path)
142
- content_bytes = self._read_file_content_bytes(local_path)
143
- return content_bytes.decode(self.encoding)
144
-
145
- @staticmethod
146
- def _read_file_content_bytes(path: Path) -> bytes:
147
129
  try:
148
- return path.read_bytes()
130
+ # Always use storage backend for consistent behavior
131
+ # Pass the original path object to allow storage registry to handle Path -> file:// conversion
132
+ backend = self.storage_registry.get(path)
133
+ return backend.read_text(path_str, encoding=self.encoding)
134
+ except KeyError as e:
135
+ raise SQLFileNotFoundError(path_str) from e
149
136
  except Exception as e:
150
- raise SQLFileParseError(str(path), str(path), e) from e
151
-
152
- @staticmethod
153
- def _check_file_path(path: Union[str, Path]) -> None:
154
- """Ensure the file exists and is a valid path."""
155
- path_obj = Path(path).resolve()
156
- if not path_obj.exists():
157
- raise SQLFileNotFoundError(str(path_obj))
158
- if not path_obj.is_file():
159
- raise SQLFileParseError(str(path_obj), str(path_obj), ValueError("Path is not a file"))
137
+ raise SQLFileParseError(path_str, path_str, e) from e
160
138
 
161
139
  @staticmethod
162
140
  def _strip_leading_comments(sql_text: str) -> str:
@@ -192,7 +192,9 @@ class QueryBuilder(ABC, Generic[RowT]):
192
192
  self._raise_sql_builder_error(msg)
193
193
  cte_select_expression = query._expression.copy()
194
194
  for p_name, p_value in query._parameters.items():
195
- self.add_parameter(p_value, f"cte_{alias}_{p_name}")
195
+ # Try to preserve original parameter name, only rename if collision
196
+ unique_name = self._generate_unique_parameter_name(p_name)
197
+ self.add_parameter(p_value, unique_name)
196
198
 
197
199
  elif isinstance(query, str):
198
200
  try:
@@ -769,14 +769,27 @@ class CreateTableAsSelectBuilder(DDLBuilder):
769
769
  select_expr = self._select_query.expression
770
770
  select_params = getattr(self._select_query, "parameters", None)
771
771
  elif isinstance(self._select_query, SelectBuilder):
772
+ # Get the expression and parameters directly
772
773
  select_expr = getattr(self._select_query, "_expression", None)
773
774
  select_params = getattr(self._select_query, "_parameters", None)
775
+
776
+ # Apply CTEs if present
777
+ with_ctes = getattr(self._select_query, "_with_ctes", {})
778
+ if with_ctes and select_expr and isinstance(select_expr, exp.Select):
779
+ # Apply CTEs directly to the SELECT expression using sqlglot's with_ method
780
+ for alias, cte in with_ctes.items():
781
+ if hasattr(select_expr, "with_"):
782
+ select_expr = select_expr.with_(
783
+ cte.this, # The CTE's SELECT expression
784
+ as_=alias,
785
+ copy=False,
786
+ )
774
787
  elif isinstance(self._select_query, str):
775
788
  select_expr = exp.maybe_parse(self._select_query)
776
789
  select_params = None
777
790
  else:
778
791
  self._raise_sql_builder_error("Unsupported type for SELECT query in CTAS.")
779
- if select_expr is None or not isinstance(select_expr, exp.Select):
792
+ if select_expr is None:
780
793
  self._raise_sql_builder_error("SELECT query must be a valid SELECT expression.")
781
794
 
782
795
  # Merge parameters from SELECT if present
@@ -324,11 +324,7 @@ class StatementAnalyzer(ProcessorProtocol):
324
324
  def _analyze_subqueries(self, expression: exp.Expression, analysis: StatementAnalysis) -> None:
325
325
  """Analyze subquery complexity and nesting depth."""
326
326
  subqueries: list[exp.Expression] = list(expression.find_all(exp.Subquery))
327
- subqueries.extend(
328
- query
329
- for in_clause in expression.find_all(exp.In)
330
- if (query := in_clause.args.get("query")) and isinstance(query, exp.Select)
331
- )
327
+ # Workaround for EXISTS clauses: sqlglot doesn't wrap EXISTS subqueries in Subquery nodes
332
328
  subqueries.extend(
333
329
  [
334
330
  exists_clause.this
@@ -34,7 +34,9 @@ class ParameterizationContext:
34
34
  in_case_when: bool = False
35
35
  in_array: bool = False
36
36
  in_in_clause: bool = False
37
+ in_recursive_cte: bool = False
37
38
  function_depth: int = 0
39
+ cte_depth: int = 0
38
40
 
39
41
 
40
42
  class ParameterizeLiterals(ProcessorProtocol):
@@ -53,6 +55,7 @@ class ParameterizeLiterals(ProcessorProtocol):
53
55
  preserve_boolean: Whether to preserve boolean literals as-is.
54
56
  preserve_numbers_in_limit: Whether to preserve numbers in LIMIT/OFFSET clauses.
55
57
  preserve_in_functions: List of function names where literals should be preserved.
58
+ preserve_in_recursive_cte: Whether to preserve literals in recursive CTEs (default True to avoid type inference issues).
56
59
  parameterize_arrays: Whether to parameterize array literals.
57
60
  parameterize_in_lists: Whether to parameterize IN clause lists.
58
61
  max_string_length: Maximum string length to parameterize.
@@ -68,6 +71,7 @@ class ParameterizeLiterals(ProcessorProtocol):
68
71
  preserve_boolean: bool = True,
69
72
  preserve_numbers_in_limit: bool = True,
70
73
  preserve_in_functions: Optional[list[str]] = None,
74
+ preserve_in_recursive_cte: bool = True,
71
75
  parameterize_arrays: bool = True,
72
76
  parameterize_in_lists: bool = True,
73
77
  max_string_length: int = DEFAULT_MAX_STRING_LENGTH,
@@ -79,7 +83,18 @@ class ParameterizeLiterals(ProcessorProtocol):
79
83
  self.preserve_null = preserve_null
80
84
  self.preserve_boolean = preserve_boolean
81
85
  self.preserve_numbers_in_limit = preserve_numbers_in_limit
82
- self.preserve_in_functions = preserve_in_functions or ["COALESCE", "IFNULL", "NVL", "ISNULL"]
86
+ self.preserve_in_recursive_cte = preserve_in_recursive_cte
87
+ self.preserve_in_functions = preserve_in_functions or [
88
+ "COALESCE",
89
+ "IFNULL",
90
+ "NVL",
91
+ "ISNULL",
92
+ # Array functions that take dimension arguments
93
+ "ARRAYSIZE", # SQLglot converts array_length to ArraySize
94
+ "ARRAY_UPPER",
95
+ "ARRAY_LOWER",
96
+ "ARRAY_NDIMS",
97
+ ]
83
98
  self.parameterize_arrays = parameterize_arrays
84
99
  self.parameterize_in_lists = parameterize_in_lists
85
100
  self.max_string_length = max_string_length
@@ -162,6 +177,17 @@ class ParameterizeLiterals(ProcessorProtocol):
162
177
  context.in_array = True
163
178
  elif isinstance(node, exp.In):
164
179
  context.in_in_clause = True
180
+ elif isinstance(node, exp.CTE):
181
+ context.cte_depth += 1
182
+ # Check if this CTE is recursive:
183
+ # 1. Parent WITH must be RECURSIVE
184
+ # 2. CTE must contain UNION (characteristic of recursive CTEs)
185
+ is_in_recursive_with = any(
186
+ isinstance(parent, exp.With) and parent.args.get("recursive", False)
187
+ for parent in reversed(context.parent_stack)
188
+ )
189
+ if is_in_recursive_with and self._contains_union(node):
190
+ context.in_recursive_cte = True
165
191
  else:
166
192
  if context.parent_stack:
167
193
  context.parent_stack.pop()
@@ -176,6 +202,10 @@ class ParameterizeLiterals(ProcessorProtocol):
176
202
  context.in_array = False
177
203
  elif isinstance(node, exp.In):
178
204
  context.in_in_clause = False
205
+ elif isinstance(node, exp.CTE):
206
+ context.cte_depth -= 1
207
+ if context.cte_depth == 0:
208
+ context.in_recursive_cte = False
179
209
 
180
210
  def _process_literal_with_context(
181
211
  self, literal: exp.Expression, context: ParameterizationContext
@@ -206,7 +236,6 @@ class ParameterizeLiterals(ProcessorProtocol):
206
236
  "type": type_hint,
207
237
  "semantic_name": semantic_name,
208
238
  "context": self._get_context_description(context),
209
- # Note: We avoid calling literal.sql() for performance
210
239
  }
211
240
  )
212
241
 
@@ -227,6 +256,21 @@ class ParameterizeLiterals(ProcessorProtocol):
227
256
  if context.in_function_args:
228
257
  return True
229
258
 
259
+ # Preserve literals in recursive CTEs to avoid type inference issues
260
+ if self.preserve_in_recursive_cte and context.in_recursive_cte:
261
+ return True
262
+
263
+ # Check if this literal is being used as an alias value in SELECT
264
+ # e.g., 'computed' as process_status should be preserved
265
+ if hasattr(literal, "parent") and literal.parent:
266
+ parent = literal.parent
267
+ # Check if it's an Alias node and the literal is the expression (not the alias name)
268
+ if isinstance(parent, exp.Alias) and parent.this == literal:
269
+ # Check if this alias is in a SELECT clause
270
+ for ancestor in context.parent_stack:
271
+ if isinstance(ancestor, exp.Select):
272
+ return True
273
+
230
274
  # Check parent context more intelligently
231
275
  for parent in context.parent_stack:
232
276
  # Preserve in schema/DDL contexts
@@ -616,6 +660,16 @@ class ParameterizeLiterals(ProcessorProtocol):
616
660
  """
617
661
  return self._parameter_metadata.copy()
618
662
 
663
+ def _contains_union(self, cte_node: exp.CTE) -> bool:
664
+ """Check if a CTE contains a UNION (characteristic of recursive CTEs)."""
665
+
666
+ def has_union(node: exp.Expression) -> bool:
667
+ if isinstance(node, exp.Union):
668
+ return True
669
+ return any(has_union(child) for child in node.iter_expressions())
670
+
671
+ return cte_node.this and has_union(cte_node.this)
672
+
619
673
  def clear_parameters(self) -> None:
620
674
  """Clear the extracted parameters list."""
621
675
  self.extracted_parameters = []