sqlspec 0.15.0__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlspec/_sql.py +702 -44
  2. sqlspec/builder/_base.py +77 -44
  3. sqlspec/builder/_column.py +0 -4
  4. sqlspec/builder/_ddl.py +15 -52
  5. sqlspec/builder/_ddl_utils.py +0 -1
  6. sqlspec/builder/_delete.py +4 -5
  7. sqlspec/builder/_insert.py +235 -44
  8. sqlspec/builder/_merge.py +17 -2
  9. sqlspec/builder/_parsing_utils.py +42 -14
  10. sqlspec/builder/_select.py +29 -33
  11. sqlspec/builder/_update.py +4 -2
  12. sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
  13. sqlspec/builder/mixins/_delete_operations.py +6 -1
  14. sqlspec/builder/mixins/_insert_operations.py +126 -24
  15. sqlspec/builder/mixins/_join_operations.py +44 -10
  16. sqlspec/builder/mixins/_merge_operations.py +183 -25
  17. sqlspec/builder/mixins/_order_limit_operations.py +15 -3
  18. sqlspec/builder/mixins/_pivot_operations.py +11 -2
  19. sqlspec/builder/mixins/_select_operations.py +21 -14
  20. sqlspec/builder/mixins/_update_operations.py +80 -32
  21. sqlspec/builder/mixins/_where_clause.py +201 -66
  22. sqlspec/core/cache.py +26 -28
  23. sqlspec/core/compiler.py +58 -37
  24. sqlspec/core/filters.py +12 -10
  25. sqlspec/core/parameters.py +80 -52
  26. sqlspec/core/result.py +30 -17
  27. sqlspec/core/statement.py +47 -22
  28. sqlspec/driver/_async.py +76 -46
  29. sqlspec/driver/_common.py +25 -6
  30. sqlspec/driver/_sync.py +73 -43
  31. sqlspec/driver/mixins/_result_tools.py +62 -37
  32. sqlspec/driver/mixins/_sql_translator.py +61 -11
  33. sqlspec/extensions/litestar/cli.py +1 -1
  34. sqlspec/extensions/litestar/plugin.py +2 -2
  35. sqlspec/protocols.py +7 -0
  36. sqlspec/utils/sync_tools.py +1 -1
  37. sqlspec/utils/type_guards.py +7 -3
  38. {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/METADATA +1 -1
  39. {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/RECORD +43 -43
  40. {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/WHEEL +0 -0
  41. {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/entry_points.txt +0 -0
  42. {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/licenses/LICENSE +0 -0
  43. {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/licenses/NOTICE +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Any, Optional, Union
4
4
 
5
+ from mypy_extensions import trait
5
6
  from sqlglot import exp
6
7
  from typing_extensions import Self
7
8
 
@@ -18,10 +19,12 @@ __all__ = (
18
19
  )
19
20
 
20
21
 
22
+ @trait
21
23
  class MergeIntoClauseMixin:
22
24
  """Mixin providing INTO clause for MERGE builders."""
23
25
 
24
- _expression: Optional[exp.Expression] = None
26
+ __slots__ = ()
27
+ _expression: Optional[exp.Expression]
25
28
 
26
29
  def into(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
27
30
  """Set the target table for the MERGE operation (INTO clause).
@@ -35,17 +38,24 @@ class MergeIntoClauseMixin:
35
38
  The current builder instance for method chaining.
36
39
  """
37
40
  if self._expression is None:
38
- self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[])) # pyright: ignore
39
- if not isinstance(self._expression, exp.Merge): # pyright: ignore
40
- self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[])) # pyright: ignore
41
+ self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[]))
42
+ if not isinstance(self._expression, exp.Merge):
43
+ self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[]))
41
44
  self._expression.set("this", exp.to_table(table, alias=alias) if isinstance(table, str) else table)
42
45
  return self
43
46
 
44
47
 
48
+ @trait
45
49
  class MergeUsingClauseMixin:
46
50
  """Mixin providing USING clause for MERGE builders."""
47
51
 
48
- _expression: Optional[exp.Expression] = None
52
+ __slots__ = ()
53
+ _expression: Optional[exp.Expression]
54
+
55
+ def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
56
+ """Add parameter - provided by QueryBuilder."""
57
+ msg = "Method must be provided by QueryBuilder subclass"
58
+ raise NotImplementedError(msg)
49
59
 
50
60
  def using(self, source: Union[str, exp.Expression, Any], alias: Optional[str] = None) -> Self:
51
61
  """Set the source data for the MERGE operation (USING clause).
@@ -73,7 +83,7 @@ class MergeUsingClauseMixin:
73
83
  subquery_builder_parameters = source.parameters
74
84
  if subquery_builder_parameters:
75
85
  for p_name, p_value in subquery_builder_parameters.items():
76
- self.add_parameter(p_value, name=p_name) # type: ignore[attr-defined]
86
+ self.add_parameter(p_value, name=p_name)
77
87
 
78
88
  subquery_exp = exp.paren(getattr(source, "_expression", exp.select()))
79
89
  source_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
@@ -89,10 +99,12 @@ class MergeUsingClauseMixin:
89
99
  return self
90
100
 
91
101
 
102
+ @trait
92
103
  class MergeOnClauseMixin:
93
104
  """Mixin providing ON clause for MERGE builders."""
94
105
 
95
- _expression: Optional[exp.Expression] = None
106
+ __slots__ = ()
107
+ _expression: Optional[exp.Expression]
96
108
 
97
109
  def on(self, condition: Union[str, exp.Expression]) -> Self:
98
110
  """Set the join condition for the MERGE operation (ON clause).
@@ -131,10 +143,22 @@ class MergeOnClauseMixin:
131
143
  return self
132
144
 
133
145
 
146
+ @trait
134
147
  class MergeMatchedClauseMixin:
135
148
  """Mixin providing WHEN MATCHED THEN ... clauses for MERGE builders."""
136
149
 
137
- _expression: Optional[exp.Expression] = None
150
+ __slots__ = ()
151
+ _expression: Optional[exp.Expression]
152
+
153
+ def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
154
+ """Add parameter - provided by QueryBuilder."""
155
+ msg = "Method must be provided by QueryBuilder subclass"
156
+ raise NotImplementedError(msg)
157
+
158
+ def _generate_unique_parameter_name(self, base_name: str) -> str:
159
+ """Generate unique parameter name - provided by QueryBuilder."""
160
+ msg = "Method must be provided by QueryBuilder subclass"
161
+ raise NotImplementedError(msg)
138
162
 
139
163
  def _add_when_clause(self, when_clause: exp.When) -> None:
140
164
  """Helper to add a WHEN clause to the MERGE statement.
@@ -143,9 +167,9 @@ class MergeMatchedClauseMixin:
143
167
  when_clause: The WHEN clause to add.
144
168
  """
145
169
  if self._expression is None:
146
- self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[]))
170
+ self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[])) # type: ignore[misc]
147
171
  if not isinstance(self._expression, exp.Merge):
148
- self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[]))
172
+ self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[])) # type: ignore[misc]
149
173
 
150
174
  whens = self._expression.args.get("whens")
151
175
  if not whens:
@@ -155,14 +179,23 @@ class MergeMatchedClauseMixin:
155
179
  whens.append("expressions", when_clause)
156
180
 
157
181
  def when_matched_then_update(
158
- self, set_values: dict[str, Any], condition: Optional[Union[str, exp.Expression]] = None
182
+ self,
183
+ set_values: Optional[dict[str, Any]] = None,
184
+ condition: Optional[Union[str, exp.Expression]] = None,
185
+ **kwargs: Any,
159
186
  ) -> Self:
160
187
  """Define the UPDATE action for matched rows.
161
188
 
189
+ Supports:
190
+ - when_matched_then_update({"column": value})
191
+ - when_matched_then_update(column=value, other_column=other_value)
192
+ - when_matched_then_update({"column": value}, other_column=other_value)
193
+
162
194
  Args:
163
195
  set_values: A dictionary of column names and their new values to set.
164
196
  The values will be parameterized.
165
197
  condition: An optional additional condition for this specific action.
198
+ **kwargs: Column-value pairs to update on match.
166
199
 
167
200
  Raises:
168
201
  SQLBuilderError: If the condition type is unsupported.
@@ -170,10 +203,48 @@ class MergeMatchedClauseMixin:
170
203
  Returns:
171
204
  The current builder instance for method chaining.
172
205
  """
206
+ # Combine set_values dict and kwargs
207
+ all_values = dict(set_values or {}, **kwargs)
208
+
209
+ if not all_values:
210
+ msg = "No update values provided. Use set_values dict or kwargs."
211
+ raise SQLBuilderError(msg)
212
+
173
213
  update_expressions: list[exp.EQ] = []
174
- for col, val in set_values.items():
175
- param_name = self.add_parameter(val)[1] # type: ignore[attr-defined]
176
- update_expressions.append(exp.EQ(this=exp.column(col), expression=exp.var(param_name)))
214
+ for col, val in all_values.items():
215
+ if hasattr(val, "expression") and hasattr(val, "sql"):
216
+ # Handle SQL objects (from sql.raw with parameters)
217
+ expression = getattr(val, "expression", None)
218
+ if expression is not None and isinstance(expression, exp.Expression):
219
+ # Merge parameters from SQL object into builder
220
+ if hasattr(val, "parameters"):
221
+ sql_parameters = getattr(val, "parameters", {})
222
+ for param_name, param_value in sql_parameters.items():
223
+ self.add_parameter(param_value, name=param_name)
224
+ value_expr = expression
225
+ else:
226
+ # If expression is None, fall back to parsing the raw SQL
227
+ sql_text = getattr(val, "sql", "")
228
+ # Merge parameters even when parsing raw SQL
229
+ if hasattr(val, "parameters"):
230
+ sql_parameters = getattr(val, "parameters", {})
231
+ for param_name, param_value in sql_parameters.items():
232
+ self.add_parameter(param_value, name=param_name)
233
+ # Check if sql_text is callable (like Expression.sql method)
234
+ if callable(sql_text):
235
+ sql_text = str(val)
236
+ value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
237
+ elif isinstance(val, exp.Expression):
238
+ value_expr = val
239
+ else:
240
+ column_name = col if isinstance(col, str) else str(col)
241
+ if "." in column_name:
242
+ column_name = column_name.split(".")[-1]
243
+ param_name = self._generate_unique_parameter_name(column_name)
244
+ param_name = self.add_parameter(val, name=param_name)[1]
245
+ value_expr = exp.Placeholder(this=param_name)
246
+
247
+ update_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
177
248
 
178
249
  when_args: dict[str, Any] = {"matched": True, "then": exp.Update(expressions=update_expressions)}
179
250
 
@@ -234,10 +305,28 @@ class MergeMatchedClauseMixin:
234
305
  return self
235
306
 
236
307
 
308
+ @trait
237
309
  class MergeNotMatchedClauseMixin:
238
310
  """Mixin providing WHEN NOT MATCHED THEN ... clauses for MERGE builders."""
239
311
 
240
- _expression: Optional[exp.Expression] = None
312
+ __slots__ = ()
313
+
314
+ _expression: Optional[exp.Expression]
315
+
316
+ def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
317
+ """Add parameter - provided by QueryBuilder."""
318
+ msg = "Method must be provided by QueryBuilder subclass"
319
+ raise NotImplementedError(msg)
320
+
321
+ def _generate_unique_parameter_name(self, base_name: str) -> str:
322
+ """Generate unique parameter name - provided by QueryBuilder."""
323
+ msg = "Method must be provided by QueryBuilder subclass"
324
+ raise NotImplementedError(msg)
325
+
326
+ def _add_when_clause(self, when_clause: exp.When) -> None:
327
+ """Helper to add a WHEN clause to the MERGE statement - provided by QueryBuilder."""
328
+ msg = "Method must be provided by QueryBuilder subclass"
329
+ raise NotImplementedError(msg)
241
330
 
242
331
  def when_not_matched_then_insert(
243
332
  self,
@@ -270,8 +359,12 @@ class MergeNotMatchedClauseMixin:
270
359
  raise SQLBuilderError(msg)
271
360
 
272
361
  parameterized_values: list[exp.Expression] = []
273
- for val in values:
274
- param_name = self.add_parameter(val)[1] # type: ignore[attr-defined]
362
+ for i, val in enumerate(values):
363
+ column_name = columns[i] if isinstance(columns[i], str) else str(columns[i])
364
+ if "." in column_name:
365
+ column_name = column_name.split(".")[-1]
366
+ param_name = self._generate_unique_parameter_name(column_name)
367
+ param_name = self.add_parameter(val, name=param_name)[1]
275
368
  parameterized_values.append(exp.var(param_name))
276
369
 
277
370
  insert_args["this"] = exp.Tuple(expressions=[exp.column(c) for c in columns])
@@ -308,25 +401,52 @@ class MergeNotMatchedClauseMixin:
308
401
  when_args["this"] = condition_expr
309
402
 
310
403
  when_clause = exp.When(**when_args)
311
- self._add_when_clause(when_clause) # type: ignore[attr-defined]
404
+ self._add_when_clause(when_clause)
312
405
  return self
313
406
 
314
407
 
408
+ @trait
315
409
  class MergeNotMatchedBySourceClauseMixin:
316
410
  """Mixin providing WHEN NOT MATCHED BY SOURCE THEN ... clauses for MERGE builders."""
317
411
 
318
- _expression: Optional[exp.Expression] = None
412
+ __slots__ = ()
413
+
414
+ _expression: Optional[exp.Expression]
415
+
416
+ def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
417
+ """Add parameter - provided by QueryBuilder."""
418
+ msg = "Method must be provided by QueryBuilder subclass"
419
+ raise NotImplementedError(msg)
420
+
421
+ def _generate_unique_parameter_name(self, base_name: str) -> str:
422
+ """Generate unique parameter name - provided by QueryBuilder."""
423
+ msg = "Method must be provided by QueryBuilder subclass"
424
+ raise NotImplementedError(msg)
425
+
426
+ def _add_when_clause(self, when_clause: exp.When) -> None:
427
+ """Helper to add a WHEN clause to the MERGE statement - provided by QueryBuilder."""
428
+ msg = "Method must be provided by QueryBuilder subclass"
429
+ raise NotImplementedError(msg)
319
430
 
320
431
  def when_not_matched_by_source_then_update(
321
- self, set_values: dict[str, Any], condition: Optional[Union[str, exp.Expression]] = None
432
+ self,
433
+ set_values: Optional[dict[str, Any]] = None,
434
+ condition: Optional[Union[str, exp.Expression]] = None,
435
+ **kwargs: Any,
322
436
  ) -> Self:
323
437
  """Define the UPDATE action for rows not matched by source.
324
438
 
325
439
  This is useful for handling rows that exist in the target but not in the source.
326
440
 
441
+ Supports:
442
+ - when_not_matched_by_source_then_update({"column": value})
443
+ - when_not_matched_by_source_then_update(column=value, other_column=other_value)
444
+ - when_not_matched_by_source_then_update({"column": value}, other_column=other_value)
445
+
327
446
  Args:
328
447
  set_values: A dictionary of column names and their new values to set.
329
448
  condition: An optional additional condition for this specific action.
449
+ **kwargs: Column-value pairs to update when not matched by source.
330
450
 
331
451
  Raises:
332
452
  SQLBuilderError: If the condition type is unsupported.
@@ -334,10 +454,48 @@ class MergeNotMatchedBySourceClauseMixin:
334
454
  Returns:
335
455
  The current builder instance for method chaining.
336
456
  """
457
+ # Combine set_values dict and kwargs
458
+ all_values = dict(set_values or {}, **kwargs)
459
+
460
+ if not all_values:
461
+ msg = "No update values provided. Use set_values dict or kwargs."
462
+ raise SQLBuilderError(msg)
463
+
337
464
  update_expressions: list[exp.EQ] = []
338
- for col, val in set_values.items():
339
- param_name = self.add_parameter(val)[1] # type: ignore[attr-defined]
340
- update_expressions.append(exp.EQ(this=exp.column(col), expression=exp.var(param_name)))
465
+ for col, val in all_values.items():
466
+ if hasattr(val, "expression") and hasattr(val, "sql"):
467
+ # Handle SQL objects (from sql.raw with parameters)
468
+ expression = getattr(val, "expression", None)
469
+ if expression is not None and isinstance(expression, exp.Expression):
470
+ # Merge parameters from SQL object into builder
471
+ if hasattr(val, "parameters"):
472
+ sql_parameters = getattr(val, "parameters", {})
473
+ for param_name, param_value in sql_parameters.items():
474
+ self.add_parameter(param_value, name=param_name)
475
+ value_expr = expression
476
+ else:
477
+ # If expression is None, fall back to parsing the raw SQL
478
+ sql_text = getattr(val, "sql", "")
479
+ # Merge parameters even when parsing raw SQL
480
+ if hasattr(val, "parameters"):
481
+ sql_parameters = getattr(val, "parameters", {})
482
+ for param_name, param_value in sql_parameters.items():
483
+ self.add_parameter(param_value, name=param_name)
484
+ # Check if sql_text is callable (like Expression.sql method)
485
+ if callable(sql_text):
486
+ sql_text = str(val)
487
+ value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
488
+ elif isinstance(val, exp.Expression):
489
+ value_expr = val
490
+ else:
491
+ column_name = col if isinstance(col, str) else str(col)
492
+ if "." in column_name:
493
+ column_name = column_name.split(".")[-1]
494
+ param_name = self._generate_unique_parameter_name(column_name)
495
+ param_name = self.add_parameter(val, name=param_name)[1]
496
+ value_expr = exp.Placeholder(this=param_name)
497
+
498
+ update_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
341
499
 
342
500
  when_args: dict[str, Any] = {
343
501
  "matched": False,
@@ -363,7 +521,7 @@ class MergeNotMatchedBySourceClauseMixin:
363
521
  when_args["this"] = condition_expr
364
522
 
365
523
  when_clause = exp.When(**when_args)
366
- self._add_when_clause(when_clause) # type: ignore[attr-defined]
524
+ self._add_when_clause(when_clause)
367
525
  return self
368
526
 
369
527
  def when_not_matched_by_source_then_delete(self, condition: Optional[Union[str, exp.Expression]] = None) -> Self:
@@ -400,5 +558,5 @@ class MergeNotMatchedBySourceClauseMixin:
400
558
  when_args["this"] = condition_expr
401
559
 
402
560
  when_clause = exp.When(**when_args)
403
- self._add_when_clause(when_clause) # type: ignore[attr-defined]
561
+ self._add_when_clause(when_clause)
404
562
  return self
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import TYPE_CHECKING, Optional, Union, cast
4
4
 
5
+ from mypy_extensions import trait
5
6
  from sqlglot import exp
6
7
  from typing_extensions import Self
7
8
 
@@ -14,10 +15,14 @@ if TYPE_CHECKING:
14
15
  __all__ = ("LimitOffsetClauseMixin", "OrderByClauseMixin", "ReturningClauseMixin")
15
16
 
16
17
 
18
+ @trait
17
19
  class OrderByClauseMixin:
18
20
  """Mixin providing ORDER BY clause."""
19
21
 
20
- _expression: Optional[exp.Expression] = None
22
+ __slots__ = ()
23
+
24
+ # Type annotation for PyRight - this will be provided by the base class
25
+ _expression: Optional[exp.Expression]
21
26
 
22
27
  def order_by(self, *items: Union[str, exp.Ordered], desc: bool = False) -> Self:
23
28
  """Add ORDER BY clause.
@@ -50,10 +55,14 @@ class OrderByClauseMixin:
50
55
  return cast("Self", builder)
51
56
 
52
57
 
58
+ @trait
53
59
  class LimitOffsetClauseMixin:
54
60
  """Mixin providing LIMIT and OFFSET clauses."""
55
61
 
56
- _expression: Optional[exp.Expression] = None
62
+ __slots__ = ()
63
+
64
+ # Type annotation for PyRight - this will be provided by the base class
65
+ _expression: Optional[exp.Expression]
57
66
 
58
67
  def limit(self, value: int) -> Self:
59
68
  """Add LIMIT clause.
@@ -94,10 +103,13 @@ class LimitOffsetClauseMixin:
94
103
  return cast("Self", builder)
95
104
 
96
105
 
106
+ @trait
97
107
  class ReturningClauseMixin:
98
108
  """Mixin providing RETURNING clause."""
99
109
 
100
- _expression: Optional[exp.Expression] = None
110
+ __slots__ = ()
111
+ # Type annotation for PyRight - this will be provided by the base class
112
+ _expression: Optional[exp.Expression]
101
113
 
102
114
  def returning(self, *columns: Union[str, exp.Expression]) -> Self:
103
115
  """Add RETURNING clause to the statement.
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import TYPE_CHECKING, Optional, Union, cast
4
4
 
5
+ from mypy_extensions import trait
5
6
  from sqlglot import exp
6
7
 
7
8
  if TYPE_CHECKING:
@@ -12,10 +13,14 @@ if TYPE_CHECKING:
12
13
  __all__ = ("PivotClauseMixin", "UnpivotClauseMixin")
13
14
 
14
15
 
16
+ @trait
15
17
  class PivotClauseMixin:
16
18
  """Mixin class to add PIVOT functionality to a Select."""
17
19
 
18
- _expression: "Optional[exp.Expression]" = None
20
+ __slots__ = ()
21
+ # Type annotation for PyRight - this will be provided by the base class
22
+ _expression: Optional[exp.Expression]
23
+
19
24
  dialect: "DialectType" = None
20
25
 
21
26
  def pivot(
@@ -79,10 +84,14 @@ class PivotClauseMixin:
79
84
  return cast("Select", self)
80
85
 
81
86
 
87
+ @trait
82
88
  class UnpivotClauseMixin:
83
89
  """Mixin class to add UNPIVOT functionality to a Select."""
84
90
 
85
- _expression: "Optional[exp.Expression]" = None
91
+ __slots__ = ()
92
+ # Type annotation for PyRight - this will be provided by the base class
93
+ _expression: Optional[exp.Expression]
94
+
86
95
  dialect: "DialectType" = None
87
96
 
88
97
  def unpivot(
@@ -3,6 +3,7 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import TYPE_CHECKING, Any, Optional, Union, cast
5
5
 
6
+ from mypy_extensions import trait
6
7
  from sqlglot import exp
7
8
  from typing_extensions import Self
8
9
 
@@ -11,19 +12,23 @@ from sqlspec.exceptions import SQLBuilderError
11
12
  from sqlspec.utils.type_guards import has_query_builder_parameters, is_expression
12
13
 
13
14
  if TYPE_CHECKING:
14
- from sqlspec.builder._base import QueryBuilder
15
15
  from sqlspec.builder._column import Column, FunctionColumn
16
+ from sqlspec.core.statement import SQL
16
17
  from sqlspec.protocols import SelectBuilderProtocol, SQLBuilderProtocol
17
18
 
18
19
  __all__ = ("CaseBuilder", "SelectClauseMixin")
19
20
 
20
21
 
22
+ @trait
21
23
  class SelectClauseMixin:
22
24
  """Consolidated mixin providing all SELECT-related clauses and functionality."""
23
25
 
24
- _expression: Optional[exp.Expression] = None
26
+ __slots__ = ()
25
27
 
26
- def select(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn"]) -> Self:
28
+ # Type annotation for PyRight - this will be provided by the base class
29
+ _expression: Optional[exp.Expression]
30
+
31
+ def select(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn", "SQL"]) -> Self:
27
32
  """Add columns to SELECT clause.
28
33
 
29
34
  Raises:
@@ -39,10 +44,10 @@ class SelectClauseMixin:
39
44
  msg = "Cannot add select columns to a non-SELECT expression."
40
45
  raise SQLBuilderError(msg)
41
46
  for column in columns:
42
- builder._expression = builder._expression.select(parse_column_expression(column), copy=False)
47
+ builder._expression = builder._expression.select(parse_column_expression(column, builder), copy=False)
43
48
  return cast("Self", builder)
44
49
 
45
- def distinct(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn"]) -> Self:
50
+ def distinct(self, *columns: Union[str, exp.Expression, "Column", "FunctionColumn", "SQL"]) -> Self:
46
51
  """Add DISTINCT clause to SELECT.
47
52
 
48
53
  Args:
@@ -63,7 +68,7 @@ class SelectClauseMixin:
63
68
  if not columns:
64
69
  builder._expression.set("distinct", exp.Distinct())
65
70
  else:
66
- distinct_columns = [parse_column_expression(column) for column in columns]
71
+ distinct_columns = [parse_column_expression(column, builder) for column in columns]
67
72
  builder._expression.set("distinct", exp.Distinct(expressions=distinct_columns))
68
73
  return cast("Self", builder)
69
74
 
@@ -529,7 +534,7 @@ class SelectClauseMixin:
529
534
  Returns:
530
535
  CaseBuilder: A CaseBuilder instance for building the CASE expression.
531
536
  """
532
- builder = cast("QueryBuilder", self) # pyright: ignore
537
+ builder = cast("SelectBuilderProtocol", self)
533
538
  return CaseBuilder(builder, alias)
534
539
 
535
540
 
@@ -537,15 +542,15 @@ class SelectClauseMixin:
537
542
  class CaseBuilder:
538
543
  """Builder for CASE expressions."""
539
544
 
540
- _parent: "QueryBuilder" # pyright: ignore
545
+ _parent: "SelectBuilderProtocol"
541
546
  _alias: Optional[str]
542
547
  _case_expr: exp.Case
543
548
 
544
- def __init__(self, parent: "QueryBuilder", alias: "Optional[str]" = None) -> None:
549
+ def __init__(self, parent: "SelectBuilderProtocol", alias: "Optional[str]" = None) -> None:
545
550
  """Initialize CaseBuilder.
546
551
 
547
552
  Args:
548
- parent: The parent builder.
553
+ parent: The parent builder with select capabilities.
549
554
  alias: Optional alias for the CASE expression.
550
555
  """
551
556
  self._parent = parent
@@ -563,7 +568,8 @@ class CaseBuilder:
563
568
  CaseBuilder: The current builder instance for method chaining.
564
569
  """
565
570
  cond_expr = exp.condition(condition) if isinstance(condition, str) else condition
566
- param_name = self._parent.add_parameter(value)[1]
571
+ param_name = self._parent._generate_unique_parameter_name("case_when_value")
572
+ param_name = self._parent.add_parameter(value, name=param_name)[1]
567
573
  value_expr = exp.Placeholder(this=param_name)
568
574
 
569
575
  when_clause = exp.When(this=cond_expr, then=value_expr)
@@ -582,16 +588,17 @@ class CaseBuilder:
582
588
  Returns:
583
589
  CaseBuilder: The current builder instance for method chaining.
584
590
  """
585
- param_name = self._parent.add_parameter(value)[1]
591
+ param_name = self._parent._generate_unique_parameter_name("case_else_value")
592
+ param_name = self._parent.add_parameter(value, name=param_name)[1]
586
593
  value_expr = exp.Placeholder(this=param_name)
587
594
  self._case_expr.set("default", value_expr)
588
595
  return self
589
596
 
590
- def end(self) -> "QueryBuilder":
597
+ def end(self) -> "SelectBuilderProtocol":
591
598
  """Finalize the CASE expression and add it to the SELECT clause.
592
599
 
593
600
  Returns:
594
601
  The parent builder instance.
595
602
  """
596
603
  select_expr = exp.alias_(self._case_expr, self._alias) if self._alias else self._case_expr
597
- return cast("QueryBuilder", self._parent.select(select_expr)) # type: ignore[attr-defined]
604
+ return self._parent.select(select_expr)