sqlspec 0.15.0__py3-none-any.whl → 0.16.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +702 -44
- sqlspec/builder/_base.py +77 -44
- sqlspec/builder/_column.py +0 -4
- sqlspec/builder/_ddl.py +15 -52
- sqlspec/builder/_ddl_utils.py +0 -1
- sqlspec/builder/_delete.py +4 -5
- sqlspec/builder/_insert.py +235 -44
- sqlspec/builder/_merge.py +17 -2
- sqlspec/builder/_parsing_utils.py +42 -14
- sqlspec/builder/_select.py +29 -33
- sqlspec/builder/_update.py +4 -2
- sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
- sqlspec/builder/mixins/_delete_operations.py +6 -1
- sqlspec/builder/mixins/_insert_operations.py +126 -24
- sqlspec/builder/mixins/_join_operations.py +44 -10
- sqlspec/builder/mixins/_merge_operations.py +183 -25
- sqlspec/builder/mixins/_order_limit_operations.py +15 -3
- sqlspec/builder/mixins/_pivot_operations.py +11 -2
- sqlspec/builder/mixins/_select_operations.py +21 -14
- sqlspec/builder/mixins/_update_operations.py +80 -32
- sqlspec/builder/mixins/_where_clause.py +201 -66
- sqlspec/core/cache.py +26 -28
- sqlspec/core/compiler.py +58 -37
- sqlspec/core/filters.py +12 -10
- sqlspec/core/parameters.py +80 -52
- sqlspec/core/result.py +30 -17
- sqlspec/core/statement.py +47 -22
- sqlspec/driver/_async.py +76 -46
- sqlspec/driver/_common.py +25 -6
- sqlspec/driver/_sync.py +73 -43
- sqlspec/driver/mixins/_result_tools.py +62 -37
- sqlspec/driver/mixins/_sql_translator.py +61 -11
- sqlspec/extensions/litestar/cli.py +1 -1
- sqlspec/extensions/litestar/plugin.py +2 -2
- sqlspec/protocols.py +7 -0
- sqlspec/utils/sync_tools.py +1 -1
- sqlspec/utils/type_guards.py +7 -3
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/METADATA +1 -1
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/RECORD +43 -43
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/WHEEL +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/licenses/NOTICE +0 -0
sqlspec/builder/_insert.py
CHANGED
|
@@ -4,8 +4,7 @@ This module provides a fluent interface for building SQL queries safely,
|
|
|
4
4
|
with automatic parameter binding and validation.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Final, Optional
|
|
9
8
|
|
|
10
9
|
from sqlglot import exp
|
|
11
10
|
from typing_extensions import Self
|
|
@@ -21,15 +20,14 @@ if TYPE_CHECKING:
|
|
|
21
20
|
|
|
22
21
|
__all__ = ("Insert",)
|
|
23
22
|
|
|
24
|
-
ERR_MSG_TABLE_NOT_SET = "The target table must be set using .into() before adding values."
|
|
25
|
-
ERR_MSG_VALUES_COLUMNS_MISMATCH = (
|
|
23
|
+
ERR_MSG_TABLE_NOT_SET: Final[str] = "The target table must be set using .into() before adding values."
|
|
24
|
+
ERR_MSG_VALUES_COLUMNS_MISMATCH: Final[str] = (
|
|
26
25
|
"Number of values ({values_len}) does not match the number of specified columns ({columns_len})."
|
|
27
26
|
)
|
|
28
|
-
ERR_MSG_INTERNAL_EXPRESSION_TYPE = "Internal error: expression is not an Insert instance as expected."
|
|
29
|
-
ERR_MSG_EXPRESSION_NOT_INITIALIZED = "Internal error: base expression not initialized."
|
|
27
|
+
ERR_MSG_INTERNAL_EXPRESSION_TYPE: Final[str] = "Internal error: expression is not an Insert instance as expected."
|
|
28
|
+
ERR_MSG_EXPRESSION_NOT_INITIALIZED: Final[str] = "Internal error: base expression not initialized."
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
@dataclass(unsafe_hash=True)
|
|
33
31
|
class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSelectMixin, InsertIntoClauseMixin):
|
|
34
32
|
"""Builder for INSERT statements.
|
|
35
33
|
|
|
@@ -37,9 +35,7 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
37
35
|
in a safe and dialect-agnostic manner with automatic parameter binding.
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
|
-
|
|
41
|
-
_columns: list[str] = field(default_factory=list, init=False)
|
|
42
|
-
_values_added_count: int = field(default=0, init=False)
|
|
38
|
+
__slots__ = ("_columns", "_table", "_values_added_count")
|
|
43
39
|
|
|
44
40
|
def __init__(self, table: Optional[str] = None, **kwargs: Any) -> None:
|
|
45
41
|
"""Initialize INSERT with optional table.
|
|
@@ -50,9 +46,12 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
50
46
|
"""
|
|
51
47
|
super().__init__(**kwargs)
|
|
52
48
|
|
|
53
|
-
|
|
54
|
-
self.
|
|
55
|
-
self.
|
|
49
|
+
# Initialize Insert-specific attributes
|
|
50
|
+
self._table: Optional[str] = None
|
|
51
|
+
self._columns: list[str] = []
|
|
52
|
+
self._values_added_count: int = 0
|
|
53
|
+
|
|
54
|
+
self._initialize_expression()
|
|
56
55
|
|
|
57
56
|
if table:
|
|
58
57
|
self.into(table)
|
|
@@ -91,16 +90,22 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
91
90
|
raise SQLBuilderError(ERR_MSG_INTERNAL_EXPRESSION_TYPE)
|
|
92
91
|
return self._expression
|
|
93
92
|
|
|
94
|
-
def values(self, *values: Any) -> "Self":
|
|
93
|
+
def values(self, *values: Any, **kwargs: Any) -> "Self":
|
|
95
94
|
"""Adds a row of values to the INSERT statement.
|
|
96
95
|
|
|
97
96
|
This method can be called multiple times to insert multiple rows,
|
|
98
97
|
resulting in a multi-row INSERT statement like `VALUES (...), (...)`.
|
|
99
98
|
|
|
99
|
+
Supports:
|
|
100
|
+
- values(val1, val2, val3)
|
|
101
|
+
- values(col1=val1, col2=val2)
|
|
102
|
+
- values(mapping)
|
|
103
|
+
|
|
100
104
|
Args:
|
|
101
105
|
*values: The values for the row to be inserted. The number of values
|
|
102
106
|
must match the number of columns set by `columns()`, if `columns()` was called
|
|
103
107
|
and specified any non-empty list of columns.
|
|
108
|
+
**kwargs: Column-value pairs for named values.
|
|
104
109
|
|
|
105
110
|
Returns:
|
|
106
111
|
The current builder instance for method chaining.
|
|
@@ -113,25 +118,72 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
113
118
|
if not self._table:
|
|
114
119
|
raise SQLBuilderError(ERR_MSG_TABLE_NOT_SET)
|
|
115
120
|
|
|
121
|
+
if kwargs:
|
|
122
|
+
if values:
|
|
123
|
+
msg = "Cannot mix positional values with keyword values."
|
|
124
|
+
raise SQLBuilderError(msg)
|
|
125
|
+
return self.values_from_dict(kwargs)
|
|
126
|
+
|
|
127
|
+
if len(values) == 1:
|
|
128
|
+
try:
|
|
129
|
+
values_0 = values[0]
|
|
130
|
+
if hasattr(values_0, "items"):
|
|
131
|
+
return self.values_from_dict(values_0)
|
|
132
|
+
except (AttributeError, TypeError):
|
|
133
|
+
pass
|
|
134
|
+
|
|
116
135
|
insert_expr = self._get_insert_expression()
|
|
117
136
|
|
|
118
137
|
if self._columns and len(values) != len(self._columns):
|
|
119
138
|
msg = ERR_MSG_VALUES_COLUMNS_MISMATCH.format(values_len=len(values), columns_len=len(self._columns))
|
|
120
139
|
raise SQLBuilderError(msg)
|
|
121
140
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
141
|
+
value_placeholders: list[exp.Expression] = []
|
|
142
|
+
for i, value in enumerate(values):
|
|
143
|
+
if isinstance(value, exp.Expression):
|
|
144
|
+
value_placeholders.append(value)
|
|
145
|
+
elif hasattr(value, "expression") and hasattr(value, "sql"):
|
|
146
|
+
# Handle SQL objects (from sql.raw with parameters)
|
|
147
|
+
expression = getattr(value, "expression", None)
|
|
148
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
149
|
+
# Merge parameters from SQL object into builder
|
|
150
|
+
if hasattr(value, "parameters"):
|
|
151
|
+
sql_parameters = getattr(value, "parameters", {})
|
|
152
|
+
for param_name, param_value in sql_parameters.items():
|
|
153
|
+
self.add_parameter(param_value, name=param_name)
|
|
154
|
+
value_placeholders.append(expression)
|
|
155
|
+
else:
|
|
156
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
157
|
+
sql_text = getattr(value, "sql", "")
|
|
158
|
+
# Merge parameters even when parsing raw SQL
|
|
159
|
+
if hasattr(value, "parameters"):
|
|
160
|
+
sql_parameters = getattr(value, "parameters", {})
|
|
161
|
+
for param_name, param_value in sql_parameters.items():
|
|
162
|
+
self.add_parameter(param_value, name=param_name)
|
|
163
|
+
# Check if sql_text is callable (like Expression.sql method)
|
|
164
|
+
if callable(sql_text):
|
|
165
|
+
sql_text = str(value)
|
|
166
|
+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
|
|
167
|
+
value_placeholders.append(value_expr)
|
|
168
|
+
else:
|
|
169
|
+
if self._columns and i < len(self._columns):
|
|
170
|
+
column_str = str(self._columns[i])
|
|
171
|
+
column_name = column_str.rsplit(".", maxsplit=1)[-1] if "." in column_str else column_str
|
|
172
|
+
param_name = self._generate_unique_parameter_name(column_name)
|
|
173
|
+
else:
|
|
174
|
+
param_name = self._generate_unique_parameter_name(f"value_{i + 1}")
|
|
175
|
+
_, param_name = self.add_parameter(value, name=param_name)
|
|
176
|
+
value_placeholders.append(exp.var(param_name))
|
|
177
|
+
|
|
178
|
+
tuple_expr = exp.Tuple(expressions=value_placeholders)
|
|
127
179
|
if self._values_added_count == 0:
|
|
128
|
-
|
|
129
|
-
insert_expr.set("expression", new_values_node)
|
|
130
|
-
elif isinstance(current_values_expression, exp.Values):
|
|
131
|
-
current_values_expression.expressions.append(exp.Tuple(expressions=list(value_placeholders)))
|
|
180
|
+
insert_expr.set("expression", exp.Values(expressions=[tuple_expr]))
|
|
132
181
|
else:
|
|
133
|
-
|
|
134
|
-
|
|
182
|
+
current_values = insert_expr.args.get("expression")
|
|
183
|
+
if isinstance(current_values, exp.Values):
|
|
184
|
+
current_values.expressions.append(tuple_expr)
|
|
185
|
+
else:
|
|
186
|
+
insert_expr.set("expression", exp.Values(expressions=[tuple_expr]))
|
|
135
187
|
|
|
136
188
|
self._values_added_count += 1
|
|
137
189
|
return self
|
|
@@ -154,10 +206,11 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
154
206
|
if not self._table:
|
|
155
207
|
raise SQLBuilderError(ERR_MSG_TABLE_NOT_SET)
|
|
156
208
|
|
|
209
|
+
data_keys = list(data.keys())
|
|
157
210
|
if not self._columns:
|
|
158
|
-
self.columns(*
|
|
159
|
-
elif set(self._columns) != set(
|
|
160
|
-
msg = f"Dictionary keys {set(
|
|
211
|
+
self.columns(*data_keys)
|
|
212
|
+
elif set(self._columns) != set(data_keys):
|
|
213
|
+
msg = f"Dictionary keys {set(data_keys)} do not match existing columns {set(self._columns)}."
|
|
161
214
|
raise SQLBuilderError(msg)
|
|
162
215
|
|
|
163
216
|
return self.values(*[data[col] for col in self._columns])
|
|
@@ -198,33 +251,171 @@ class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSe
|
|
|
198
251
|
|
|
199
252
|
return self
|
|
200
253
|
|
|
201
|
-
def
|
|
202
|
-
"""Adds an ON CONFLICT
|
|
254
|
+
def on_conflict(self, *columns: str) -> "ConflictBuilder":
|
|
255
|
+
"""Adds an ON CONFLICT clause with specified columns.
|
|
203
256
|
|
|
204
|
-
|
|
257
|
+
Args:
|
|
258
|
+
*columns: Column names that define the conflict. If no columns provided,
|
|
259
|
+
creates an ON CONFLICT without specific columns (catches all conflicts).
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A ConflictBuilder instance for chaining conflict resolution methods.
|
|
263
|
+
|
|
264
|
+
Example:
|
|
265
|
+
```python
|
|
266
|
+
# ON CONFLICT (id) DO NOTHING
|
|
267
|
+
sql.insert("users").values(id=1, name="John").on_conflict(
|
|
268
|
+
"id"
|
|
269
|
+
).do_nothing()
|
|
270
|
+
|
|
271
|
+
# ON CONFLICT (email, username) DO UPDATE SET updated_at = NOW()
|
|
272
|
+
sql.insert("users").values(...).on_conflict(
|
|
273
|
+
"email", "username"
|
|
274
|
+
).do_update(updated_at=sql.raw("NOW()"))
|
|
275
|
+
|
|
276
|
+
# ON CONFLICT DO NOTHING (catches all conflicts)
|
|
277
|
+
sql.insert("users").values(...).on_conflict().do_nothing()
|
|
278
|
+
```
|
|
279
|
+
"""
|
|
280
|
+
return ConflictBuilder(self, columns)
|
|
281
|
+
|
|
282
|
+
def on_conflict_do_nothing(self, *columns: str) -> "Insert":
|
|
283
|
+
"""Adds an ON CONFLICT DO NOTHING clause (convenience method).
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
*columns: Column names that define the conflict. If no columns provided,
|
|
287
|
+
creates an ON CONFLICT without specific columns.
|
|
205
288
|
|
|
206
289
|
Returns:
|
|
207
290
|
The current builder instance for method chaining.
|
|
208
291
|
|
|
209
292
|
Note:
|
|
210
|
-
This is
|
|
211
|
-
For a more general solution, you might need dialect-specific handling.
|
|
293
|
+
This is a convenience method. For more control, use on_conflict().do_nothing().
|
|
212
294
|
"""
|
|
213
|
-
|
|
214
|
-
try:
|
|
215
|
-
on_conflict = exp.OnConflict(this=None, expressions=[])
|
|
216
|
-
insert_expr.set("on", on_conflict)
|
|
217
|
-
except AttributeError:
|
|
218
|
-
pass
|
|
219
|
-
return self
|
|
295
|
+
return self.on_conflict(*columns).do_nothing()
|
|
220
296
|
|
|
221
|
-
def on_duplicate_key_update(self, **
|
|
222
|
-
"""Adds
|
|
297
|
+
def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
|
|
298
|
+
"""Adds conflict resolution using the ON CONFLICT syntax (cross-database compatible).
|
|
223
299
|
|
|
224
300
|
Args:
|
|
225
|
-
**
|
|
301
|
+
**kwargs: Column-value pairs to update on conflict.
|
|
226
302
|
|
|
227
303
|
Returns:
|
|
228
304
|
The current builder instance for method chaining.
|
|
305
|
+
|
|
306
|
+
Note:
|
|
307
|
+
This method uses PostgreSQL-style ON CONFLICT syntax but SQLGlot will
|
|
308
|
+
transpile it to the appropriate syntax for each database (MySQL's
|
|
309
|
+
ON DUPLICATE KEY UPDATE, etc.).
|
|
229
310
|
"""
|
|
230
|
-
|
|
311
|
+
if not kwargs:
|
|
312
|
+
return self
|
|
313
|
+
return self.on_conflict().do_update(**kwargs)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class ConflictBuilder:
|
|
317
|
+
"""Builder for ON CONFLICT clauses in INSERT statements.
|
|
318
|
+
|
|
319
|
+
This builder provides a fluent interface for constructing conflict resolution
|
|
320
|
+
clauses using PostgreSQL-style syntax, which SQLGlot can transpile to other dialects.
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
__slots__ = ("_columns", "_insert_builder")
|
|
324
|
+
|
|
325
|
+
def __init__(self, insert_builder: "Insert", columns: tuple[str, ...]) -> None:
|
|
326
|
+
"""Initialize ConflictBuilder.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
insert_builder: The parent Insert builder
|
|
330
|
+
columns: Column names that define the conflict
|
|
331
|
+
"""
|
|
332
|
+
self._insert_builder = insert_builder
|
|
333
|
+
self._columns = columns
|
|
334
|
+
|
|
335
|
+
def do_nothing(self) -> "Insert":
|
|
336
|
+
"""Add DO NOTHING conflict resolution.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
The parent Insert builder for method chaining.
|
|
340
|
+
|
|
341
|
+
Example:
|
|
342
|
+
```python
|
|
343
|
+
sql.insert("users").values(id=1, name="John").on_conflict(
|
|
344
|
+
"id"
|
|
345
|
+
).do_nothing()
|
|
346
|
+
```
|
|
347
|
+
"""
|
|
348
|
+
insert_expr = self._insert_builder._get_insert_expression()
|
|
349
|
+
|
|
350
|
+
# Create ON CONFLICT with proper structure
|
|
351
|
+
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
|
|
352
|
+
on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING"))
|
|
353
|
+
|
|
354
|
+
insert_expr.set("conflict", on_conflict)
|
|
355
|
+
return self._insert_builder
|
|
356
|
+
|
|
357
|
+
def do_update(self, **kwargs: Any) -> "Insert":
|
|
358
|
+
"""Add DO UPDATE conflict resolution with SET clauses.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
**kwargs: Column-value pairs to update on conflict.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
The parent Insert builder for method chaining.
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
```python
|
|
368
|
+
sql.insert("users").values(id=1, name="John").on_conflict(
|
|
369
|
+
"id"
|
|
370
|
+
).do_update(
|
|
371
|
+
name="Updated Name", updated_at=sql.raw("NOW()")
|
|
372
|
+
)
|
|
373
|
+
```
|
|
374
|
+
"""
|
|
375
|
+
insert_expr = self._insert_builder._get_insert_expression()
|
|
376
|
+
|
|
377
|
+
# Create SET expressions for the UPDATE
|
|
378
|
+
set_expressions = []
|
|
379
|
+
for col, val in kwargs.items():
|
|
380
|
+
if hasattr(val, "expression") and hasattr(val, "sql"):
|
|
381
|
+
# Handle SQL objects (from sql.raw with parameters)
|
|
382
|
+
expression = getattr(val, "expression", None)
|
|
383
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
384
|
+
# Merge parameters from SQL object into builder
|
|
385
|
+
if hasattr(val, "parameters"):
|
|
386
|
+
sql_parameters = getattr(val, "parameters", {})
|
|
387
|
+
for param_name, param_value in sql_parameters.items():
|
|
388
|
+
self._insert_builder.add_parameter(param_value, name=param_name)
|
|
389
|
+
value_expr = expression
|
|
390
|
+
else:
|
|
391
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
392
|
+
sql_text = getattr(val, "sql", "")
|
|
393
|
+
# Merge parameters even when parsing raw SQL
|
|
394
|
+
if hasattr(val, "parameters"):
|
|
395
|
+
sql_parameters = getattr(val, "parameters", {})
|
|
396
|
+
for param_name, param_value in sql_parameters.items():
|
|
397
|
+
self._insert_builder.add_parameter(param_value, name=param_name)
|
|
398
|
+
# Check if sql_text is callable (like Expression.sql method)
|
|
399
|
+
if callable(sql_text):
|
|
400
|
+
sql_text = str(val)
|
|
401
|
+
value_expr = exp.maybe_parse(sql_text) or exp.convert(str(sql_text))
|
|
402
|
+
elif isinstance(val, exp.Expression):
|
|
403
|
+
value_expr = val
|
|
404
|
+
else:
|
|
405
|
+
# Create parameter for regular values
|
|
406
|
+
param_name = self._insert_builder._generate_unique_parameter_name(col)
|
|
407
|
+
_, param_name = self._insert_builder.add_parameter(val, name=param_name)
|
|
408
|
+
value_expr = exp.Placeholder(this=param_name)
|
|
409
|
+
|
|
410
|
+
set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
|
|
411
|
+
|
|
412
|
+
# Create ON CONFLICT with proper structure
|
|
413
|
+
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
|
|
414
|
+
on_conflict = exp.OnConflict(
|
|
415
|
+
conflict_keys=conflict_keys,
|
|
416
|
+
action=exp.var("DO UPDATE"),
|
|
417
|
+
expressions=set_expressions if set_expressions else None,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
insert_expr.set("conflict", on_conflict)
|
|
421
|
+
return self._insert_builder
|
sqlspec/builder/_merge.py
CHANGED
|
@@ -4,7 +4,7 @@ This module provides a fluent interface for building SQL queries safely,
|
|
|
4
4
|
with automatic parameter binding and validation.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from typing import Any, Optional
|
|
8
8
|
|
|
9
9
|
from sqlglot import exp
|
|
10
10
|
|
|
@@ -22,7 +22,6 @@ from sqlspec.core.result import SQLResult
|
|
|
22
22
|
__all__ = ("Merge",)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
@dataclass(unsafe_hash=True)
|
|
26
25
|
class Merge(
|
|
27
26
|
QueryBuilder,
|
|
28
27
|
MergeUsingClauseMixin,
|
|
@@ -38,6 +37,22 @@ class Merge(
|
|
|
38
37
|
(also known as UPSERT in some databases) with automatic parameter binding and validation.
|
|
39
38
|
"""
|
|
40
39
|
|
|
40
|
+
__slots__ = ()
|
|
41
|
+
_expression: Optional[exp.Expression]
|
|
42
|
+
|
|
43
|
+
def __init__(self, target_table: Optional[str] = None, **kwargs: Any) -> None:
|
|
44
|
+
"""Initialize MERGE with optional target table.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
target_table: Target table name
|
|
48
|
+
**kwargs: Additional QueryBuilder arguments
|
|
49
|
+
"""
|
|
50
|
+
super().__init__(**kwargs)
|
|
51
|
+
self._initialize_expression()
|
|
52
|
+
|
|
53
|
+
if target_table:
|
|
54
|
+
self.into(target_table)
|
|
55
|
+
|
|
41
56
|
@property
|
|
42
57
|
def _expected_result_type(self) -> "type[SQLResult]":
|
|
43
58
|
"""Return the expected result type for this builder.
|
|
@@ -5,14 +5,16 @@ that users might pass as strings to various builder methods.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import contextlib
|
|
8
|
-
from typing import Any, Optional, Union, cast
|
|
8
|
+
from typing import Any, Final, Optional, Union, cast
|
|
9
9
|
|
|
10
10
|
from sqlglot import exp, maybe_parse, parse_one
|
|
11
11
|
|
|
12
12
|
from sqlspec.utils.type_guards import has_expression_attr, has_parameter_builder
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def parse_column_expression(
|
|
15
|
+
def parse_column_expression(
|
|
16
|
+
column_input: Union[str, exp.Expression, Any], builder: Optional[Any] = None
|
|
17
|
+
) -> exp.Expression:
|
|
16
18
|
"""Parse a column input that might be a complex expression.
|
|
17
19
|
|
|
18
20
|
Handles cases like:
|
|
@@ -22,9 +24,11 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex
|
|
|
22
24
|
- Function calls: "MAX(price)" -> Max(this=Column(price))
|
|
23
25
|
- Complex expressions: "CASE WHEN ... END" -> Case(...)
|
|
24
26
|
- Custom Column objects from our builder
|
|
27
|
+
- SQL objects with raw SQL expressions
|
|
25
28
|
|
|
26
29
|
Args:
|
|
27
|
-
column_input: String, SQLGlot expression, or Column object
|
|
30
|
+
column_input: String, SQLGlot expression, SQL object, or Column object
|
|
31
|
+
builder: Optional builder instance for parameter merging
|
|
28
32
|
|
|
29
33
|
Returns:
|
|
30
34
|
exp.Expression: Parsed SQLGlot expression
|
|
@@ -32,10 +36,33 @@ def parse_column_expression(column_input: Union[str, exp.Expression, Any]) -> ex
|
|
|
32
36
|
if isinstance(column_input, exp.Expression):
|
|
33
37
|
return column_input
|
|
34
38
|
|
|
39
|
+
# Handle SQL objects (from sql.raw with parameters)
|
|
40
|
+
if hasattr(column_input, "expression") and hasattr(column_input, "sql"):
|
|
41
|
+
# This is likely a SQL object
|
|
42
|
+
expression = getattr(column_input, "expression", None)
|
|
43
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
44
|
+
# Merge parameters from SQL object into builder if available
|
|
45
|
+
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
|
|
46
|
+
sql_parameters = getattr(column_input, "parameters", {})
|
|
47
|
+
for param_name, param_value in sql_parameters.items():
|
|
48
|
+
builder.add_parameter(param_value, name=param_name)
|
|
49
|
+
return cast("exp.Expression", expression)
|
|
50
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
51
|
+
sql_text = getattr(column_input, "sql", "")
|
|
52
|
+
# Merge parameters even when parsing raw SQL
|
|
53
|
+
if builder and hasattr(column_input, "parameters") and hasattr(builder, "add_parameter"):
|
|
54
|
+
sql_parameters = getattr(column_input, "parameters", {})
|
|
55
|
+
for param_name, param_value in sql_parameters.items():
|
|
56
|
+
builder.add_parameter(param_value, name=param_name)
|
|
57
|
+
return exp.maybe_parse(sql_text) or exp.column(str(sql_text))
|
|
58
|
+
|
|
35
59
|
if has_expression_attr(column_input):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
60
|
+
try:
|
|
61
|
+
attr_value = column_input._expression
|
|
62
|
+
if isinstance(attr_value, exp.Expression):
|
|
63
|
+
return attr_value
|
|
64
|
+
except AttributeError:
|
|
65
|
+
pass
|
|
39
66
|
|
|
40
67
|
return exp.maybe_parse(column_input) or exp.column(str(column_input))
|
|
41
68
|
|
|
@@ -102,14 +129,18 @@ def parse_condition_expression(
|
|
|
102
129
|
if isinstance(condition_input, exp.Expression):
|
|
103
130
|
return condition_input
|
|
104
131
|
|
|
105
|
-
tuple_condition_parts = 2
|
|
132
|
+
tuple_condition_parts: Final[int] = 2
|
|
106
133
|
if isinstance(condition_input, tuple) and len(condition_input) == tuple_condition_parts:
|
|
107
134
|
column, value = condition_input
|
|
108
135
|
column_expr = parse_column_expression(column)
|
|
109
136
|
if value is None:
|
|
110
137
|
return exp.Is(this=column_expr, expression=exp.null())
|
|
111
138
|
if builder and has_parameter_builder(builder):
|
|
112
|
-
|
|
139
|
+
from sqlspec.builder.mixins._where_clause import _extract_column_name
|
|
140
|
+
|
|
141
|
+
column_name = _extract_column_name(column)
|
|
142
|
+
param_name = builder._generate_unique_parameter_name(column_name)
|
|
143
|
+
_, param_name = builder.add_parameter(value, name=param_name)
|
|
113
144
|
return exp.EQ(this=column_expr, expression=exp.Placeholder(this=param_name))
|
|
114
145
|
if isinstance(value, str):
|
|
115
146
|
return exp.EQ(this=column_expr, expression=exp.convert(value))
|
|
@@ -125,12 +156,9 @@ def parse_condition_expression(
|
|
|
125
156
|
except Exception:
|
|
126
157
|
try:
|
|
127
158
|
parsed = exp.maybe_parse(condition_input) # type: ignore[var-annotated]
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
pass
|
|
132
|
-
|
|
133
|
-
return exp.condition(condition_input)
|
|
159
|
+
return parsed or exp.condition(condition_input)
|
|
160
|
+
except Exception:
|
|
161
|
+
return exp.condition(condition_input)
|
|
134
162
|
|
|
135
163
|
|
|
136
164
|
__all__ = ("parse_column_expression", "parse_condition_expression", "parse_order_expression", "parse_table_expression")
|
sqlspec/builder/_select.py
CHANGED
|
@@ -5,8 +5,7 @@ with automatic parameter binding and validation.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import re
|
|
8
|
-
from
|
|
9
|
-
from typing import Any, Optional, Union
|
|
8
|
+
from typing import Any, Callable, Final, Optional, Union
|
|
10
9
|
|
|
11
10
|
from sqlglot import exp
|
|
12
11
|
from typing_extensions import Self
|
|
@@ -29,10 +28,9 @@ from sqlspec.core.result import SQLResult
|
|
|
29
28
|
__all__ = ("Select",)
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
TABLE_HINT_PATTERN = r"\b{}\b(\s+AS\s+\w+)?"
|
|
31
|
+
TABLE_HINT_PATTERN: Final[str] = r"\b{}\b(\s+AS\s+\w+)?"
|
|
33
32
|
|
|
34
33
|
|
|
35
|
-
@dataclass
|
|
36
34
|
class Select(
|
|
37
35
|
QueryBuilder,
|
|
38
36
|
WhereClauseMixin,
|
|
@@ -58,9 +56,8 @@ class Select(
|
|
|
58
56
|
>>> result = driver.execute(builder)
|
|
59
57
|
"""
|
|
60
58
|
|
|
61
|
-
|
|
62
|
-
_expression: Optional[exp.Expression]
|
|
63
|
-
_hints: "list[dict[str, object]]" = field(default_factory=list, init=False, repr=False)
|
|
59
|
+
__slots__ = ("_hints", "_with_parts")
|
|
60
|
+
_expression: Optional[exp.Expression]
|
|
64
61
|
|
|
65
62
|
def __init__(self, *columns: str, **kwargs: Any) -> None:
|
|
66
63
|
"""Initialize SELECT with optional columns.
|
|
@@ -75,11 +72,11 @@ class Select(
|
|
|
75
72
|
"""
|
|
76
73
|
super().__init__(**kwargs)
|
|
77
74
|
|
|
78
|
-
|
|
79
|
-
self.
|
|
80
|
-
self._hints = []
|
|
75
|
+
# Initialize Select-specific attributes
|
|
76
|
+
self._with_parts: dict[str, Union[exp.CTE, Select]] = {}
|
|
77
|
+
self._hints: list[dict[str, object]] = []
|
|
81
78
|
|
|
82
|
-
self.
|
|
79
|
+
self._initialize_expression()
|
|
83
80
|
|
|
84
81
|
if columns:
|
|
85
82
|
self.select(*columns)
|
|
@@ -93,7 +90,8 @@ class Select(
|
|
|
93
90
|
"""
|
|
94
91
|
return SQLResult
|
|
95
92
|
|
|
96
|
-
def _create_base_expression(self) ->
|
|
93
|
+
def _create_base_expression(self) -> exp.Select:
|
|
94
|
+
"""Create base SELECT expression."""
|
|
97
95
|
if self._expression is None or not isinstance(self._expression, exp.Select):
|
|
98
96
|
self._expression = exp.Select()
|
|
99
97
|
return self._expression
|
|
@@ -131,44 +129,42 @@ class Select(
|
|
|
131
129
|
if not self._hints:
|
|
132
130
|
return safe_query
|
|
133
131
|
|
|
134
|
-
modified_expr = self._expression
|
|
132
|
+
modified_expr = self._expression or self._create_base_expression()
|
|
135
133
|
|
|
136
134
|
if isinstance(modified_expr, exp.Select):
|
|
137
135
|
statement_hints = [h["hint"] for h in self._hints if h.get("location") == "statement"]
|
|
138
136
|
if statement_hints:
|
|
139
|
-
hint_expressions = []
|
|
140
137
|
|
|
141
|
-
def
|
|
142
|
-
"""Parse a single hint."""
|
|
138
|
+
def parse_hint_safely(hint: Any) -> exp.Expression:
|
|
143
139
|
try:
|
|
144
|
-
hint_str = str(hint)
|
|
140
|
+
hint_str = str(hint)
|
|
145
141
|
hint_expr: Optional[exp.Expression] = exp.maybe_parse(hint_str, dialect=self.dialect_name)
|
|
146
|
-
|
|
147
|
-
return hint_expr
|
|
148
|
-
return exp.Anonymous(this=hint_str)
|
|
142
|
+
return hint_expr or exp.Anonymous(this=hint_str)
|
|
149
143
|
except Exception:
|
|
150
144
|
return exp.Anonymous(this=str(hint))
|
|
151
145
|
|
|
152
|
-
hint_expressions = [
|
|
146
|
+
hint_expressions: list[exp.Expression] = [parse_hint_safely(hint) for hint in statement_hints]
|
|
153
147
|
|
|
154
148
|
if hint_expressions:
|
|
155
|
-
|
|
156
|
-
modified_expr.set("hint", hint_node)
|
|
149
|
+
modified_expr.set("hint", exp.Hint(expressions=hint_expressions))
|
|
157
150
|
|
|
158
151
|
modified_sql = modified_expr.sql(dialect=self.dialect_name, pretty=True)
|
|
159
152
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
hint = th["hint"]
|
|
153
|
+
for hint_dict in self._hints:
|
|
154
|
+
if hint_dict.get("location") == "table" and hint_dict.get("table"):
|
|
155
|
+
table = str(hint_dict["table"])
|
|
156
|
+
hint = str(hint_dict["hint"])
|
|
165
157
|
pattern = TABLE_HINT_PATTERN.format(re.escape(table))
|
|
166
|
-
compiled_pattern = re.compile(pattern, re.IGNORECASE)
|
|
167
158
|
|
|
168
|
-
def
|
|
169
|
-
|
|
170
|
-
|
|
159
|
+
def make_replacement(hint_val: str, table_val: str) -> "Callable[[re.Match[str]], str]":
|
|
160
|
+
def replacement_func(match: re.Match[str]) -> str:
|
|
161
|
+
alias_part = match.group(1) or ""
|
|
162
|
+
return f"/*+ {hint_val} */ {table_val}{alias_part}"
|
|
171
163
|
|
|
172
|
-
|
|
164
|
+
return replacement_func
|
|
165
|
+
|
|
166
|
+
modified_sql = re.sub(
|
|
167
|
+
pattern, make_replacement(hint, table), modified_sql, count=1, flags=re.IGNORECASE
|
|
168
|
+
)
|
|
173
169
|
|
|
174
170
|
return SafeQuery(sql=modified_sql, parameters=safe_query.parameters, dialect=safe_query.dialect)
|
sqlspec/builder/_update.py
CHANGED
|
@@ -4,7 +4,6 @@ This module provides a fluent interface for building SQL queries safely,
|
|
|
4
4
|
with automatic parameter binding and validation.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from dataclasses import dataclass
|
|
8
7
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
9
8
|
|
|
10
9
|
from sqlglot import exp
|
|
@@ -27,7 +26,6 @@ if TYPE_CHECKING:
|
|
|
27
26
|
__all__ = ("Update",)
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
@dataclass(unsafe_hash=True)
|
|
31
29
|
class Update(
|
|
32
30
|
QueryBuilder,
|
|
33
31
|
WhereClauseMixin,
|
|
@@ -72,6 +70,9 @@ class Update(
|
|
|
72
70
|
```
|
|
73
71
|
"""
|
|
74
72
|
|
|
73
|
+
__slots__ = ("_table",)
|
|
74
|
+
_expression: Optional[exp.Expression]
|
|
75
|
+
|
|
75
76
|
def __init__(self, table: Optional[str] = None, **kwargs: Any) -> None:
|
|
76
77
|
"""Initialize UPDATE with optional table.
|
|
77
78
|
|
|
@@ -80,6 +81,7 @@ class Update(
|
|
|
80
81
|
**kwargs: Additional QueryBuilder arguments
|
|
81
82
|
"""
|
|
82
83
|
super().__init__(**kwargs)
|
|
84
|
+
self._initialize_expression()
|
|
83
85
|
|
|
84
86
|
if table:
|
|
85
87
|
self.table(table)
|