sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
25
25
  from collections.abc import Mapping, Sequence
26
26
 
27
27
 
28
- __all__ = ("InsertBuilder",)
28
+ __all__ = ("Insert",)
29
29
 
30
30
  ERR_MSG_TABLE_NOT_SET = "The target table must be set using .into() before adding values."
31
31
  ERR_MSG_VALUES_COLUMNS_MISMATCH = (
@@ -36,9 +36,7 @@ ERR_MSG_EXPRESSION_NOT_INITIALIZED = "Internal error: base expression not initia
36
36
 
37
37
 
38
38
  @dataclass(unsafe_hash=True)
39
- class InsertBuilder(
40
- QueryBuilder[RowT], ReturningClauseMixin, InsertValuesMixin, InsertFromSelectMixin, InsertIntoClauseMixin
41
- ):
39
+ class Insert(QueryBuilder[RowT], ReturningClauseMixin, InsertValuesMixin, InsertFromSelectMixin, InsertIntoClauseMixin):
42
40
  """Builder for INSERT statements.
43
41
 
44
42
  This builder facilitates the construction of SQL INSERT queries
@@ -48,15 +46,20 @@ class InsertBuilder(
48
46
  ```python
49
47
  # Basic INSERT with values
50
48
  insert_query = (
51
- InsertBuilder()
49
+ Insert()
52
50
  .into("users")
53
51
  .columns("name", "email", "age")
54
52
  .values("John Doe", "john@example.com", 30)
55
53
  )
56
54
 
55
+ # Even more concise with constructor
56
+ insert_query = Insert("users").values(
57
+ {"name": "John", "age": 30}
58
+ )
59
+
57
60
  # Multi-row INSERT
58
61
  insert_query = (
59
- InsertBuilder()
62
+ Insert()
60
63
  .into("users")
61
64
  .columns("name", "email")
62
65
  .values("John", "john@example.com")
@@ -65,7 +68,7 @@ class InsertBuilder(
65
68
 
66
69
  # INSERT from dictionary
67
70
  insert_query = (
68
- InsertBuilder()
71
+ Insert()
69
72
  .into("users")
70
73
  .values_from_dict(
71
74
  {"name": "John", "email": "john@example.com"}
@@ -74,10 +77,10 @@ class InsertBuilder(
74
77
 
75
78
  # INSERT from SELECT
76
79
  insert_query = (
77
- InsertBuilder()
80
+ Insert()
78
81
  .into("users_backup")
79
82
  .from_select(
80
- SelectBuilder()
83
+ Select()
81
84
  .select("name", "email")
82
85
  .from_("users")
83
86
  .where("active = true")
@@ -90,6 +93,23 @@ class InsertBuilder(
90
93
  _columns: list[str] = field(default_factory=list, init=False)
91
94
  _values_added_count: int = field(default=0, init=False)
92
95
 
96
+ def __init__(self, table: Optional[str] = None, **kwargs: Any) -> None:
97
+ """Initialize INSERT with optional table.
98
+
99
+ Args:
100
+ table: Target table name
101
+ **kwargs: Additional QueryBuilder arguments
102
+ """
103
+ super().__init__(**kwargs)
104
+
105
+ # Initialize fields from dataclass
106
+ self._table = None
107
+ self._columns = []
108
+ self._values_added_count = 0
109
+
110
+ if table:
111
+ self.into(table)
112
+
93
113
  def _create_base_expression(self) -> exp.Insert:
94
114
  """Create a base INSERT expression.
95
115
 
@@ -165,7 +185,6 @@ class InsertBuilder(
165
185
  else:
166
186
  # This case should ideally not be reached if logic is correct:
167
187
  # means _values_added_count > 0 but expression is not exp.Values.
168
- # Fallback to creating a new Values node, though this might indicate an issue.
169
188
  new_values_node = exp.Values(expressions=[exp.Tuple(expressions=list(value_placeholders))])
170
189
  insert_expr.set("expression", new_values_node)
171
190
 
@@ -191,14 +210,12 @@ class InsertBuilder(
191
210
  raise SQLBuilderError(ERR_MSG_TABLE_NOT_SET)
192
211
 
193
212
  if not self._columns:
194
- # Set columns from dictionary keys if not already set
195
213
  self.columns(*data.keys())
196
214
  elif set(self._columns) != set(data.keys()):
197
215
  # Verify that dictionary keys match existing columns
198
216
  msg = f"Dictionary keys {set(data.keys())} do not match existing columns {set(self._columns)}."
199
217
  raise SQLBuilderError(msg)
200
218
 
201
- # Add values in the same order as columns
202
219
  return self.values(*[data[col] for col in self._columns])
203
220
 
204
221
  def values_from_dicts(self, data: "Sequence[Mapping[str, Any]]") -> "Self":
@@ -219,12 +236,10 @@ class InsertBuilder(
219
236
  if not data:
220
237
  return self
221
238
 
222
- # Use the first dictionary to establish columns
223
239
  first_dict = data[0]
224
240
  if not self._columns:
225
241
  self.columns(*first_dict.keys())
226
242
 
227
- # Validate that all dictionaries have the same keys
228
243
  expected_keys = set(self._columns)
229
244
  for i, row_dict in enumerate(data):
230
245
  if set(row_dict.keys()) != expected_keys:
@@ -234,7 +249,6 @@ class InsertBuilder(
234
249
  )
235
250
  raise SQLBuilderError(msg)
236
251
 
237
- # Add each row
238
252
  for row_dict in data:
239
253
  self.values(*[row_dict[col] for col in self._columns])
240
254
 
@@ -20,11 +20,11 @@ from sqlspec.statement.builder.mixins import (
20
20
  from sqlspec.statement.result import SQLResult
21
21
  from sqlspec.typing import RowT
22
22
 
23
- __all__ = ("MergeBuilder",)
23
+ __all__ = ("Merge",)
24
24
 
25
25
 
26
26
  @dataclass(unsafe_hash=True)
27
- class MergeBuilder(
27
+ class Merge(
28
28
  QueryBuilder[RowT],
29
29
  MergeUsingClauseMixin,
30
30
  MergeOnClauseMixin,
@@ -42,7 +42,7 @@ class MergeBuilder(
42
42
  ```python
43
43
  # Basic MERGE statement
44
44
  merge_query = (
45
- MergeBuilder()
45
+ Merge()
46
46
  .into("target_table")
47
47
  .using("source_table", "src")
48
48
  .on("target_table.id = src.id")
@@ -64,7 +64,7 @@ class MergeBuilder(
64
64
  )
65
65
 
66
66
  merge_query = (
67
- MergeBuilder()
67
+ Merge()
68
68
  .into("users")
69
69
  .using(source_query, "src")
70
70
  .on("users.email = src.email")
@@ -4,7 +4,7 @@ from sqlglot import exp
4
4
  from typing_extensions import Self
5
5
 
6
6
  if TYPE_CHECKING:
7
- from sqlspec.statement.builder.protocols import SelectBuilderProtocol
7
+ from sqlspec.protocols import SelectBuilderProtocol
8
8
 
9
9
  __all__ = ("AggregateFunctionsMixin",)
10
10
 
@@ -112,40 +112,139 @@ class AggregateFunctionsMixin:
112
112
  select_expr = exp.alias_(array_agg_expr, alias) if alias else array_agg_expr
113
113
  return cast("Self", builder.select(select_expr))
114
114
 
115
- def bool_and(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
116
- """Add BOOL_AND aggregate function to SELECT clause (PostgreSQL, DuckDB, etc).
115
+ def count_distinct(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
116
+ """Add COUNT(DISTINCT column) to SELECT clause.
117
117
 
118
118
  Args:
119
- column: The boolean column to aggregate.
119
+ column: The column to count distinct values of.
120
+ alias: Optional alias for the count.
121
+
122
+ Returns:
123
+ The current builder instance for method chaining.
124
+ """
125
+ builder = cast("SelectBuilderProtocol", self)
126
+ col_expr = exp.column(column) if isinstance(column, str) else column
127
+ count_expr = exp.Count(this=exp.Distinct(expressions=[col_expr]))
128
+ select_expr = exp.alias_(count_expr, alias) if alias else count_expr
129
+ return cast("Self", builder.select(select_expr))
130
+
131
+ def stddev(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
132
+ """Add STDDEV aggregate function to SELECT clause.
133
+
134
+ Args:
135
+ column: The column to calculate standard deviation of.
120
136
  alias: Optional alias for the result.
121
137
 
122
138
  Returns:
123
139
  The current builder instance for method chaining.
140
+ """
141
+ builder = cast("SelectBuilderProtocol", self)
142
+ col_expr = exp.column(column) if isinstance(column, str) else column
143
+ stddev_expr = exp.Stddev(this=col_expr)
144
+ select_expr = exp.alias_(stddev_expr, alias) if alias else stddev_expr
145
+ return cast("Self", builder.select(select_expr))
124
146
 
125
- Note:
126
- Uses exp.Anonymous for BOOL_AND. Not all dialects support this function.
147
+ def stddev_pop(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
148
+ """Add STDDEV_POP aggregate function to SELECT clause.
149
+
150
+ Args:
151
+ column: The column to calculate population standard deviation of.
152
+ alias: Optional alias for the result.
153
+
154
+ Returns:
155
+ The current builder instance for method chaining.
127
156
  """
128
157
  builder = cast("SelectBuilderProtocol", self)
129
158
  col_expr = exp.column(column) if isinstance(column, str) else column
130
- bool_and_expr = exp.Anonymous(this="BOOL_AND", expressions=[col_expr])
131
- select_expr = exp.alias_(bool_and_expr, alias) if alias else bool_and_expr
159
+ stddev_pop_expr = exp.StddevPop(this=col_expr)
160
+ select_expr = exp.alias_(stddev_pop_expr, alias) if alias else stddev_pop_expr
132
161
  return cast("Self", builder.select(select_expr))
133
162
 
134
- def bool_or(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
135
- """Add BOOL_OR aggregate function to SELECT clause (PostgreSQL, DuckDB, etc).
163
+ def stddev_samp(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
164
+ """Add STDDEV_SAMP aggregate function to SELECT clause.
136
165
 
137
166
  Args:
138
- column: The boolean column to aggregate.
167
+ column: The column to calculate sample standard deviation of.
168
+ alias: Optional alias for the result.
169
+
170
+ Returns:
171
+ The current builder instance for method chaining.
172
+ """
173
+ builder = cast("SelectBuilderProtocol", self)
174
+ col_expr = exp.column(column) if isinstance(column, str) else column
175
+ stddev_samp_expr = exp.StddevSamp(this=col_expr)
176
+ select_expr = exp.alias_(stddev_samp_expr, alias) if alias else stddev_samp_expr
177
+ return cast("Self", builder.select(select_expr))
178
+
179
+ def variance(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
180
+ """Add VARIANCE aggregate function to SELECT clause.
181
+
182
+ Args:
183
+ column: The column to calculate variance of.
184
+ alias: Optional alias for the result.
185
+
186
+ Returns:
187
+ The current builder instance for method chaining.
188
+ """
189
+ builder = cast("SelectBuilderProtocol", self)
190
+ col_expr = exp.column(column) if isinstance(column, str) else column
191
+ variance_expr = exp.Variance(this=col_expr)
192
+ select_expr = exp.alias_(variance_expr, alias) if alias else variance_expr
193
+ return cast("Self", builder.select(select_expr))
194
+
195
+ def var_pop(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
196
+ """Add VAR_POP aggregate function to SELECT clause.
197
+
198
+ Args:
199
+ column: The column to calculate population variance of.
200
+ alias: Optional alias for the result.
201
+
202
+ Returns:
203
+ The current builder instance for method chaining.
204
+ """
205
+ builder = cast("SelectBuilderProtocol", self)
206
+ col_expr = exp.column(column) if isinstance(column, str) else column
207
+ var_pop_expr = exp.VariancePop(this=col_expr)
208
+ select_expr = exp.alias_(var_pop_expr, alias) if alias else var_pop_expr
209
+ return cast("Self", builder.select(select_expr))
210
+
211
+ def string_agg(self, column: Union[str, exp.Expression], separator: str = ",", alias: Optional[str] = None) -> Self:
212
+ """Add STRING_AGG aggregate function to SELECT clause.
213
+
214
+ Args:
215
+ column: The column to aggregate into a string.
216
+ separator: The separator between values (default is comma).
139
217
  alias: Optional alias for the result.
140
218
 
141
219
  Returns:
142
220
  The current builder instance for method chaining.
143
221
 
144
222
  Note:
145
- Uses exp.Anonymous for BOOL_OR. Not all dialects support this function.
223
+ Different databases have different names for this function:
224
+ - PostgreSQL: STRING_AGG
225
+ - MySQL: GROUP_CONCAT
226
+ - SQLite: GROUP_CONCAT
227
+ SQLGlot will handle the translation.
228
+ """
229
+ builder = cast("SelectBuilderProtocol", self)
230
+ col_expr = exp.column(column) if isinstance(column, str) else column
231
+ # Use GroupConcat which SQLGlot can translate to STRING_AGG for Postgres
232
+ string_agg_expr = exp.GroupConcat(this=col_expr, separator=exp.Literal.string(separator))
233
+ select_expr = exp.alias_(string_agg_expr, alias) if alias else string_agg_expr
234
+ return cast("Self", builder.select(select_expr))
235
+
236
+ def json_agg(self, column: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
237
+ """Add JSON_AGG aggregate function to SELECT clause.
238
+
239
+ Args:
240
+ column: The column to aggregate into a JSON array.
241
+ alias: Optional alias for the result.
242
+
243
+ Returns:
244
+ The current builder instance for method chaining.
146
245
  """
147
246
  builder = cast("SelectBuilderProtocol", self)
148
247
  col_expr = exp.column(column) if isinstance(column, str) else column
149
- bool_or_expr = exp.Anonymous(this="BOOL_OR", expressions=[col_expr])
150
- select_expr = exp.alias_(bool_or_expr, alias) if alias else bool_or_expr
248
+ json_agg_expr = exp.JSONArrayAgg(this=col_expr)
249
+ select_expr = exp.alias_(json_agg_expr, alias) if alias else json_agg_expr
151
250
  return cast("Self", builder.select(select_expr))
@@ -61,7 +61,6 @@ class CommonTableExpressionMixin:
61
61
  msg = f"Could not parse CTE query: {query}"
62
62
  raise SQLBuilderError(msg)
63
63
 
64
- # Create a proper CTE with table alias
65
64
  if columns:
66
65
  # CTE with explicit column list: name(col1, col2, ...)
67
66
  cte_alias_expr = exp.alias_(cte_expr, name, table=[exp.to_identifier(col) for col in columns])
@@ -26,7 +26,7 @@ class DeleteFromClauseMixin:
26
26
  self._expression = exp.Delete()
27
27
  if not isinstance(self._expression, exp.Delete):
28
28
  current_expr_type = type(self._expression).__name__
29
- msg = f"Base expression for DeleteBuilder is {current_expr_type}, expected Delete."
29
+ msg = f"Base expression for Delete is {current_expr_type}, expected Delete."
30
30
  raise SQLBuilderError(msg)
31
31
 
32
32
  setattr(self, "_table", table)
@@ -5,10 +5,10 @@ from typing_extensions import Self
5
5
 
6
6
  from sqlspec.exceptions import SQLBuilderError
7
7
  from sqlspec.statement.builder._parsing_utils import parse_table_expression
8
- from sqlspec.typing import is_expression
8
+ from sqlspec.utils.type_guards import has_query_builder_parameters, is_expression
9
9
 
10
10
  if TYPE_CHECKING:
11
- from sqlspec.statement.builder.protocols import BuilderProtocol
11
+ from sqlspec.protocols import SQLBuilderProtocol
12
12
 
13
13
  __all__ = ("FromClauseMixin",)
14
14
 
@@ -29,7 +29,7 @@ class FromClauseMixin:
29
29
  Returns:
30
30
  The current builder instance for method chaining.
31
31
  """
32
- builder = cast("BuilderProtocol", self)
32
+ builder = cast("SQLBuilderProtocol", self)
33
33
  if builder._expression is None:
34
34
  builder._expression = exp.Select()
35
35
  if not isinstance(builder._expression, exp.Select):
@@ -41,16 +41,18 @@ class FromClauseMixin:
41
41
  elif is_expression(table):
42
42
  # Direct sqlglot expression - use as is
43
43
  from_expr = exp.alias_(table, alias) if alias else table
44
- elif hasattr(table, "build"):
44
+ elif has_query_builder_parameters(table):
45
45
  # Query builder with build() method
46
- subquery = table.build() # pyright: ignore
47
- subquery_exp = exp.paren(exp.maybe_parse(subquery.sql, dialect=getattr(builder, "dialect", None)))
46
+ subquery = table.build()
47
+ sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
48
+ subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
48
49
  from_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
49
50
  current_params = getattr(builder, "_parameters", None)
50
51
  merged_params = getattr(type(builder), "ParameterConverter", None)
51
- if merged_params:
52
+ if merged_params and hasattr(subquery, "parameters"):
53
+ subquery_params = getattr(subquery, "parameters", {})
52
54
  merged_params = merged_params.merge_parameters(
53
- parameters=subquery.parameters,
55
+ parameters=subquery_params,
54
56
  args=current_params if isinstance(current_params, list) else None,
55
57
  kwargs=current_params if isinstance(current_params, dict) else {},
56
58
  )
@@ -106,7 +106,6 @@ class GroupByClauseMixin:
106
106
  for column_set in column_sets:
107
107
  if isinstance(column_set, (tuple, list)):
108
108
  if len(column_set) == 0:
109
- # Empty set for grand total
110
109
  set_expressions.append(exp.Tuple(expressions=[]))
111
110
  else:
112
111
  columns = [exp.column(col) for col in column_set]
@@ -38,7 +38,6 @@ class InsertFromSelectMixin:
38
38
  if subquery_params:
39
39
  for p_name, p_value in subquery_params.items():
40
40
  self.add_parameter(p_value, name=p_name) # type: ignore[attr-defined]
41
- # Set the SELECT expression as the source
42
41
  select_expr = getattr(select_builder, "_expression", None)
43
42
  if select_expr and isinstance(select_expr, exp.Select):
44
43
  self._expression.set("expression", select_expr.copy())
@@ -39,7 +39,6 @@ class InsertValuesMixin:
39
39
  if not isinstance(self._expression, exp.Insert):
40
40
  msg = "Cannot add values to a non-INSERT expression."
41
41
  raise SQLBuilderError(msg)
42
- # Validate value count if _columns is present and non-empty
43
42
  if (
44
43
  hasattr(self, "_columns") and getattr(self, "_columns", []) and len(values) != len(self._columns) # pyright: ignore
45
44
  ):
@@ -50,7 +49,6 @@ class InsertValuesMixin:
50
49
  if isinstance(v, exp.Expression):
51
50
  row_exprs.append(v)
52
51
  else:
53
- # Add as parameter
54
52
  _, param_name = self.add_parameter(v) # type: ignore[attr-defined]
55
53
  row_exprs.append(exp.var(param_name))
56
54
  values_expr = exp.Values(expressions=[row_exprs])
@@ -5,9 +5,10 @@ from typing_extensions import Self
5
5
 
6
6
  from sqlspec.exceptions import SQLBuilderError
7
7
  from sqlspec.statement.builder._parsing_utils import parse_table_expression
8
+ from sqlspec.utils.type_guards import has_query_builder_parameters
8
9
 
9
10
  if TYPE_CHECKING:
10
- from sqlspec.statement.builder.protocols import BuilderProtocol
11
+ from sqlspec.protocols import SQLBuilderProtocol
11
12
 
12
13
  __all__ = ("JoinClauseMixin",)
13
14
 
@@ -22,7 +23,7 @@ class JoinClauseMixin:
22
23
  alias: Optional[str] = None,
23
24
  join_type: str = "INNER",
24
25
  ) -> Self:
25
- builder = cast("BuilderProtocol", self)
26
+ builder = cast("SQLBuilderProtocol", self)
26
27
  if builder._expression is None:
27
28
  builder._expression = exp.Select()
28
29
  if not isinstance(builder._expression, exp.Select):
@@ -31,16 +32,19 @@ class JoinClauseMixin:
31
32
  table_expr: exp.Expression
32
33
  if isinstance(table, str):
33
34
  table_expr = parse_table_expression(table, alias)
34
- elif hasattr(table, "build"):
35
- # Handle builder objects with build() method
35
+ elif has_query_builder_parameters(table):
36
36
  # Work directly with AST when possible to avoid string parsing
37
37
  if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
38
- subquery_exp = exp.paren(table._expression.copy()) # pyright: ignore
38
+ table_expr_value = getattr(table, "_expression", None)
39
+ if table_expr_value is not None:
40
+ subquery_exp = exp.paren(table_expr_value.copy()) # pyright: ignore
41
+ else:
42
+ subquery_exp = exp.paren(exp.Anonymous(this=""))
39
43
  table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
40
44
  else:
41
- # Fallback to string parsing
42
45
  subquery = table.build() # pyright: ignore
43
- subquery_exp = exp.paren(exp.maybe_parse(subquery.sql, dialect=getattr(builder, "dialect", None)))
46
+ sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
47
+ subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
44
48
  table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
45
49
  # Parameter merging logic can be added here if needed
46
50
  else:
@@ -84,7 +88,7 @@ class JoinClauseMixin:
84
88
  return self.join(table, on, alias, "FULL")
85
89
 
86
90
  def cross_join(self, table: Union[str, exp.Expression, Any], alias: Optional[str] = None) -> Self:
87
- builder = cast("BuilderProtocol", self)
91
+ builder = cast("SQLBuilderProtocol", self)
88
92
  if builder._expression is None:
89
93
  builder._expression = exp.Select()
90
94
  if not isinstance(builder._expression, exp.Select):
@@ -93,15 +97,18 @@ class JoinClauseMixin:
93
97
  table_expr: exp.Expression
94
98
  if isinstance(table, str):
95
99
  table_expr = parse_table_expression(table, alias)
96
- elif hasattr(table, "build"):
97
- # Handle builder objects with build() method
100
+ elif has_query_builder_parameters(table):
98
101
  if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
99
- subquery_exp = exp.paren(table._expression.copy()) # pyright: ignore
102
+ table_expr_value = getattr(table, "_expression", None)
103
+ if table_expr_value is not None:
104
+ subquery_exp = exp.paren(table_expr_value.copy()) # pyright: ignore
105
+ else:
106
+ subquery_exp = exp.paren(exp.Anonymous(this=""))
100
107
  table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
101
108
  else:
102
- # Fallback to string parsing
103
109
  subquery = table.build() # pyright: ignore
104
- subquery_exp = exp.paren(exp.maybe_parse(subquery.sql, dialect=getattr(builder, "dialect", None)))
110
+ sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
111
+ subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
105
112
  table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
106
113
  else:
107
114
  table_expr = table
@@ -4,7 +4,7 @@ from sqlglot import exp
4
4
  from typing_extensions import Self
5
5
 
6
6
  if TYPE_CHECKING:
7
- from sqlspec.statement.builder.protocols import BuilderProtocol
7
+ from sqlspec.protocols import SQLBuilderProtocol
8
8
 
9
9
  from sqlspec.exceptions import SQLBuilderError
10
10
 
@@ -26,7 +26,7 @@ class LimitOffsetClauseMixin:
26
26
  Returns:
27
27
  The current builder instance for method chaining.
28
28
  """
29
- builder = cast("BuilderProtocol", self)
29
+ builder = cast("SQLBuilderProtocol", self)
30
30
  if not isinstance(builder._expression, exp.Select):
31
31
  msg = "LIMIT is only supported for SELECT statements."
32
32
  raise SQLBuilderError(msg)
@@ -45,7 +45,7 @@ class LimitOffsetClauseMixin:
45
45
  Returns:
46
46
  The current builder instance for method chaining.
47
47
  """
48
- builder = cast("BuilderProtocol", self)
48
+ builder = cast("SQLBuilderProtocol", self)
49
49
  if not isinstance(builder._expression, exp.Select):
50
50
  msg = "OFFSET is only supported for SELECT statements."
51
51
  raise SQLBuilderError(msg)
@@ -4,6 +4,7 @@ from sqlglot import exp
4
4
  from typing_extensions import Self
5
5
 
6
6
  from sqlspec.exceptions import SQLBuilderError
7
+ from sqlspec.utils.type_guards import has_query_builder_parameters
7
8
 
8
9
  __all__ = (
9
10
  "MergeIntoClauseMixin",
@@ -66,9 +67,9 @@ class MergeUsingClauseMixin:
66
67
  source_expr: exp.Expression
67
68
  if isinstance(source, str):
68
69
  source_expr = exp.to_table(source, alias=alias)
69
- elif hasattr(source, "_parameters") and hasattr(source, "_expression"):
70
+ elif has_query_builder_parameters(source) and hasattr(source, "_expression"):
70
71
  # Merge parameters from the SELECT builder or other builder
71
- subquery_builder_params = getattr(source, "_parameters", {})
72
+ subquery_builder_params = source.parameters
72
73
  if subquery_builder_params:
73
74
  for p_name, p_value in subquery_builder_params.items():
74
75
  self.add_parameter(p_value, name=p_name) # type: ignore[attr-defined]
@@ -145,13 +146,11 @@ class MergeMatchedClauseMixin:
145
146
  if not isinstance(self._expression, exp.Merge):
146
147
  self._expression = exp.Merge(this=None, using=None, on=None, whens=exp.Whens(expressions=[]))
147
148
 
148
- # Get or create the whens object
149
149
  whens = self._expression.args.get("whens")
150
150
  if not whens:
151
151
  whens = exp.Whens(expressions=[])
152
152
  self._expression.set("whens", whens)
153
153
 
154
- # Add the when clause to the whens expressions using SQLGlot's append method
155
154
  whens.append("expressions", when_clause)
156
155
 
157
156
  def when_matched_then_update(
@@ -7,7 +7,7 @@ from sqlspec.exceptions import SQLBuilderError
7
7
  from sqlspec.statement.builder._parsing_utils import parse_order_expression
8
8
 
9
9
  if TYPE_CHECKING:
10
- from sqlspec.statement.builder.protocols import BuilderProtocol
10
+ from sqlspec.protocols import SQLBuilderProtocol
11
11
 
12
12
  __all__ = ("OrderByClauseMixin",)
13
13
 
@@ -28,7 +28,7 @@ class OrderByClauseMixin:
28
28
  Returns:
29
29
  The current builder instance for method chaining.
30
30
  """
31
- builder = cast("BuilderProtocol", self)
31
+ builder = cast("SQLBuilderProtocol", self)
32
32
  if not isinstance(builder._expression, exp.Select):
33
33
  msg = "ORDER BY is only supported for SELECT statements."
34
34
  raise SQLBuilderError(msg)
@@ -5,13 +5,13 @@ from sqlglot import exp
5
5
  if TYPE_CHECKING:
6
6
  from sqlglot.dialects.dialect import DialectType
7
7
 
8
- from sqlspec.statement.builder.select import SelectBuilder
8
+ from sqlspec.statement.builder.select import Select
9
9
 
10
10
  __all__ = ("PivotClauseMixin",)
11
11
 
12
12
 
13
13
  class PivotClauseMixin:
14
- """Mixin class to add PIVOT functionality to a SelectBuilder."""
14
+ """Mixin class to add PIVOT functionality to a Select."""
15
15
 
16
16
  _expression: "Optional[exp.Expression]" = None
17
17
  dialect: "DialectType" = None
@@ -23,7 +23,7 @@ class PivotClauseMixin:
23
23
  pivot_column: Union[str, exp.Expression],
24
24
  pivot_values: list[Union[str, int, float, exp.Expression]],
25
25
  alias: Optional[str] = None,
26
- ) -> "SelectBuilder":
26
+ ) -> "Select":
27
27
  """Adds a PIVOT clause to the SELECT statement.
28
28
 
29
29
  Example:
@@ -61,7 +61,6 @@ class PivotClauseMixin:
61
61
  else:
62
62
  pivot_value_exprs.append(exp.Literal.string(str(val)))
63
63
 
64
- # Create the pivot expression with proper fields structure
65
64
  in_expr = exp.In(this=pivot_col_expr, expressions=pivot_value_exprs)
66
65
 
67
66
  pivot_node = exp.Pivot(expressions=[pivot_agg_expr], fields=[in_expr], unpivot=False)
@@ -69,14 +68,12 @@ class PivotClauseMixin:
69
68
  if alias:
70
69
  pivot_node.set("alias", exp.TableAlias(this=exp.to_identifier(alias)))
71
70
 
72
- # Add pivot to the table in the FROM clause
73
71
  from_clause = current_expr.args.get("from")
74
72
  if from_clause and isinstance(from_clause, exp.From):
75
73
  table = from_clause.this
76
74
  if isinstance(table, exp.Table):
77
- # Add to pivots array
78
75
  existing_pivots = table.args.get("pivots", [])
79
76
  existing_pivots.append(pivot_node)
80
77
  table.set("pivots", existing_pivots)
81
78
 
82
- return cast("SelectBuilder", self)
79
+ return cast("Select", self)