sqlspec 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlspec/_sql.py +699 -43
  2. sqlspec/builder/_base.py +77 -44
  3. sqlspec/builder/_column.py +0 -4
  4. sqlspec/builder/_ddl.py +15 -52
  5. sqlspec/builder/_ddl_utils.py +0 -1
  6. sqlspec/builder/_delete.py +4 -5
  7. sqlspec/builder/_insert.py +61 -35
  8. sqlspec/builder/_merge.py +17 -2
  9. sqlspec/builder/_parsing_utils.py +16 -12
  10. sqlspec/builder/_select.py +29 -33
  11. sqlspec/builder/_update.py +4 -2
  12. sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
  13. sqlspec/builder/mixins/_delete_operations.py +6 -1
  14. sqlspec/builder/mixins/_insert_operations.py +126 -24
  15. sqlspec/builder/mixins/_join_operations.py +11 -4
  16. sqlspec/builder/mixins/_merge_operations.py +91 -19
  17. sqlspec/builder/mixins/_order_limit_operations.py +15 -3
  18. sqlspec/builder/mixins/_pivot_operations.py +11 -2
  19. sqlspec/builder/mixins/_select_operations.py +16 -10
  20. sqlspec/builder/mixins/_update_operations.py +43 -10
  21. sqlspec/builder/mixins/_where_clause.py +177 -65
  22. sqlspec/core/cache.py +26 -28
  23. sqlspec/core/compiler.py +58 -37
  24. sqlspec/core/filters.py +12 -10
  25. sqlspec/core/parameters.py +80 -52
  26. sqlspec/core/result.py +30 -17
  27. sqlspec/core/statement.py +47 -22
  28. sqlspec/driver/_async.py +76 -46
  29. sqlspec/driver/_common.py +25 -6
  30. sqlspec/driver/_sync.py +73 -43
  31. sqlspec/driver/mixins/_result_tools.py +62 -37
  32. sqlspec/driver/mixins/_sql_translator.py +61 -11
  33. sqlspec/extensions/litestar/cli.py +1 -1
  34. sqlspec/extensions/litestar/plugin.py +2 -2
  35. sqlspec/protocols.py +7 -0
  36. sqlspec/utils/sync_tools.py +1 -1
  37. sqlspec/utils/type_guards.py +7 -3
  38. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/METADATA +1 -1
  39. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/RECORD +43 -43
  40. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/WHEEL +0 -0
  41. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/entry_points.txt +0 -0
  42. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/LICENSE +0 -0
  43. {sqlspec-0.15.0.dist-info → sqlspec-0.16.1.dist-info}/licenses/NOTICE +0 -0
@@ -4,8 +4,7 @@ This module provides a fluent interface for building SQL queries safely,
4
4
  with automatic parameter binding and validation.
5
5
  """
6
6
 
7
- from dataclasses import dataclass, field
8
- from typing import TYPE_CHECKING, Any, Optional
7
+ from typing import TYPE_CHECKING, Any, Final, Optional
9
8
 
10
9
  from sqlglot import exp
11
10
  from typing_extensions import Self
@@ -21,15 +20,14 @@ if TYPE_CHECKING:
21
20
 
22
21
  __all__ = ("Insert",)
23
22
 
24
- ERR_MSG_TABLE_NOT_SET = "The target table must be set using .into() before adding values."
25
- ERR_MSG_VALUES_COLUMNS_MISMATCH = (
23
+ ERR_MSG_TABLE_NOT_SET: Final[str] = "The target table must be set using .into() before adding values."
24
+ ERR_MSG_VALUES_COLUMNS_MISMATCH: Final[str] = (
26
25
  "Number of values ({values_len}) does not match the number of specified columns ({columns_len})."
27
26
  )
28
- ERR_MSG_INTERNAL_EXPRESSION_TYPE = "Internal error: expression is not an Insert instance as expected."
29
- ERR_MSG_EXPRESSION_NOT_INITIALIZED = "Internal error: base expression not initialized."
27
+ ERR_MSG_INTERNAL_EXPRESSION_TYPE: Final[str] = "Internal error: expression is not an Insert instance as expected."
28
+ ERR_MSG_EXPRESSION_NOT_INITIALIZED: Final[str] = "Internal error: base expression not initialized."
30
29
 
31
30
 
32
- @dataclass(unsafe_hash=True)
33
31
  class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSelectMixin, InsertIntoClauseMixin):
34
32
  """Builder for INSERT statements.
35
33
 
@@ -37,9 +35,7 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
37
35
  in a safe and dialect-agnostic manner with automatic parameter binding.
38
36
  """
39
37
 
40
- _table: "Optional[str]" = field(default=None, init=False)
41
- _columns: list[str] = field(default_factory=list, init=False)
42
- _values_added_count: int = field(default=0, init=False)
38
+ __slots__ = ("_columns", "_table", "_values_added_count")
43
39
 
44
40
  def __init__(self, table: Optional[str] = None, **kwargs: Any) -> None:
45
41
  """Initialize INSERT with optional table.
@@ -50,9 +46,12 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
50
46
  """
51
47
  super().__init__(**kwargs)
52
48
 
53
- self._table = None
54
- self._columns = []
55
- self._values_added_count = 0
49
+ # Initialize Insert-specific attributes
50
+ self._table: Optional[str] = None
51
+ self._columns: list[str] = []
52
+ self._values_added_count: int = 0
53
+
54
+ self._initialize_expression()
56
55
 
57
56
  if table:
58
57
  self.into(table)
@@ -91,16 +90,22 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
91
90
  raise SQLBuilderError(ERR_MSG_INTERNAL_EXPRESSION_TYPE)
92
91
  return self._expression
93
92
 
94
- def values(self, *values: Any) -> "Self":
93
+ def values(self, *values: Any, **kwargs: Any) -> "Self":
95
94
  """Adds a row of values to the INSERT statement.
96
95
 
97
96
  This method can be called multiple times to insert multiple rows,
98
97
  resulting in a multi-row INSERT statement like `VALUES (...), (...)`.
99
98
 
99
+ Supports:
100
+ - values(val1, val2, val3)
101
+ - values(col1=val1, col2=val2)
102
+ - values(mapping)
103
+
100
104
  Args:
101
105
  *values: The values for the row to be inserted. The number of values
102
106
  must match the number of columns set by `columns()`, if `columns()` was called
103
107
  and specified any non-empty list of columns.
108
+ **kwargs: Column-value pairs for named values.
104
109
 
105
110
  Returns:
106
111
  The current builder instance for method chaining.
@@ -113,25 +118,49 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
113
118
  if not self._table:
114
119
  raise SQLBuilderError(ERR_MSG_TABLE_NOT_SET)
115
120
 
121
+ if kwargs:
122
+ if values:
123
+ msg = "Cannot mix positional values with keyword values."
124
+ raise SQLBuilderError(msg)
125
+ return self.values_from_dict(kwargs)
126
+
127
+ if len(values) == 1:
128
+ try:
129
+ values_0 = values[0]
130
+ if hasattr(values_0, "items"):
131
+ return self.values_from_dict(values_0)
132
+ except (AttributeError, TypeError):
133
+ pass
134
+
116
135
  insert_expr = self._get_insert_expression()
117
136
 
118
137
  if self._columns and len(values) != len(self._columns):
119
138
  msg = ERR_MSG_VALUES_COLUMNS_MISMATCH.format(values_len=len(values), columns_len=len(self._columns))
120
139
  raise SQLBuilderError(msg)
121
140
 
122
- param_names = [self._add_parameter(value) for value in values]
123
- value_placeholders = tuple(exp.var(name) for name in param_names)
124
-
125
- current_values_expression = insert_expr.args.get("expression")
126
-
141
+ value_placeholders: list[exp.Expression] = []
142
+ for i, value in enumerate(values):
143
+ if isinstance(value, exp.Expression):
144
+ value_placeholders.append(value)
145
+ else:
146
+ if self._columns and i < len(self._columns):
147
+ column_str = str(self._columns[i])
148
+ column_name = column_str.rsplit(".", maxsplit=1)[-1] if "." in column_str else column_str
149
+ param_name = self._generate_unique_parameter_name(column_name)
150
+ else:
151
+ param_name = self._generate_unique_parameter_name(f"value_{i + 1}")
152
+ _, param_name = self.add_parameter(value, name=param_name)
153
+ value_placeholders.append(exp.var(param_name))
154
+
155
+ tuple_expr = exp.Tuple(expressions=value_placeholders)
127
156
  if self._values_added_count == 0:
128
- new_values_node = exp.Values(expressions=[exp.Tuple(expressions=list(value_placeholders))])
129
- insert_expr.set("expression", new_values_node)
130
- elif isinstance(current_values_expression, exp.Values):
131
- current_values_expression.expressions.append(exp.Tuple(expressions=list(value_placeholders)))
157
+ insert_expr.set("expression", exp.Values(expressions=[tuple_expr]))
132
158
  else:
133
- new_values_node = exp.Values(expressions=[exp.Tuple(expressions=list(value_placeholders))])
134
- insert_expr.set("expression", new_values_node)
159
+ current_values = insert_expr.args.get("expression")
160
+ if isinstance(current_values, exp.Values):
161
+ current_values.expressions.append(tuple_expr)
162
+ else:
163
+ insert_expr.set("expression", exp.Values(expressions=[tuple_expr]))
135
164
 
136
165
  self._values_added_count += 1
137
166
  return self
@@ -154,10 +183,11 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
154
183
  if not self._table:
155
184
  raise SQLBuilderError(ERR_MSG_TABLE_NOT_SET)
156
185
 
186
+ data_keys = list(data.keys())
157
187
  if not self._columns:
158
- self.columns(*data.keys())
159
- elif set(self._columns) != set(data.keys()):
160
- msg = f"Dictionary keys {set(data.keys())} do not match existing columns {set(self._columns)}."
188
+ self.columns(*data_keys)
189
+ elif set(self._columns) != set(data_keys):
190
+ msg = f"Dictionary keys {set(data_keys)} do not match existing columns {set(self._columns)}."
161
191
  raise SQLBuilderError(msg)
162
192
 
163
193
  return self.values(*[data[col] for col in self._columns])
@@ -211,18 +241,14 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
211
241
  For a more general solution, you might need dialect-specific handling.
212
242
  """
213
243
  insert_expr = self._get_insert_expression()
214
- try:
215
- on_conflict = exp.OnConflict(this=None, expressions=[])
216
- insert_expr.set("on", on_conflict)
217
- except AttributeError:
218
- pass
244
+ insert_expr.set("on", exp.OnConflict(this=None, expressions=[]))
219
245
  return self
220
246
 
221
- def on_duplicate_key_update(self, **set_values: Any) -> "Self":
247
+ def on_duplicate_key_update(self, **_: Any) -> "Self":
222
248
  """Adds an ON DUPLICATE KEY UPDATE clause (MySQL syntax).
223
249
 
224
250
  Args:
225
- **set_values: Column-value pairs to update on duplicate key.
251
+ **_: Column-value pairs to update on duplicate key.
226
252
 
227
253
  Returns:
228
254
  The current builder instance for method chaining.
sqlspec/builder/_merge.py CHANGED
@@ -4,7 +4,7 @@ This module provides a fluent interface for building SQL queries safely,
4
4
  with automatic parameter binding and validation.
5
5
  """
6
6
 
7
- from dataclasses import dataclass
7
+ from typing import Any, Optional
8
8
 
9
9
  from sqlglot import exp
10
10
 
@@ -22,7 +22,6 @@ from sqlspec.core.result import SQLResult
22
22
  __all__ = ("Merge",)
23
23
 
24
24
 
25
- @dataclass(unsafe_hash=True)
26
25
  class Merge(
27
26
  QueryBuilder,
28
27
  MergeUsingClauseMixin,
@@ -38,6 +37,22 @@ class Merge(
38
37
  (also known as UPSERT in some databases) with automatic parameter binding and validation.
39
38
  """
40
39
 
40
+ __slots__ = ()
41
+ _expression: Optional[exp.Expression]
42
+
43
+ def __init__(self, target_table: Optional[str] = None, **kwargs: Any) -> None:
44
+ """Initialize MERGE with optional target table.
45
+
46
+ Args:
47
+ target_table: Target table name
48
+ **kwargs: Additional QueryBuilder arguments
49
+ """
50
+ super().__init__(**kwargs)
51
+ self._initialize_expression()
52
+
53
+ if target_table:
54
+ self.into(target_table)
55
+
41
56
  @property
42
57
  def _expected_result_type(self) -> "type[SQLResult]":
43
58
  """Return the expected result type for this builder.
@@ -5,7 +5,7 @@ that users might pass as strings to various builder methods.
5
5
  """
6
6
 
7
7
  import contextlib
8
- from typing import Any, Optional, Union, cast
8
+ from typing import Any, Final, Optional, Union, cast
9
9
 
10
10
  from sqlglot import exp, maybe_parse, parse_one
11
11
 
@@ -33,9 +33,12 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex
33
33
  return column_input
34
34
 
35
35
  if has_expression_attr(column_input):
36
- attr_value = getattr(column_input, "_expression", None)
37
- if isinstance(attr_value, exp.Expression):
38
- return attr_value
36
+ try:
37
+ attr_value = column_input._expression
38
+ if isinstance(attr_value, exp.Expression):
39
+ return attr_value
40
+ except AttributeError:
41
+ pass
39
42
 
40
43
  return exp.maybe_parse(column_input) or exp.column(str(column_input))
41
44
 
@@ -102,14 +105,18 @@ def parse_condition_expression(
102
105
  if isinstance(condition_input, exp.Expression):
103
106
  return condition_input
104
107
 
105
- tuple_condition_parts = 2
108
+ tuple_condition_parts: Final[int] = 2
106
109
  if isinstance(condition_input, tuple) and len(condition_input) == tuple_condition_parts:
107
110
  column, value = condition_input
108
111
  column_expr = parse_column_expression(column)
109
112
  if value is None:
110
113
  return exp.Is(this=column_expr, expression=exp.null())
111
114
  if builder and has_parameter_builder(builder):
112
- _, param_name = builder.add_parameter(value)
115
+ from sqlspec.builder.mixins._where_clause import _extract_column_name
116
+
117
+ column_name = _extract_column_name(column)
118
+ param_name = builder._generate_unique_parameter_name(column_name)
119
+ _, param_name = builder.add_parameter(value, name=param_name)
113
120
  return exp.EQ(this=column_expr, expression=exp.Placeholder(this=param_name))
114
121
  if isinstance(value, str):
115
122
  return exp.EQ(this=column_expr, expression=exp.convert(value))
@@ -125,12 +132,9 @@ def parse_condition_expression(
125
132
  except Exception:
126
133
  try:
127
134
  parsed = exp.maybe_parse(condition_input) # type: ignore[var-annotated]
128
- if parsed:
129
- return parsed # type:ignore[no-any-return]
130
- except Exception: # noqa: S110
131
- pass
132
-
133
- return exp.condition(condition_input)
135
+ return parsed or exp.condition(condition_input)
136
+ except Exception:
137
+ return exp.condition(condition_input)
134
138
 
135
139
 
136
140
  __all__ = ("parse_column_expression", "parse_condition_expression", "parse_order_expression", "parse_table_expression")
@@ -5,8 +5,7 @@ with automatic parameter binding and validation.
5
5
  """
6
6
 
7
7
  import re
8
- from dataclasses import dataclass, field
9
- from typing import Any, Optional, Union
8
+ from typing import Any, Callable, Final, Optional, Union
10
9
 
11
10
  from sqlglot import exp
12
11
  from typing_extensions import Self
@@ -29,10 +28,9 @@ from sqlspec.core.result import SQLResult
29
28
  __all__ = ("Select",)
30
29
 
31
30
 
32
- TABLE_HINT_PATTERN = r"\b{}\b(\s+AS\s+\w+)?"
31
+ TABLE_HINT_PATTERN: Final[str] = r"\b{}\b(\s+AS\s+\w+)?"
33
32
 
34
33
 
35
- @dataclass
36
34
  class Select(
37
35
  QueryBuilder,
38
36
  WhereClauseMixin,
@@ -58,9 +56,8 @@ class Select(
58
56
  >>> result = driver.execute(builder)
59
57
  """
60
58
 
61
- _with_parts: "dict[str, Union[exp.CTE, Select]]" = field(default_factory=dict, init=False)
62
- _expression: Optional[exp.Expression] = field(default=None, init=False, repr=False, compare=False, hash=False)
63
- _hints: "list[dict[str, object]]" = field(default_factory=list, init=False, repr=False)
59
+ __slots__ = ("_hints", "_with_parts")
60
+ _expression: Optional[exp.Expression]
64
61
 
65
62
  def __init__(self, *columns: str, **kwargs: Any) -> None:
66
63
  """Initialize SELECT with optional columns.
@@ -75,11 +72,11 @@ class Select(
75
72
  """
76
73
  super().__init__(**kwargs)
77
74
 
78
- self._with_parts = {}
79
- self._expression = None
80
- self._hints = []
75
+ # Initialize Select-specific attributes
76
+ self._with_parts: dict[str, Union[exp.CTE, Select]] = {}
77
+ self._hints: list[dict[str, object]] = []
81
78
 
82
- self._create_base_expression()
79
+ self._initialize_expression()
83
80
 
84
81
  if columns:
85
82
  self.select(*columns)
@@ -93,7 +90,8 @@ class Select(
93
90
  """
94
91
  return SQLResult
95
92
 
96
- def _create_base_expression(self) -> "exp.Select":
93
+ def _create_base_expression(self) -> exp.Select:
94
+ """Create base SELECT expression."""
97
95
  if self._expression is None or not isinstance(self._expression, exp.Select):
98
96
  self._expression = exp.Select()
99
97
  return self._expression
@@ -131,44 +129,42 @@ class Select(
131
129
  if not self._hints:
132
130
  return safe_query
133
131
 
134
- modified_expr = self._expression.copy() if self._expression else self._create_base_expression()
132
+ modified_expr = self._expression or self._create_base_expression()
135
133
 
136
134
  if isinstance(modified_expr, exp.Select):
137
135
  statement_hints = [h["hint"] for h in self._hints if h.get("location") == "statement"]
138
136
  if statement_hints:
139
- hint_expressions = []
140
137
 
141
- def parse_hint(hint: Any) -> exp.Expression:
142
- """Parse a single hint."""
138
+ def parse_hint_safely(hint: Any) -> exp.Expression:
143
139
  try:
144
- hint_str = str(hint) # Ensure hint is a string
140
+ hint_str = str(hint)
145
141
  hint_expr: Optional[exp.Expression] = exp.maybe_parse(hint_str, dialect=self.dialect_name)
146
- if hint_expr:
147
- return hint_expr
148
- return exp.Anonymous(this=hint_str)
142
+ return hint_expr or exp.Anonymous(this=hint_str)
149
143
  except Exception:
150
144
  return exp.Anonymous(this=str(hint))
151
145
 
152
- hint_expressions = [parse_hint(hint) for hint in statement_hints]
146
+ hint_expressions: list[exp.Expression] = [parse_hint_safely(hint) for hint in statement_hints]
153
147
 
154
148
  if hint_expressions:
155
- hint_node = exp.Hint(expressions=hint_expressions)
156
- modified_expr.set("hint", hint_node)
149
+ modified_expr.set("hint", exp.Hint(expressions=hint_expressions))
157
150
 
158
151
  modified_sql = modified_expr.sql(dialect=self.dialect_name, pretty=True)
159
152
 
160
- table_hints = [h for h in self._hints if h.get("location") == "table" and h.get("table")]
161
- if table_hints:
162
- for th in table_hints:
163
- table = str(th["table"])
164
- hint = th["hint"]
153
+ for hint_dict in self._hints:
154
+ if hint_dict.get("location") == "table" and hint_dict.get("table"):
155
+ table = str(hint_dict["table"])
156
+ hint = str(hint_dict["hint"])
165
157
  pattern = TABLE_HINT_PATTERN.format(re.escape(table))
166
- compiled_pattern = re.compile(pattern, re.IGNORECASE)
167
158
 
168
- def replacement_func(match: re.Match[str]) -> str:
169
- alias_part = match.group(1) or ""
170
- return f"/*+ {hint} */ {table}{alias_part}" # noqa: B023
159
+ def make_replacement(hint_val: str, table_val: str) -> "Callable[[re.Match[str]], str]":
160
+ def replacement_func(match: re.Match[str]) -> str:
161
+ alias_part = match.group(1) or ""
162
+ return f"/*+ {hint_val} */ {table_val}{alias_part}"
171
163
 
172
- modified_sql = compiled_pattern.sub(replacement_func, modified_sql, count=1)
164
+ return replacement_func
165
+
166
+ modified_sql = re.sub(
167
+ pattern, make_replacement(hint, table), modified_sql, count=1, flags=re.IGNORECASE
168
+ )
173
169
 
174
170
  return SafeQuery(sql=modified_sql, parameters=safe_query.parameters, dialect=safe_query.dialect)
@@ -4,7 +4,6 @@ This module provides a fluent interface for building SQL queries safely,
4
4
  with automatic parameter binding and validation.
5
5
  """
6
6
 
7
- from dataclasses import dataclass
8
7
  from typing import TYPE_CHECKING, Any, Optional, Union
9
8
 
10
9
  from sqlglot import exp
@@ -27,7 +26,6 @@ if TYPE_CHECKING:
27
26
  __all__ = ("Update",)
28
27
 
29
28
 
30
- @dataclass(unsafe_hash=True)
31
29
  class Update(
32
30
  QueryBuilder,
33
31
  WhereClauseMixin,
@@ -72,6 +70,9 @@ class Update(
72
70
  ```
73
71
  """
74
72
 
73
+ __slots__ = ("_table",)
74
+ _expression: Optional[exp.Expression]
75
+
75
76
  def __init__(self, table: Optional[str] = None, **kwargs: Any) -> None:
76
77
  """Initialize UPDATE with optional table.
77
78
 
@@ -80,6 +81,7 @@ class Update(
80
81
  **kwargs: Additional QueryBuilder arguments
81
82
  """
82
83
  super().__init__(**kwargs)
84
+ self._initialize_expression()
83
85
 
84
86
  if table:
85
87
  self.table(table)
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Any, Optional, Union
4
4
 
5
+ from mypy_extensions import trait
5
6
  from sqlglot import exp
6
7
  from typing_extensions import Self
7
8
 
@@ -10,10 +11,21 @@ from sqlspec.exceptions import SQLBuilderError
10
11
  __all__ = ("CommonTableExpressionMixin", "SetOperationMixin")
11
12
 
12
13
 
14
+ @trait
13
15
  class CommonTableExpressionMixin:
14
16
  """Mixin providing WITH clause (Common Table Expressions) support for SQL builders."""
15
17
 
16
- _expression: Optional[exp.Expression] = None
18
+ __slots__ = ()
19
+ # Type annotation for PyRight - this will be provided by the base class
20
+ _expression: Optional[exp.Expression]
21
+
22
+ _with_ctes: Any # Provided by QueryBuilder
23
+ dialect: Any # Provided by QueryBuilder
24
+
25
+ def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
26
+ """Add parameter - provided by QueryBuilder."""
27
+ msg = "Method must be provided by QueryBuilder subclass"
28
+ raise NotImplementedError(msg)
17
29
 
18
30
  def with_(
19
31
  self, name: str, query: Union[Any, str], recursive: bool = False, columns: Optional[list[str]] = None
@@ -42,22 +54,22 @@ class CommonTableExpressionMixin:
42
54
 
43
55
  cte_expr: Optional[exp.Expression] = None
44
56
  if isinstance(query, str):
45
- cte_expr = exp.maybe_parse(query, dialect=self.dialect) # type: ignore[attr-defined]
57
+ cte_expr = exp.maybe_parse(query, dialect=self.dialect)
46
58
  elif isinstance(query, exp.Expression):
47
59
  cte_expr = query
48
60
  else:
49
- built_query = query.to_statement() # pyright: ignore
61
+ built_query = query.to_statement()
50
62
  cte_sql = built_query.sql
51
- cte_expr = exp.maybe_parse(cte_sql, dialect=self.dialect) # type: ignore[attr-defined]
63
+ cte_expr = exp.maybe_parse(cte_sql, dialect=self.dialect)
52
64
 
53
65
  parameters = built_query.parameters
54
66
  if parameters:
55
67
  if isinstance(parameters, dict):
56
68
  for param_name, param_value in parameters.items():
57
- self.add_parameter(param_value, name=param_name) # type: ignore[attr-defined]
69
+ self.add_parameter(param_value, name=param_name)
58
70
  elif isinstance(parameters, (list, tuple)):
59
71
  for param_value in parameters:
60
- self.add_parameter(param_value) # type: ignore[attr-defined]
72
+ self.add_parameter(param_value)
61
73
 
62
74
  if not cte_expr:
63
75
  msg = f"Could not parse CTE query: {query}"
@@ -68,29 +80,42 @@ class CommonTableExpressionMixin:
68
80
  else:
69
81
  cte_alias_expr = exp.alias_(cte_expr, name)
70
82
 
71
- existing_with = self._expression.args.get("with") # pyright: ignore
83
+ existing_with = self._expression.args.get("with")
72
84
  if existing_with:
73
85
  existing_with.expressions.append(cte_alias_expr)
74
86
  if recursive:
75
87
  existing_with.set("recursive", recursive)
76
88
  else:
77
- self._expression = self._expression.with_(cte_alias_expr, as_=name, copy=False) # type: ignore[union-attr]
78
- if recursive:
79
- with_clause = self._expression.find(exp.With)
80
- if with_clause:
81
- with_clause.set("recursive", recursive)
82
- self._with_ctes[name] = exp.CTE(this=cte_expr, alias=exp.to_table(name)) # type: ignore[attr-defined]
89
+ # Only SELECT, INSERT, UPDATE support WITH clauses
90
+ if hasattr(self._expression, "with_") and isinstance(
91
+ self._expression, (exp.Select, exp.Insert, exp.Update)
92
+ ):
93
+ self._expression = self._expression.with_(cte_alias_expr, as_=name, copy=False)
94
+ if recursive:
95
+ with_clause = self._expression.find(exp.With)
96
+ if with_clause:
97
+ with_clause.set("recursive", recursive)
98
+ self._with_ctes[name] = exp.CTE(this=cte_expr, alias=exp.to_table(name))
83
99
 
84
100
  return self
85
101
 
86
102
 
103
+ @trait
87
104
  class SetOperationMixin:
88
105
  """Mixin providing set operations (UNION, INTERSECT, EXCEPT) for SELECT builders."""
89
106
 
90
- _expression: Any = None
107
+ __slots__ = ()
108
+ # Type annotation for PyRight - this will be provided by the base class
109
+ _expression: Optional[exp.Expression]
110
+
91
111
  _parameters: dict[str, Any]
92
112
  dialect: Any = None
93
113
 
114
+ def build(self) -> Any:
115
+ """Build the query - provided by QueryBuilder."""
116
+ msg = "Method must be provided by QueryBuilder subclass"
117
+ raise NotImplementedError(msg)
118
+
94
119
  def union(self, other: Any, all_: bool = False) -> Self:
95
120
  """Combine this query with another using UNION.
96
121
 
@@ -104,7 +129,7 @@ class SetOperationMixin:
104
129
  Returns:
105
130
  The new builder instance for the union query.
106
131
  """
107
- left_query = self.build() # type: ignore[attr-defined]
132
+ left_query = self.build()
108
133
  right_query = other.build()
109
134
  left_expr: Optional[exp.Expression] = exp.maybe_parse(left_query.sql, dialect=self.dialect)
110
135
  right_expr: Optional[exp.Expression] = exp.maybe_parse(right_query.sql, dialect=self.dialect)
@@ -124,9 +149,11 @@ class SetOperationMixin:
124
149
  counter += 1
125
150
  new_param_name = f"{param_name}_right_{counter}"
126
151
 
127
- def rename_parameter(node: exp.Expression) -> exp.Expression:
128
- if isinstance(node, exp.Placeholder) and node.name == param_name: # noqa: B023
129
- return exp.Placeholder(this=new_param_name) # noqa: B023
152
+ def rename_parameter(
153
+ node: exp.Expression, old_name: str = param_name, new_name: str = new_param_name
154
+ ) -> exp.Expression:
155
+ if isinstance(node, exp.Placeholder) and node.name == old_name:
156
+ return exp.Placeholder(this=new_name)
130
157
  return node
131
158
 
132
159
  right_expr = right_expr.transform(rename_parameter)
@@ -150,7 +177,7 @@ class SetOperationMixin:
150
177
  Returns:
151
178
  The new builder instance for the intersect query.
152
179
  """
153
- left_query = self.build() # type: ignore[attr-defined]
180
+ left_query = self.build()
154
181
  right_query = other.build()
155
182
  left_expr: Optional[exp.Expression] = exp.maybe_parse(left_query.sql, dialect=self.dialect)
156
183
  right_expr: Optional[exp.Expression] = exp.maybe_parse(right_query.sql, dialect=self.dialect)
@@ -178,7 +205,7 @@ class SetOperationMixin:
178
205
  Returns:
179
206
  The new builder instance for the except query.
180
207
  """
181
- left_query = self.build() # type: ignore[attr-defined]
208
+ left_query = self.build()
182
209
  right_query = other.build()
183
210
  left_expr: Optional[exp.Expression] = exp.maybe_parse(left_query.sql, dialect=self.dialect)
184
211
  right_expr: Optional[exp.Expression] = exp.maybe_parse(right_query.sql, dialect=self.dialect)
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Optional
4
4
 
5
+ from mypy_extensions import trait
5
6
  from sqlglot import exp
6
7
  from typing_extensions import Self
7
8
 
@@ -10,10 +11,14 @@ from sqlspec.exceptions import SQLBuilderError
10
11
  __all__ = ("DeleteFromClauseMixin",)
11
12
 
12
13
 
14
+ @trait
13
15
  class DeleteFromClauseMixin:
14
16
  """Mixin providing FROM clause for DELETE builders."""
15
17
 
16
- _expression: Optional[exp.Expression] = None
18
+ __slots__ = ()
19
+
20
+ # Type annotation for PyRight - this will be provided by the base class
21
+ _expression: Optional[exp.Expression]
17
22
 
18
23
  def from_(self, table: str) -> Self:
19
24
  """Set the target table for the DELETE statement.