sqlspec 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -23
- sqlspec/_typing.py +2 -0
- sqlspec/adapters/adbc/driver.py +2 -2
- sqlspec/adapters/oracledb/driver.py +5 -0
- sqlspec/adapters/psycopg/config.py +2 -4
- sqlspec/base.py +3 -4
- sqlspec/builder/_base.py +55 -13
- sqlspec/builder/_column.py +9 -0
- sqlspec/builder/_ddl.py +7 -7
- sqlspec/builder/_insert.py +10 -6
- sqlspec/builder/_parsing_utils.py +23 -4
- sqlspec/builder/_update.py +1 -1
- sqlspec/builder/mixins/_cte_and_set_ops.py +31 -22
- sqlspec/builder/mixins/_delete_operations.py +12 -7
- sqlspec/builder/mixins/_insert_operations.py +50 -36
- sqlspec/builder/mixins/_join_operations.py +1 -0
- sqlspec/builder/mixins/_merge_operations.py +54 -28
- sqlspec/builder/mixins/_order_limit_operations.py +1 -0
- sqlspec/builder/mixins/_pivot_operations.py +1 -0
- sqlspec/builder/mixins/_select_operations.py +42 -14
- sqlspec/builder/mixins/_update_operations.py +30 -18
- sqlspec/builder/mixins/_where_clause.py +48 -60
- sqlspec/core/__init__.py +3 -2
- sqlspec/core/cache.py +297 -351
- sqlspec/core/compiler.py +5 -3
- sqlspec/core/filters.py +246 -213
- sqlspec/core/hashing.py +9 -11
- sqlspec/core/parameters.py +20 -7
- sqlspec/core/statement.py +67 -12
- sqlspec/driver/_async.py +2 -2
- sqlspec/driver/_common.py +31 -14
- sqlspec/driver/_sync.py +2 -2
- sqlspec/driver/mixins/_result_tools.py +60 -7
- sqlspec/loader.py +8 -9
- sqlspec/storage/backends/fsspec.py +1 -0
- sqlspec/typing.py +2 -0
- {sqlspec-0.24.0.dist-info → sqlspec-0.25.0.dist-info}/METADATA +1 -1
- {sqlspec-0.24.0.dist-info → sqlspec-0.25.0.dist-info}/RECORD +42 -42
- {sqlspec-0.24.0.dist-info → sqlspec-0.25.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.24.0.dist-info → sqlspec-0.25.0.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.24.0.dist-info → sqlspec-0.25.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.24.0.dist-info → sqlspec-0.25.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/_sql.py
CHANGED
|
@@ -137,21 +137,18 @@ class SQLFactory:
|
|
|
137
137
|
self.dialect = dialect
|
|
138
138
|
|
|
139
139
|
def __call__(self, statement: str, dialect: DialectType = None) -> "Any":
|
|
140
|
-
"""Create a SelectBuilder from a SQL string,
|
|
140
|
+
"""Create a SelectBuilder from a SQL string, or SQL object for DML with RETURNING.
|
|
141
141
|
|
|
142
142
|
Args:
|
|
143
143
|
statement: The SQL statement string.
|
|
144
|
-
parameters: Optional parameters for the query.
|
|
145
|
-
*filters: Optional filters.
|
|
146
|
-
config: Optional config.
|
|
147
144
|
dialect: Optional SQL dialect.
|
|
148
|
-
**kwargs: Additional parameters.
|
|
149
145
|
|
|
150
146
|
Returns:
|
|
151
|
-
SelectBuilder instance
|
|
147
|
+
SelectBuilder instance for SELECT/WITH statements,
|
|
148
|
+
SQL object for DML statements with RETURNING clause.
|
|
152
149
|
|
|
153
150
|
Raises:
|
|
154
|
-
SQLBuilderError: If the SQL is not a SELECT/CTE statement.
|
|
151
|
+
SQLBuilderError: If the SQL is not a SELECT/CTE/DML+RETURNING statement.
|
|
155
152
|
"""
|
|
156
153
|
|
|
157
154
|
try:
|
|
@@ -173,10 +170,15 @@ class SQLFactory:
|
|
|
173
170
|
actual_type_str == "WITH" and parsed_expr.this and isinstance(parsed_expr.this, exp.Select)
|
|
174
171
|
):
|
|
175
172
|
builder = Select(dialect=dialect or self.dialect)
|
|
176
|
-
builder.
|
|
173
|
+
builder.set_expression(parsed_expr)
|
|
177
174
|
return builder
|
|
175
|
+
|
|
176
|
+
if actual_type_str in {"INSERT", "UPDATE", "DELETE"} and parsed_expr.args.get("returning") is not None:
|
|
177
|
+
return SQL(statement)
|
|
178
|
+
|
|
178
179
|
msg = (
|
|
179
|
-
f"sql(...) only supports SELECT statements
|
|
180
|
+
f"sql(...) only supports SELECT statements or DML statements with RETURNING clause. "
|
|
181
|
+
f"Detected type: {actual_type_str}. "
|
|
180
182
|
f"Use sql.{actual_type_str.lower()}() instead."
|
|
181
183
|
)
|
|
182
184
|
raise SQLBuilderError(msg)
|
|
@@ -449,7 +451,7 @@ class SQLFactory:
|
|
|
449
451
|
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
|
|
450
452
|
|
|
451
453
|
if isinstance(parsed_expr, exp.Insert):
|
|
452
|
-
builder.
|
|
454
|
+
builder.set_expression(parsed_expr)
|
|
453
455
|
return builder
|
|
454
456
|
|
|
455
457
|
if isinstance(parsed_expr, exp.Select):
|
|
@@ -468,7 +470,7 @@ class SQLFactory:
|
|
|
468
470
|
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
|
|
469
471
|
|
|
470
472
|
if isinstance(parsed_expr, exp.Select):
|
|
471
|
-
builder.
|
|
473
|
+
builder.set_expression(parsed_expr)
|
|
472
474
|
return builder
|
|
473
475
|
|
|
474
476
|
logger.warning("Cannot create SELECT from %s statement", type(parsed_expr).__name__)
|
|
@@ -483,7 +485,7 @@ class SQLFactory:
|
|
|
483
485
|
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
|
|
484
486
|
|
|
485
487
|
if isinstance(parsed_expr, exp.Update):
|
|
486
|
-
builder.
|
|
488
|
+
builder.set_expression(parsed_expr)
|
|
487
489
|
return builder
|
|
488
490
|
|
|
489
491
|
logger.warning("Cannot create UPDATE from %s statement", type(parsed_expr).__name__)
|
|
@@ -498,7 +500,7 @@ class SQLFactory:
|
|
|
498
500
|
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
|
|
499
501
|
|
|
500
502
|
if isinstance(parsed_expr, exp.Delete):
|
|
501
|
-
builder.
|
|
503
|
+
builder.set_expression(parsed_expr)
|
|
502
504
|
return builder
|
|
503
505
|
|
|
504
506
|
logger.warning("Cannot create DELETE from %s statement", type(parsed_expr).__name__)
|
|
@@ -513,7 +515,7 @@ class SQLFactory:
|
|
|
513
515
|
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
|
|
514
516
|
|
|
515
517
|
if isinstance(parsed_expr, exp.Merge):
|
|
516
|
-
builder.
|
|
518
|
+
builder.set_expression(parsed_expr)
|
|
517
519
|
return builder
|
|
518
520
|
|
|
519
521
|
logger.warning("Cannot create MERGE from %s statement", type(parsed_expr).__name__)
|
|
@@ -722,19 +724,15 @@ class SQLFactory:
|
|
|
722
724
|
if not parameters:
|
|
723
725
|
try:
|
|
724
726
|
parsed: exp.Expression = exp.maybe_parse(sql_fragment)
|
|
725
|
-
return parsed
|
|
726
|
-
if sql_fragment.strip().replace("_", "").replace(".", "").isalnum():
|
|
727
|
-
return exp.to_identifier(sql_fragment)
|
|
728
|
-
return exp.Literal.string(sql_fragment)
|
|
729
727
|
except Exception as e:
|
|
730
728
|
msg = f"Failed to parse raw SQL fragment '{sql_fragment}': {e}"
|
|
731
729
|
raise SQLBuilderError(msg) from e
|
|
730
|
+
return parsed
|
|
732
731
|
|
|
733
732
|
return SQL(sql_fragment, parameters)
|
|
734
733
|
|
|
735
|
-
@staticmethod
|
|
736
734
|
def count(
|
|
737
|
-
column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
|
|
735
|
+
self, column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
|
|
738
736
|
) -> AggregateExpression:
|
|
739
737
|
"""Create a COUNT expression.
|
|
740
738
|
|
|
@@ -748,7 +746,7 @@ class SQLFactory:
|
|
|
748
746
|
if isinstance(column, str) and column == "*":
|
|
749
747
|
expr = exp.Count(this=exp.Star(), distinct=distinct)
|
|
750
748
|
else:
|
|
751
|
-
col_expr =
|
|
749
|
+
col_expr = self._extract_expression(column)
|
|
752
750
|
expr = exp.Count(this=col_expr, distinct=distinct)
|
|
753
751
|
return AggregateExpression(expr)
|
|
754
752
|
|
|
@@ -1066,11 +1064,11 @@ class SQLFactory:
|
|
|
1066
1064
|
if isinstance(value, str):
|
|
1067
1065
|
return exp.column(value)
|
|
1068
1066
|
if isinstance(value, Column):
|
|
1069
|
-
return value.
|
|
1067
|
+
return value.sqlglot_expression
|
|
1070
1068
|
if isinstance(value, ExpressionWrapper):
|
|
1071
1069
|
return value.expression
|
|
1072
1070
|
if isinstance(value, Case):
|
|
1073
|
-
return exp.Case(ifs=value.
|
|
1071
|
+
return exp.Case(ifs=value.conditions, default=value.default)
|
|
1074
1072
|
if isinstance(value, exp.Expression):
|
|
1075
1073
|
return value
|
|
1076
1074
|
return exp.convert(value)
|
sqlspec/_typing.py
CHANGED
|
@@ -606,6 +606,7 @@ except ImportError:
|
|
|
606
606
|
|
|
607
607
|
|
|
608
608
|
FSSPEC_INSTALLED = bool(find_spec("fsspec"))
|
|
609
|
+
NUMPY_INSTALLED = bool(find_spec("numpy"))
|
|
609
610
|
OBSTORE_INSTALLED = bool(find_spec("obstore"))
|
|
610
611
|
PGVECTOR_INSTALLED = bool(find_spec("pgvector"))
|
|
611
612
|
|
|
@@ -617,6 +618,7 @@ __all__ = (
|
|
|
617
618
|
"FSSPEC_INSTALLED",
|
|
618
619
|
"LITESTAR_INSTALLED",
|
|
619
620
|
"MSGSPEC_INSTALLED",
|
|
621
|
+
"NUMPY_INSTALLED",
|
|
620
622
|
"OBSTORE_INSTALLED",
|
|
621
623
|
"OPENTELEMETRY_INSTALLED",
|
|
622
624
|
"PGVECTOR_INSTALLED",
|
sqlspec/adapters/adbc/driver.py
CHANGED
|
@@ -521,7 +521,7 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
521
521
|
|
|
522
522
|
try:
|
|
523
523
|
if not prepared_parameters:
|
|
524
|
-
cursor._rowcount = 0
|
|
524
|
+
cursor._rowcount = 0 # pyright: ignore[reportPrivateUsage]
|
|
525
525
|
row_count = 0
|
|
526
526
|
elif isinstance(prepared_parameters, list) and prepared_parameters:
|
|
527
527
|
processed_params = []
|
|
@@ -596,7 +596,7 @@ class AdbcDriver(SyncDriverAdapterBase):
|
|
|
596
596
|
Execution result with statement counts
|
|
597
597
|
"""
|
|
598
598
|
if statement.is_script:
|
|
599
|
-
sql = statement.
|
|
599
|
+
sql = statement.raw_sql
|
|
600
600
|
prepared_parameters: list[Any] = []
|
|
601
601
|
else:
|
|
602
602
|
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
@@ -286,6 +286,11 @@ class OracleSyncDriver(SyncDriverAdapterBase):
|
|
|
286
286
|
msg = "execute_many requires parameters"
|
|
287
287
|
raise ValueError(msg)
|
|
288
288
|
|
|
289
|
+
# Oracle-specific fix: Ensure parameters are in list format for executemany
|
|
290
|
+
# Oracle expects a list of sequences, not a tuple of sequences
|
|
291
|
+
if isinstance(prepared_parameters, tuple):
|
|
292
|
+
prepared_parameters = list(prepared_parameters)
|
|
293
|
+
|
|
289
294
|
cursor.executemany(sql, prepared_parameters)
|
|
290
295
|
|
|
291
296
|
# Calculate affected rows based on parameter count
|
|
@@ -173,8 +173,7 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
173
173
|
logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
|
|
174
174
|
|
|
175
175
|
try:
|
|
176
|
-
|
|
177
|
-
self.pool_instance._closed = True
|
|
176
|
+
self.pool_instance._closed = True # pyright: ignore[reportPrivateUsage]
|
|
178
177
|
|
|
179
178
|
self.pool_instance.close()
|
|
180
179
|
logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"})
|
|
@@ -350,8 +349,7 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
350
349
|
return
|
|
351
350
|
|
|
352
351
|
try:
|
|
353
|
-
|
|
354
|
-
self.pool_instance._closed = True
|
|
352
|
+
self.pool_instance._closed = True # pyright: ignore[reportPrivateUsage]
|
|
355
353
|
|
|
356
354
|
await self.pool_instance.close()
|
|
357
355
|
finally:
|
sqlspec/base.py
CHANGED
|
@@ -15,9 +15,8 @@ from sqlspec.config import (
|
|
|
15
15
|
)
|
|
16
16
|
from sqlspec.core.cache import (
|
|
17
17
|
CacheConfig,
|
|
18
|
-
CacheStatsAggregate,
|
|
19
18
|
get_cache_config,
|
|
20
|
-
|
|
19
|
+
get_cache_statistics,
|
|
21
20
|
log_cache_stats,
|
|
22
21
|
reset_cache_stats,
|
|
23
22
|
update_cache_config,
|
|
@@ -532,13 +531,13 @@ class SQLSpec:
|
|
|
532
531
|
update_cache_config(config)
|
|
533
532
|
|
|
534
533
|
@staticmethod
|
|
535
|
-
def get_cache_stats() ->
|
|
534
|
+
def get_cache_stats() -> "dict[str, Any]":
|
|
536
535
|
"""Get current cache statistics.
|
|
537
536
|
|
|
538
537
|
Returns:
|
|
539
538
|
Cache statistics object with detailed metrics.
|
|
540
539
|
"""
|
|
541
|
-
return
|
|
540
|
+
return get_cache_statistics()
|
|
542
541
|
|
|
543
542
|
@staticmethod
|
|
544
543
|
def reset_cache_stats() -> None:
|
sqlspec/builder/_base.py
CHANGED
|
@@ -13,7 +13,7 @@ from sqlglot.errors import ParseError as SQLGlotParseError
|
|
|
13
13
|
from sqlglot.optimizer import optimize
|
|
14
14
|
from typing_extensions import Self
|
|
15
15
|
|
|
16
|
-
from sqlspec.core.cache import
|
|
16
|
+
from sqlspec.core.cache import get_cache, get_cache_config
|
|
17
17
|
from sqlspec.core.hashing import hash_optimized_expression
|
|
18
18
|
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
|
|
19
19
|
from sqlspec.core.statement import SQL, StatementConfig
|
|
@@ -91,6 +91,36 @@ class QueryBuilder(ABC):
|
|
|
91
91
|
"QueryBuilder._create_base_expression must return a valid sqlglot expression."
|
|
92
92
|
)
|
|
93
93
|
|
|
94
|
+
def get_expression(self) -> Optional[exp.Expression]:
|
|
95
|
+
"""Get expression reference (no copy).
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The current SQLGlot expression or None if not set
|
|
99
|
+
"""
|
|
100
|
+
return self._expression
|
|
101
|
+
|
|
102
|
+
def set_expression(self, expression: exp.Expression) -> None:
|
|
103
|
+
"""Set expression with validation.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
expression: SQLGlot expression to set
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
TypeError: If expression is not a SQLGlot Expression
|
|
110
|
+
"""
|
|
111
|
+
if not isinstance(expression, exp.Expression):
|
|
112
|
+
msg = f"Expected Expression, got {type(expression)}"
|
|
113
|
+
raise TypeError(msg)
|
|
114
|
+
self._expression = expression
|
|
115
|
+
|
|
116
|
+
def has_expression(self) -> bool:
|
|
117
|
+
"""Check if expression exists.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
True if expression is set, False otherwise
|
|
121
|
+
"""
|
|
122
|
+
return self._expression is not None
|
|
123
|
+
|
|
94
124
|
@abstractmethod
|
|
95
125
|
def _create_base_expression(self) -> exp.Expression:
|
|
96
126
|
"""Create the base sqlglot expression for the specific query type.
|
|
@@ -307,12 +337,13 @@ class QueryBuilder(ABC):
|
|
|
307
337
|
cte_select_expression: exp.Select
|
|
308
338
|
|
|
309
339
|
if isinstance(query, QueryBuilder):
|
|
310
|
-
|
|
340
|
+
query_expr = query.get_expression()
|
|
341
|
+
if query_expr is None:
|
|
311
342
|
self._raise_sql_builder_error("CTE query builder has no expression.")
|
|
312
|
-
if not isinstance(
|
|
313
|
-
msg = f"CTE query builder expression must be a Select, got {type(
|
|
343
|
+
if not isinstance(query_expr, exp.Select):
|
|
344
|
+
msg = f"CTE query builder expression must be a Select, got {type(query_expr).__name__}."
|
|
314
345
|
self._raise_sql_builder_error(msg)
|
|
315
|
-
cte_select_expression =
|
|
346
|
+
cte_select_expression = query_expr
|
|
316
347
|
param_mapping = self._merge_cte_parameters(alias, query.parameters)
|
|
317
348
|
updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping)
|
|
318
349
|
if not isinstance(updated_expression, exp.Select):
|
|
@@ -398,9 +429,8 @@ class QueryBuilder(ABC):
|
|
|
398
429
|
expression, dialect=dialect_name, schema=self.schema, optimizer_settings=optimizer_settings
|
|
399
430
|
)
|
|
400
431
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
cached_optimized = unified_cache.get(cache_key_obj)
|
|
432
|
+
cache = get_cache()
|
|
433
|
+
cached_optimized = cache.get("optimized", cache_key)
|
|
404
434
|
if cached_optimized:
|
|
405
435
|
return cast("exp.Expression", cached_optimized)
|
|
406
436
|
|
|
@@ -409,7 +439,7 @@ class QueryBuilder(ABC):
|
|
|
409
439
|
expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings
|
|
410
440
|
)
|
|
411
441
|
|
|
412
|
-
|
|
442
|
+
cache.put("optimized", cache_key, optimized)
|
|
413
443
|
|
|
414
444
|
except Exception:
|
|
415
445
|
return expression
|
|
@@ -430,15 +460,14 @@ class QueryBuilder(ABC):
|
|
|
430
460
|
return self._to_statement(config)
|
|
431
461
|
|
|
432
462
|
cache_key_str = self._generate_builder_cache_key(config)
|
|
433
|
-
cache_key = CacheKey((cache_key_str,))
|
|
434
463
|
|
|
435
|
-
|
|
436
|
-
cached_sql =
|
|
464
|
+
cache = get_cache()
|
|
465
|
+
cached_sql = cache.get("builder", cache_key_str)
|
|
437
466
|
if cached_sql is not None:
|
|
438
467
|
return cast("SQL", cached_sql)
|
|
439
468
|
|
|
440
469
|
sql_statement = self._to_statement(config)
|
|
441
|
-
|
|
470
|
+
cache.put("builder", cache_key_str, sql_statement)
|
|
442
471
|
|
|
443
472
|
return sql_statement
|
|
444
473
|
|
|
@@ -531,3 +560,16 @@ class QueryBuilder(ABC):
|
|
|
531
560
|
def parameters(self) -> dict[str, Any]:
|
|
532
561
|
"""Public access to query parameters."""
|
|
533
562
|
return self._parameters
|
|
563
|
+
|
|
564
|
+
def set_parameters(self, parameters: dict[str, Any]) -> None:
|
|
565
|
+
"""Set query parameters (public API)."""
|
|
566
|
+
self._parameters = parameters.copy()
|
|
567
|
+
|
|
568
|
+
@property
|
|
569
|
+
def with_ctes(self) -> "dict[str, exp.CTE]":
|
|
570
|
+
"""Get WITH clause CTEs (public API)."""
|
|
571
|
+
return dict(self._with_ctes)
|
|
572
|
+
|
|
573
|
+
def generate_unique_parameter_name(self, base_name: str) -> str:
|
|
574
|
+
"""Generate unique parameter name (public API)."""
|
|
575
|
+
return self._generate_unique_parameter_name(base_name)
|
sqlspec/builder/_column.py
CHANGED
|
@@ -254,6 +254,15 @@ class Column:
|
|
|
254
254
|
"""Hash based on table and column name."""
|
|
255
255
|
return hash((self.table, self.name))
|
|
256
256
|
|
|
257
|
+
@property
|
|
258
|
+
def sqlglot_expression(self) -> exp.Expression:
|
|
259
|
+
"""Get the underlying SQLGlot expression (public API).
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
The SQLGlot expression for this column
|
|
263
|
+
"""
|
|
264
|
+
return self._expression
|
|
265
|
+
|
|
257
266
|
|
|
258
267
|
class FunctionColumn:
|
|
259
268
|
"""Represents the result of a SQL function call on a column."""
|
sqlspec/builder/_ddl.py
CHANGED
|
@@ -973,10 +973,10 @@ class CreateTableAsSelect(DDLBuilder):
|
|
|
973
973
|
select_expr = self._select_query.expression
|
|
974
974
|
select_parameters = self._select_query.parameters
|
|
975
975
|
elif isinstance(self._select_query, Select):
|
|
976
|
-
select_expr = self._select_query.
|
|
977
|
-
select_parameters = self._select_query.
|
|
976
|
+
select_expr = self._select_query.get_expression()
|
|
977
|
+
select_parameters = self._select_query.parameters
|
|
978
978
|
|
|
979
|
-
with_ctes = self._select_query.
|
|
979
|
+
with_ctes = self._select_query.with_ctes
|
|
980
980
|
if with_ctes and select_expr and isinstance(select_expr, exp.Select):
|
|
981
981
|
for alias, cte in with_ctes.items():
|
|
982
982
|
if has_with_method(select_expr):
|
|
@@ -1100,8 +1100,8 @@ class CreateMaterializedView(DDLBuilder):
|
|
|
1100
1100
|
select_expr = self._select_query.expression
|
|
1101
1101
|
select_parameters = self._select_query.parameters
|
|
1102
1102
|
elif isinstance(self._select_query, Select):
|
|
1103
|
-
select_expr = self._select_query.
|
|
1104
|
-
select_parameters = self._select_query.
|
|
1103
|
+
select_expr = self._select_query.get_expression()
|
|
1104
|
+
select_parameters = self._select_query.parameters
|
|
1105
1105
|
elif isinstance(self._select_query, str):
|
|
1106
1106
|
select_expr = exp.maybe_parse(self._select_query)
|
|
1107
1107
|
select_parameters = None
|
|
@@ -1198,8 +1198,8 @@ class CreateView(DDLBuilder):
|
|
|
1198
1198
|
select_expr = self._select_query.expression
|
|
1199
1199
|
select_parameters = self._select_query.parameters
|
|
1200
1200
|
elif isinstance(self._select_query, Select):
|
|
1201
|
-
select_expr = self._select_query.
|
|
1202
|
-
select_parameters = self._select_query.
|
|
1201
|
+
select_expr = self._select_query.get_expression()
|
|
1202
|
+
select_parameters = self._select_query.parameters
|
|
1203
1203
|
elif isinstance(self._select_query, str):
|
|
1204
1204
|
select_expr = exp.maybe_parse(self._select_query)
|
|
1205
1205
|
select_parameters = None
|
sqlspec/builder/_insert.py
CHANGED
|
@@ -90,6 +90,10 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
90
90
|
raise SQLBuilderError(ERR_MSG_INTERNAL_EXPRESSION_TYPE)
|
|
91
91
|
return self._expression
|
|
92
92
|
|
|
93
|
+
def get_insert_expression(self) -> exp.Insert:
|
|
94
|
+
"""Get the insert expression (public API)."""
|
|
95
|
+
return self._get_insert_expression()
|
|
96
|
+
|
|
93
97
|
def values(self, *values: Any, **kwargs: Any) -> "Self":
|
|
94
98
|
"""Adds a row of values to the INSERT statement.
|
|
95
99
|
|
|
@@ -129,7 +133,7 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
129
133
|
if hasattr(values_0, "items") and hasattr(values_0, "keys"):
|
|
130
134
|
return self.values_from_dict(values_0)
|
|
131
135
|
|
|
132
|
-
insert_expr = self.
|
|
136
|
+
insert_expr = self.get_insert_expression()
|
|
133
137
|
|
|
134
138
|
if self._columns and len(values) != len(self._columns):
|
|
135
139
|
msg = ERR_MSG_VALUES_COLUMNS_MISMATCH.format(values_len=len(values), columns_len=len(self._columns))
|
|
@@ -160,9 +164,9 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
160
164
|
if self._columns and i < len(self._columns):
|
|
161
165
|
column_str = str(self._columns[i])
|
|
162
166
|
column_name = column_str.rsplit(".", maxsplit=1)[-1] if "." in column_str else column_str
|
|
163
|
-
param_name = self.
|
|
167
|
+
param_name = self.generate_unique_parameter_name(column_name)
|
|
164
168
|
else:
|
|
165
|
-
param_name = self.
|
|
169
|
+
param_name = self.generate_unique_parameter_name(f"value_{i + 1}")
|
|
166
170
|
_, param_name = self.add_parameter(value, name=param_name)
|
|
167
171
|
value_placeholders.append(exp.Placeholder(this=param_name))
|
|
168
172
|
|
|
@@ -336,7 +340,7 @@ class ConflictBuilder:
|
|
|
336
340
|
).do_nothing()
|
|
337
341
|
```
|
|
338
342
|
"""
|
|
339
|
-
insert_expr = self._insert_builder.
|
|
343
|
+
insert_expr = self._insert_builder.get_insert_expression()
|
|
340
344
|
|
|
341
345
|
# Create ON CONFLICT with proper structure
|
|
342
346
|
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
|
|
@@ -363,7 +367,7 @@ class ConflictBuilder:
|
|
|
363
367
|
)
|
|
364
368
|
```
|
|
365
369
|
"""
|
|
366
|
-
insert_expr = self._insert_builder.
|
|
370
|
+
insert_expr = self._insert_builder.get_insert_expression()
|
|
367
371
|
|
|
368
372
|
# Create SET expressions for the UPDATE
|
|
369
373
|
set_expressions = []
|
|
@@ -394,7 +398,7 @@ class ConflictBuilder:
|
|
|
394
398
|
value_expr = val
|
|
395
399
|
else:
|
|
396
400
|
# Create parameter for regular values
|
|
397
|
-
param_name = self._insert_builder.
|
|
401
|
+
param_name = self._insert_builder.generate_unique_parameter_name(col)
|
|
398
402
|
_, param_name = self._insert_builder.add_parameter(val, name=param_name)
|
|
399
403
|
value_expr = exp.Placeholder(this=param_name)
|
|
400
404
|
|
|
@@ -18,6 +18,27 @@ from sqlspec.utils.type_guards import (
|
|
|
18
18
|
)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
def extract_column_name(column: Union[str, exp.Column]) -> str:
|
|
22
|
+
"""Extract column name from column expression for parameter naming.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
column: Column expression (string or SQLGlot Column)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Column name as string for use as parameter name
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(column, str):
|
|
31
|
+
if "." in column:
|
|
32
|
+
return column.split(".")[-1]
|
|
33
|
+
return column
|
|
34
|
+
if isinstance(column, exp.Column):
|
|
35
|
+
try:
|
|
36
|
+
return str(column.this.this)
|
|
37
|
+
except AttributeError:
|
|
38
|
+
return str(column.this) if column.this else "column"
|
|
39
|
+
return "column"
|
|
40
|
+
|
|
41
|
+
|
|
21
42
|
def parse_column_expression(
|
|
22
43
|
column_input: Union[str, exp.Expression, Any], builder: Optional[Any] = None
|
|
23
44
|
) -> exp.Expression:
|
|
@@ -139,10 +160,8 @@ def parse_condition_expression(
|
|
|
139
160
|
if value is None:
|
|
140
161
|
return exp.Is(this=column_expr, expression=exp.null())
|
|
141
162
|
if builder and has_parameter_builder(builder):
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
column_name = _extract_column_name(column)
|
|
145
|
-
param_name = builder._generate_unique_parameter_name(column_name)
|
|
163
|
+
column_name = extract_column_name(column)
|
|
164
|
+
param_name = builder.generate_unique_parameter_name(column_name)
|
|
146
165
|
_, param_name = builder.add_parameter(value, name=param_name)
|
|
147
166
|
return exp.EQ(this=column_expr, expression=exp.Placeholder(this=param_name))
|
|
148
167
|
if isinstance(value, str):
|
sqlspec/builder/_update.py
CHANGED
|
@@ -131,7 +131,7 @@ class Update(
|
|
|
131
131
|
subquery_exp = exp.paren(exp.maybe_parse(subquery.sql, dialect=self.dialect))
|
|
132
132
|
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
|
|
133
133
|
|
|
134
|
-
subquery_parameters = table.
|
|
134
|
+
subquery_parameters = table.parameters
|
|
135
135
|
if subquery_parameters:
|
|
136
136
|
for p_name, p_value in subquery_parameters.items():
|
|
137
137
|
self.add_parameter(p_value, name=p_name)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
# pyright: reportPrivateUsage=false
|
|
1
2
|
"""CTE and set operation mixins.
|
|
2
3
|
|
|
3
4
|
Provides mixins for Common Table Expressions (WITH clause) and
|
|
4
5
|
set operations (UNION, INTERSECT, EXCEPT).
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
|
-
from typing import Any, Optional, Union
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
8
9
|
|
|
9
10
|
from mypy_extensions import trait
|
|
10
11
|
from sqlglot import exp
|
|
@@ -12,6 +13,9 @@ from typing_extensions import Self
|
|
|
12
13
|
|
|
13
14
|
from sqlspec.exceptions import SQLBuilderError
|
|
14
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from sqlspec.builder._base import QueryBuilder
|
|
18
|
+
|
|
15
19
|
__all__ = ("CommonTableExpressionMixin", "SetOperationMixin")
|
|
16
20
|
|
|
17
21
|
|
|
@@ -20,8 +24,10 @@ class CommonTableExpressionMixin:
|
|
|
20
24
|
"""Mixin providing WITH clause (Common Table Expressions) support for SQL builders."""
|
|
21
25
|
|
|
22
26
|
__slots__ = ()
|
|
23
|
-
|
|
24
|
-
|
|
27
|
+
|
|
28
|
+
# Type annotations for PyRight - these will be provided by the base class
|
|
29
|
+
def get_expression(self) -> Optional[exp.Expression]: ...
|
|
30
|
+
def set_expression(self, expression: exp.Expression) -> None: ...
|
|
25
31
|
|
|
26
32
|
_with_ctes: Any # Provided by QueryBuilder
|
|
27
33
|
dialect: Any # Provided by QueryBuilder
|
|
@@ -60,12 +66,14 @@ class CommonTableExpressionMixin:
|
|
|
60
66
|
Returns:
|
|
61
67
|
The current builder instance for method chaining.
|
|
62
68
|
"""
|
|
63
|
-
|
|
69
|
+
builder = cast("QueryBuilder", self)
|
|
70
|
+
expression = builder.get_expression()
|
|
71
|
+
if expression is None:
|
|
64
72
|
msg = "Cannot add WITH clause: expression not initialized."
|
|
65
73
|
raise SQLBuilderError(msg)
|
|
66
74
|
|
|
67
|
-
if not isinstance(
|
|
68
|
-
msg = f"Cannot add WITH clause to {type(
|
|
75
|
+
if not isinstance(expression, (exp.Select, exp.Insert, exp.Update, exp.Delete)):
|
|
76
|
+
msg = f"Cannot add WITH clause to {type(expression).__name__} expression."
|
|
69
77
|
raise SQLBuilderError(msg)
|
|
70
78
|
|
|
71
79
|
cte_expr: Optional[exp.Expression] = None
|
|
@@ -103,19 +111,18 @@ class CommonTableExpressionMixin:
|
|
|
103
111
|
else:
|
|
104
112
|
cte_alias_expr = exp.alias_(cte_expr, name)
|
|
105
113
|
|
|
106
|
-
existing_with =
|
|
114
|
+
existing_with = expression.args.get("with")
|
|
107
115
|
if existing_with:
|
|
108
116
|
existing_with.expressions.append(cte_alias_expr)
|
|
109
117
|
if recursive:
|
|
110
118
|
existing_with.set("recursive", recursive)
|
|
111
119
|
else:
|
|
112
120
|
# Only SELECT, INSERT, UPDATE support WITH clauses
|
|
113
|
-
if hasattr(
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
self._expression = self._expression.with_(cte_alias_expr, as_=name, copy=False)
|
|
121
|
+
if hasattr(expression, "with_") and isinstance(expression, (exp.Select, exp.Insert, exp.Update)):
|
|
122
|
+
updated_expression = expression.with_(cte_alias_expr, as_=name, copy=False)
|
|
123
|
+
builder.set_expression(updated_expression)
|
|
117
124
|
if recursive:
|
|
118
|
-
with_clause =
|
|
125
|
+
with_clause = updated_expression.find(exp.With)
|
|
119
126
|
if with_clause:
|
|
120
127
|
with_clause.set("recursive", recursive)
|
|
121
128
|
self._with_ctes[name] = exp.CTE(this=cte_expr, alias=exp.to_table(name))
|
|
@@ -128,10 +135,12 @@ class SetOperationMixin:
|
|
|
128
135
|
"""Mixin providing set operations (UNION, INTERSECT, EXCEPT) for SELECT builders."""
|
|
129
136
|
|
|
130
137
|
__slots__ = ()
|
|
131
|
-
# Type annotation for PyRight - this will be provided by the base class
|
|
132
|
-
_expression: Optional[exp.Expression]
|
|
133
138
|
|
|
134
|
-
|
|
139
|
+
# Type annotations for PyRight - these will be provided by the base class
|
|
140
|
+
def get_expression(self) -> Optional[exp.Expression]: ...
|
|
141
|
+
def set_expression(self, expression: exp.Expression) -> None: ...
|
|
142
|
+
def set_parameters(self, parameters: "dict[str, Any]") -> None: ...
|
|
143
|
+
|
|
135
144
|
dialect: Any = None
|
|
136
145
|
|
|
137
146
|
def build(self) -> Any:
|
|
@@ -162,7 +171,7 @@ class SetOperationMixin:
|
|
|
162
171
|
union_expr = exp.union(left_expr, right_expr, distinct=not all_)
|
|
163
172
|
new_builder = type(self)()
|
|
164
173
|
new_builder.dialect = self.dialect
|
|
165
|
-
new_builder.
|
|
174
|
+
cast("QueryBuilder", new_builder).set_expression(union_expr)
|
|
166
175
|
merged_parameters = dict(left_query.parameters)
|
|
167
176
|
for param_name, param_value in right_query.parameters.items():
|
|
168
177
|
if param_name in merged_parameters:
|
|
@@ -181,11 +190,11 @@ class SetOperationMixin:
|
|
|
181
190
|
|
|
182
191
|
right_expr = right_expr.transform(rename_parameter)
|
|
183
192
|
union_expr = exp.union(left_expr, right_expr, distinct=not all_)
|
|
184
|
-
new_builder.
|
|
193
|
+
cast("QueryBuilder", new_builder).set_expression(union_expr)
|
|
185
194
|
merged_parameters[new_param_name] = param_value
|
|
186
195
|
else:
|
|
187
196
|
merged_parameters[param_name] = param_value
|
|
188
|
-
new_builder.
|
|
197
|
+
new_builder.set_parameters(merged_parameters)
|
|
189
198
|
return new_builder
|
|
190
199
|
|
|
191
200
|
def intersect(self, other: Any) -> Self:
|
|
@@ -210,10 +219,10 @@ class SetOperationMixin:
|
|
|
210
219
|
intersect_expr = exp.intersect(left_expr, right_expr, distinct=True)
|
|
211
220
|
new_builder = type(self)()
|
|
212
221
|
new_builder.dialect = self.dialect
|
|
213
|
-
new_builder.
|
|
222
|
+
cast("QueryBuilder", new_builder).set_expression(intersect_expr)
|
|
214
223
|
merged_parameters = dict(left_query.parameters)
|
|
215
224
|
merged_parameters.update(right_query.parameters)
|
|
216
|
-
new_builder.
|
|
225
|
+
new_builder.set_parameters(merged_parameters)
|
|
217
226
|
return new_builder
|
|
218
227
|
|
|
219
228
|
def except_(self, other: Any) -> Self:
|
|
@@ -238,8 +247,8 @@ class SetOperationMixin:
|
|
|
238
247
|
except_expr = exp.except_(left_expr, right_expr)
|
|
239
248
|
new_builder = type(self)()
|
|
240
249
|
new_builder.dialect = self.dialect
|
|
241
|
-
new_builder.
|
|
250
|
+
cast("QueryBuilder", new_builder).set_expression(except_expr)
|
|
242
251
|
merged_parameters = dict(left_query.parameters)
|
|
243
252
|
merged_parameters.update(right_query.parameters)
|
|
244
|
-
new_builder.
|
|
253
|
+
new_builder.set_parameters(merged_parameters)
|
|
245
254
|
return new_builder
|