sql_fusion 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_fusion/__init__.py +15 -0
- sql_fusion/composite_table.py +689 -0
- sql_fusion/operators.py +119 -0
- sql_fusion/query/__init__.py +0 -0
- sql_fusion/query/delete.py +84 -0
- sql_fusion/query/insert.py +72 -0
- sql_fusion/query/select.py +527 -0
- sql_fusion/query/update.py +76 -0
- sql_fusion-1.0.0.dist-info/METADATA +580 -0
- sql_fusion-1.0.0.dist-info/RECORD +12 -0
- sql_fusion-1.0.0.dist-info/WHEEL +4 -0
- sql_fusion-1.0.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
from copy import copy
|
|
2
|
+
from typing import Any, Self
|
|
3
|
+
|
|
4
|
+
from sql_fusion.composite_table import (
|
|
5
|
+
AbstractQuery,
|
|
6
|
+
Alias,
|
|
7
|
+
AliasRegistry,
|
|
8
|
+
Column,
|
|
9
|
+
Condition,
|
|
10
|
+
FunctionCall,
|
|
11
|
+
Table,
|
|
12
|
+
)
|
|
13
|
+
from sql_fusion.operators import EqualOperator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class select(AbstractQuery):
|
|
17
|
+
def __init__(self, *columns: Column | Alias | FunctionCall) -> None:
|
|
18
|
+
super().__init__(table=None, columns=columns)
|
|
19
|
+
self._having_condition: Condition | None = None
|
|
20
|
+
self._group_by_columns: tuple[Column, ...] = ()
|
|
21
|
+
self._group_by_type: str = "normal"
|
|
22
|
+
self._grouping_sets: tuple[tuple[Column, ...], ...] = ()
|
|
23
|
+
self._order_by_columns: tuple[
|
|
24
|
+
tuple[Column | Alias | FunctionCall, bool],
|
|
25
|
+
...,
|
|
26
|
+
] = ()
|
|
27
|
+
self._joins: list[
|
|
28
|
+
tuple[str, Table, Condition | None]
|
|
29
|
+
] = [] # (join_type, table, condition or None for CROSS JOIN)
|
|
30
|
+
self._limit: int | None = None
|
|
31
|
+
self._offset: int | None = None
|
|
32
|
+
self._distinct: bool = False
|
|
33
|
+
|
|
34
|
+
def build_query( # noqa: C901, PLR0912, PLR0915
|
|
35
|
+
self,
|
|
36
|
+
alias_registry: AliasRegistry | None = None,
|
|
37
|
+
) -> tuple[str, tuple[Any, ...]]:
|
|
38
|
+
registry = alias_registry or self._alias_registry
|
|
39
|
+
params: list[Any] = []
|
|
40
|
+
with_sql, with_params = self._build_with_clause(registry)
|
|
41
|
+
params.extend(with_params)
|
|
42
|
+
table = self._get_table()
|
|
43
|
+
table_sql, table_params, alias = self._prepare_table_entry(
|
|
44
|
+
table,
|
|
45
|
+
registry,
|
|
46
|
+
)
|
|
47
|
+
joins_data = self._prepare_join_entries(registry)
|
|
48
|
+
|
|
49
|
+
if not self._columns:
|
|
50
|
+
col_part: str = "*"
|
|
51
|
+
else:
|
|
52
|
+
col_parts: list[str] = []
|
|
53
|
+
|
|
54
|
+
for col in self._columns:
|
|
55
|
+
if isinstance(col, FunctionCall):
|
|
56
|
+
# Handle function calls
|
|
57
|
+
func_sql, func_params = col.to_sql(
|
|
58
|
+
registry,
|
|
59
|
+
include_alias=True,
|
|
60
|
+
)
|
|
61
|
+
col_parts.append(func_sql)
|
|
62
|
+
params.extend(func_params)
|
|
63
|
+
elif isinstance(col, Alias):
|
|
64
|
+
col_parts.append(col.to_sql(registry))
|
|
65
|
+
else:
|
|
66
|
+
# Handle regular columns from any table in scope.
|
|
67
|
+
col_parts.append(col.get_ref(registry))
|
|
68
|
+
|
|
69
|
+
col_part = ", ".join(col_parts)
|
|
70
|
+
|
|
71
|
+
distinct_part = "SELECT DISTINCT" if self._distinct else "SELECT"
|
|
72
|
+
query_parts: list[str] = []
|
|
73
|
+
if with_sql:
|
|
74
|
+
query_parts.append(with_sql)
|
|
75
|
+
query_parts.append(
|
|
76
|
+
self._build_clause("SELECT", distinct_part, col_part),
|
|
77
|
+
)
|
|
78
|
+
query_parts.append(
|
|
79
|
+
self._build_clause(
|
|
80
|
+
"FROM",
|
|
81
|
+
"FROM",
|
|
82
|
+
f'{table_sql} AS "{alias.name}"',
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
params.extend(table_params)
|
|
86
|
+
|
|
87
|
+
# Add JOIN clauses
|
|
88
|
+
if joins_data:
|
|
89
|
+
joins_sql, joins_params = self._build_joins_from_entries(
|
|
90
|
+
registry,
|
|
91
|
+
joins_data,
|
|
92
|
+
)
|
|
93
|
+
query_parts.append(joins_sql)
|
|
94
|
+
params.extend(joins_params)
|
|
95
|
+
|
|
96
|
+
if self._where_condition:
|
|
97
|
+
where_sql, where_params = self._where_condition.to_sql(registry)
|
|
98
|
+
query_parts.append(self._build_clause("WHERE", "WHERE", where_sql))
|
|
99
|
+
params.extend(where_params)
|
|
100
|
+
|
|
101
|
+
if (
|
|
102
|
+
self._group_by_columns
|
|
103
|
+
or self._grouping_sets
|
|
104
|
+
or self._group_by_type == "all"
|
|
105
|
+
):
|
|
106
|
+
group_by_sql, group_by_params = self._build_group_by_clause(
|
|
107
|
+
registry,
|
|
108
|
+
)
|
|
109
|
+
query_parts.append(group_by_sql)
|
|
110
|
+
params.extend(group_by_params)
|
|
111
|
+
|
|
112
|
+
if self._having_condition:
|
|
113
|
+
having_sql, having_params = self._having_condition.to_sql(registry)
|
|
114
|
+
query_parts.append(
|
|
115
|
+
self._build_clause("HAVING", "HAVING", having_sql),
|
|
116
|
+
)
|
|
117
|
+
params.extend(having_params)
|
|
118
|
+
|
|
119
|
+
if self._order_by_columns:
|
|
120
|
+
order_parts: list[str] = []
|
|
121
|
+
for col, descending in self._order_by_columns:
|
|
122
|
+
if isinstance(col, FunctionCall):
|
|
123
|
+
col_sql, col_params = col.to_sql(registry)
|
|
124
|
+
params.extend(col_params)
|
|
125
|
+
elif isinstance(col, Alias):
|
|
126
|
+
col_sql = col.to_sql(registry)
|
|
127
|
+
else:
|
|
128
|
+
col_sql = col.get_ref(registry)
|
|
129
|
+
|
|
130
|
+
if descending:
|
|
131
|
+
col_sql = f"{col_sql} DESC"
|
|
132
|
+
order_parts.append(col_sql)
|
|
133
|
+
|
|
134
|
+
query_parts.append(
|
|
135
|
+
self._build_clause(
|
|
136
|
+
"ORDER BY",
|
|
137
|
+
"ORDER BY",
|
|
138
|
+
", ".join(order_parts),
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if self._limit is not None:
|
|
143
|
+
query_parts.append(
|
|
144
|
+
self._build_clause("LIMIT", "LIMIT", str(self._limit)),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if self._offset is not None:
|
|
148
|
+
query_parts.append(
|
|
149
|
+
self._build_clause("OFFSET", "OFFSET", str(self._offset)),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return self._apply_compile_expressions(
|
|
153
|
+
" ".join(query_parts),
|
|
154
|
+
tuple(params),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def _prepare_table_entry(
|
|
158
|
+
self,
|
|
159
|
+
table: Table,
|
|
160
|
+
alias_registry: AliasRegistry,
|
|
161
|
+
) -> tuple[str, tuple[Any, ...], Alias]:
|
|
162
|
+
if table._subquery is not None: # pyright: ignore[reportPrivateUsage]
|
|
163
|
+
table_sql, table_params = table.to_sql(alias_registry)
|
|
164
|
+
alias = alias_registry.get_alias_for_table(table)
|
|
165
|
+
return table_sql, table_params, alias
|
|
166
|
+
|
|
167
|
+
alias = alias_registry.get_alias_for_table(table)
|
|
168
|
+
table_sql, table_params = table.to_sql(alias_registry)
|
|
169
|
+
return table_sql, table_params, alias
|
|
170
|
+
|
|
171
|
+
def _prepare_join_entries(
|
|
172
|
+
self,
|
|
173
|
+
alias_registry: AliasRegistry,
|
|
174
|
+
) -> list[
|
|
175
|
+
tuple[str, Table, Condition | None, str, tuple[Any, ...], Alias]
|
|
176
|
+
]:
|
|
177
|
+
join_entries: list[
|
|
178
|
+
tuple[str, Table, Condition | None, str, tuple[Any, ...], Alias]
|
|
179
|
+
] = []
|
|
180
|
+
|
|
181
|
+
for join_type, join_table, condition in self._joins:
|
|
182
|
+
join_sql, join_params, alias = self._prepare_table_entry(
|
|
183
|
+
join_table,
|
|
184
|
+
alias_registry,
|
|
185
|
+
)
|
|
186
|
+
join_entries.append(
|
|
187
|
+
(
|
|
188
|
+
join_type,
|
|
189
|
+
join_table,
|
|
190
|
+
condition,
|
|
191
|
+
join_sql,
|
|
192
|
+
join_params,
|
|
193
|
+
alias,
|
|
194
|
+
),
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return join_entries
|
|
198
|
+
|
|
199
|
+
def _build_joins_from_entries(
|
|
200
|
+
self,
|
|
201
|
+
alias_registry: AliasRegistry,
|
|
202
|
+
join_entries: list[
|
|
203
|
+
tuple[str, Table, Condition | None, str, tuple[Any, ...], Alias]
|
|
204
|
+
],
|
|
205
|
+
) -> tuple[str, list[Any]]:
|
|
206
|
+
"""Build JOIN clauses and return SQL string and parameters."""
|
|
207
|
+
joins_sql_parts: list[str] = []
|
|
208
|
+
joins_params: list[Any] = []
|
|
209
|
+
|
|
210
|
+
for (
|
|
211
|
+
join_type,
|
|
212
|
+
_join_table,
|
|
213
|
+
condition,
|
|
214
|
+
join_sql,
|
|
215
|
+
join_params,
|
|
216
|
+
alias,
|
|
217
|
+
) in join_entries:
|
|
218
|
+
join_body = f'{join_sql} AS "{alias.name}"'
|
|
219
|
+
joins_params.extend(join_params)
|
|
220
|
+
|
|
221
|
+
# CROSS JOIN doesn't have an ON clause
|
|
222
|
+
if condition is not None:
|
|
223
|
+
condition_sql, condition_params = condition.to_sql(
|
|
224
|
+
alias_registry,
|
|
225
|
+
)
|
|
226
|
+
join_body += f" ON {condition_sql}"
|
|
227
|
+
joins_params.extend(condition_params)
|
|
228
|
+
|
|
229
|
+
joins_sql_parts.append(
|
|
230
|
+
self._build_clause(
|
|
231
|
+
"JOIN",
|
|
232
|
+
f"{join_type} JOIN",
|
|
233
|
+
join_body,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return " ".join(joins_sql_parts), joins_params
|
|
238
|
+
|
|
239
|
+
def join(
|
|
240
|
+
self,
|
|
241
|
+
table: Table | AbstractQuery,
|
|
242
|
+
condition: Condition,
|
|
243
|
+
) -> Self:
|
|
244
|
+
"""Add an INNER JOIN clause."""
|
|
245
|
+
qs = copy(self)
|
|
246
|
+
qs._joins = self._joins.copy()
|
|
247
|
+
if isinstance(table, AbstractQuery):
|
|
248
|
+
table = Table(table)
|
|
249
|
+
qs._joins.append(("INNER", table, condition))
|
|
250
|
+
return qs
|
|
251
|
+
|
|
252
|
+
def left_join(
|
|
253
|
+
self,
|
|
254
|
+
table: Table | AbstractQuery,
|
|
255
|
+
condition: Condition,
|
|
256
|
+
) -> Self:
|
|
257
|
+
"""Add a LEFT JOIN clause."""
|
|
258
|
+
qs = copy(self)
|
|
259
|
+
qs._joins = self._joins.copy()
|
|
260
|
+
if isinstance(table, AbstractQuery):
|
|
261
|
+
table = Table(table)
|
|
262
|
+
qs._joins.append(("LEFT", table, condition))
|
|
263
|
+
return qs
|
|
264
|
+
|
|
265
|
+
def right_join(
|
|
266
|
+
self,
|
|
267
|
+
table: Table | AbstractQuery,
|
|
268
|
+
condition: Condition,
|
|
269
|
+
) -> Self:
|
|
270
|
+
"""Add a RIGHT JOIN clause."""
|
|
271
|
+
qs = copy(self)
|
|
272
|
+
qs._joins = self._joins.copy()
|
|
273
|
+
if isinstance(table, AbstractQuery):
|
|
274
|
+
table = Table(table)
|
|
275
|
+
qs._joins.append(("RIGHT", table, condition))
|
|
276
|
+
return qs
|
|
277
|
+
|
|
278
|
+
def full_join(
|
|
279
|
+
self,
|
|
280
|
+
table: Table | AbstractQuery,
|
|
281
|
+
condition: Condition,
|
|
282
|
+
) -> Self:
|
|
283
|
+
"""Add a FULL OUTER JOIN clause."""
|
|
284
|
+
qs = copy(self)
|
|
285
|
+
qs._joins = self._joins.copy()
|
|
286
|
+
if isinstance(table, AbstractQuery):
|
|
287
|
+
table = Table(table)
|
|
288
|
+
qs._joins.append(("FULL OUTER", table, condition))
|
|
289
|
+
return qs
|
|
290
|
+
|
|
291
|
+
def cross_join(self, table: Table | AbstractQuery) -> Self:
|
|
292
|
+
"""Add a CROSS JOIN clause (cartesian product)."""
|
|
293
|
+
qs = copy(self)
|
|
294
|
+
qs._joins = self._joins.copy()
|
|
295
|
+
if isinstance(table, AbstractQuery):
|
|
296
|
+
table = Table(table)
|
|
297
|
+
qs._joins.append(("CROSS", table, None))
|
|
298
|
+
return qs
|
|
299
|
+
|
|
300
|
+
def semi_join(
|
|
301
|
+
self,
|
|
302
|
+
table: Table | AbstractQuery,
|
|
303
|
+
condition: Condition,
|
|
304
|
+
) -> Self:
|
|
305
|
+
"""Add a SEMI JOIN clause (exists check)."""
|
|
306
|
+
qs = copy(self)
|
|
307
|
+
qs._joins = self._joins.copy()
|
|
308
|
+
if isinstance(table, AbstractQuery):
|
|
309
|
+
table = Table(table)
|
|
310
|
+
qs._joins.append(("SEMI", table, condition))
|
|
311
|
+
return qs
|
|
312
|
+
|
|
313
|
+
def anti_join(
|
|
314
|
+
self,
|
|
315
|
+
table: Table | AbstractQuery,
|
|
316
|
+
condition: Condition,
|
|
317
|
+
) -> Self:
|
|
318
|
+
"""Add an ANTI JOIN clause (not exists check)."""
|
|
319
|
+
qs = copy(self)
|
|
320
|
+
qs._joins = self._joins.copy()
|
|
321
|
+
if isinstance(table, AbstractQuery):
|
|
322
|
+
table = Table(table)
|
|
323
|
+
qs._joins.append(("ANTI", table, condition))
|
|
324
|
+
return qs
|
|
325
|
+
|
|
326
|
+
def limit(self, n: int) -> Self:
|
|
327
|
+
if n < 0:
|
|
328
|
+
raise ValueError("LIMIT must be non-negative")
|
|
329
|
+
qs = copy(self)
|
|
330
|
+
qs._limit = n
|
|
331
|
+
return qs
|
|
332
|
+
|
|
333
|
+
def offset(self, n: int) -> Self:
|
|
334
|
+
if n < 0:
|
|
335
|
+
raise ValueError("OFFSET must be non-negative")
|
|
336
|
+
qs = copy(self)
|
|
337
|
+
qs._offset = n
|
|
338
|
+
return qs
|
|
339
|
+
|
|
340
|
+
def order_by(
|
|
341
|
+
self,
|
|
342
|
+
*columns: Column | Alias | FunctionCall,
|
|
343
|
+
descending: bool = False,
|
|
344
|
+
) -> Self:
|
|
345
|
+
if not columns:
|
|
346
|
+
raise ValueError("order_by() requires at least one column")
|
|
347
|
+
|
|
348
|
+
qs = copy(self)
|
|
349
|
+
qs._order_by_columns = self._order_by_columns + tuple(
|
|
350
|
+
(column, descending) for column in columns
|
|
351
|
+
)
|
|
352
|
+
return qs
|
|
353
|
+
|
|
354
|
+
def distinct(self) -> Self:
|
|
355
|
+
"""Add DISTINCT clause to select only unique rows."""
|
|
356
|
+
qs = copy(self)
|
|
357
|
+
qs._distinct = True
|
|
358
|
+
return qs
|
|
359
|
+
|
|
360
|
+
def _build_group_by_clause(
|
|
361
|
+
self,
|
|
362
|
+
alias_registry: AliasRegistry | None = None,
|
|
363
|
+
) -> tuple[str, list[Any]]:
|
|
364
|
+
registry = alias_registry or self._alias_registry
|
|
365
|
+
if self._group_by_type == "all":
|
|
366
|
+
return self._build_clause("GROUP BY", "GROUP BY", "ALL"), []
|
|
367
|
+
|
|
368
|
+
col_refs: str = ", ".join(
|
|
369
|
+
col.get_ref(registry) for col in self._group_by_columns
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
if self._group_by_type == "rollup":
|
|
373
|
+
return (
|
|
374
|
+
self._build_clause(
|
|
375
|
+
"GROUP BY",
|
|
376
|
+
"GROUP BY",
|
|
377
|
+
f"ROLLUP ({col_refs})",
|
|
378
|
+
),
|
|
379
|
+
[],
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if self._group_by_type == "cube":
|
|
383
|
+
return (
|
|
384
|
+
self._build_clause(
|
|
385
|
+
"GROUP BY",
|
|
386
|
+
"GROUP BY",
|
|
387
|
+
f"CUBE ({col_refs})",
|
|
388
|
+
),
|
|
389
|
+
[],
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if self._group_by_type == "grouping_sets":
|
|
393
|
+
gr_sets = (
|
|
394
|
+
self._extract_col_set_with_registry(col_set, registry)
|
|
395
|
+
if col_set
|
|
396
|
+
else "()"
|
|
397
|
+
for col_set in self._grouping_sets
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
sets_sql: str = ", ".join(gr_sets)
|
|
401
|
+
return (
|
|
402
|
+
self._build_clause(
|
|
403
|
+
"GROUP BY",
|
|
404
|
+
"GROUP BY",
|
|
405
|
+
f"GROUPING SETS ({sets_sql})",
|
|
406
|
+
),
|
|
407
|
+
[],
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return self._build_clause("GROUP BY", "GROUP BY", col_refs), []
|
|
411
|
+
|
|
412
|
+
def _extract_col_set(self, col_set: tuple[Column, ...]) -> str:
|
|
413
|
+
return self._extract_col_set_with_registry(
|
|
414
|
+
col_set,
|
|
415
|
+
self._alias_registry,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
def _extract_col_set_with_registry(
|
|
419
|
+
self,
|
|
420
|
+
col_set: tuple[Column, ...],
|
|
421
|
+
alias_registry: AliasRegistry,
|
|
422
|
+
) -> str:
|
|
423
|
+
col_gen: list[str] = []
|
|
424
|
+
|
|
425
|
+
for col in col_set:
|
|
426
|
+
col_gen.append(col.get_ref(alias_registry))
|
|
427
|
+
|
|
428
|
+
st = ", ".join(col_gen)
|
|
429
|
+
return f"({st})"
|
|
430
|
+
|
|
431
|
+
def having(self, *conditions: Condition) -> Self:
|
|
432
|
+
if not self._group_by_columns and self._group_by_type == "normal":
|
|
433
|
+
raise ValueError("Cannot use having() without group_by()")
|
|
434
|
+
|
|
435
|
+
qs = copy(self)
|
|
436
|
+
combined_condition: Condition | None = None
|
|
437
|
+
|
|
438
|
+
for condition in conditions:
|
|
439
|
+
if combined_condition is None:
|
|
440
|
+
combined_condition = condition
|
|
441
|
+
else:
|
|
442
|
+
combined_condition = combined_condition & condition
|
|
443
|
+
|
|
444
|
+
if combined_condition:
|
|
445
|
+
if qs._having_condition is None:
|
|
446
|
+
qs._having_condition = combined_condition
|
|
447
|
+
else:
|
|
448
|
+
qs._having_condition = (
|
|
449
|
+
qs._having_condition & combined_condition
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
return qs
|
|
453
|
+
|
|
454
|
+
def having_by(self, **kwargs: Any) -> Self:
|
|
455
|
+
if not self._group_by_columns and self._group_by_type == "normal":
|
|
456
|
+
raise ValueError("Cannot use having_by() without group_by()")
|
|
457
|
+
|
|
458
|
+
qs = copy(self)
|
|
459
|
+
combined_condition: Condition | None = None
|
|
460
|
+
table = self._get_table()
|
|
461
|
+
qs._alias_registry.get_alias_for_table(table)
|
|
462
|
+
|
|
463
|
+
for key, value in kwargs.items():
|
|
464
|
+
col: Column = Column(key)
|
|
465
|
+
col._attach_table(table) # pyright: ignore[reportPrivateUsage]
|
|
466
|
+
condition = Condition(
|
|
467
|
+
column=col,
|
|
468
|
+
operator=EqualOperator,
|
|
469
|
+
value=value,
|
|
470
|
+
)
|
|
471
|
+
if combined_condition is None:
|
|
472
|
+
combined_condition = condition
|
|
473
|
+
else:
|
|
474
|
+
combined_condition = combined_condition & condition
|
|
475
|
+
|
|
476
|
+
if combined_condition:
|
|
477
|
+
if qs._having_condition is None:
|
|
478
|
+
qs._having_condition = combined_condition
|
|
479
|
+
else:
|
|
480
|
+
qs._having_condition = (
|
|
481
|
+
qs._having_condition & combined_condition
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
return qs
|
|
485
|
+
|
|
486
|
+
def group_by(self, *columns: Column) -> Self:
|
|
487
|
+
qs = copy(self)
|
|
488
|
+
if not columns:
|
|
489
|
+
qs._group_by_type = "all"
|
|
490
|
+
else:
|
|
491
|
+
qs._group_by_columns = columns
|
|
492
|
+
qs._group_by_type = "normal"
|
|
493
|
+
return qs
|
|
494
|
+
|
|
495
|
+
def group_by_rollup(self, *columns: Column) -> Self:
|
|
496
|
+
if not columns:
|
|
497
|
+
raise ValueError("group_by_rollup() requires at least one column")
|
|
498
|
+
|
|
499
|
+
qs = copy(self)
|
|
500
|
+
qs._group_by_columns = columns
|
|
501
|
+
qs._group_by_type = "rollup"
|
|
502
|
+
return qs
|
|
503
|
+
|
|
504
|
+
def group_by_cube(self, *columns: Column) -> Self:
|
|
505
|
+
if not columns:
|
|
506
|
+
raise ValueError("group_by_cube() requires at least one column")
|
|
507
|
+
|
|
508
|
+
qs = copy(self)
|
|
509
|
+
qs._group_by_columns = columns
|
|
510
|
+
qs._group_by_type = "cube"
|
|
511
|
+
return qs
|
|
512
|
+
|
|
513
|
+
def group_by_grouping_sets(self, *column_sets: tuple[Column, ...]) -> Self:
|
|
514
|
+
if not column_sets:
|
|
515
|
+
raise ValueError(
|
|
516
|
+
"group_by_grouping_sets() requires at least one set",
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
qs = copy(self)
|
|
520
|
+
qs._group_by_type = "grouping_sets"
|
|
521
|
+
qs._grouping_sets = column_sets
|
|
522
|
+
return qs
|
|
523
|
+
|
|
524
|
+
def from_(self, table: Table | AbstractQuery) -> Self:
|
|
525
|
+
qs = copy(self)
|
|
526
|
+
qs._table = table if isinstance(table, Table) else Table(table)
|
|
527
|
+
return qs
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from typing import Any, Self
|
|
2
|
+
|
|
3
|
+
from sql_fusion.composite_table import (
|
|
4
|
+
AbstractQuery,
|
|
5
|
+
AliasRegistry,
|
|
6
|
+
BinaryExpression,
|
|
7
|
+
Column,
|
|
8
|
+
FunctionCall,
|
|
9
|
+
Table,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class update(AbstractQuery):
|
|
14
|
+
def __init__(self, table: Table) -> None:
|
|
15
|
+
super().__init__(table=table, columns=())
|
|
16
|
+
self._values: dict[str, Any] = {}
|
|
17
|
+
|
|
18
|
+
def set(self, **values: Any) -> Self:
|
|
19
|
+
if not values:
|
|
20
|
+
raise ValueError("No values provided for update")
|
|
21
|
+
self._values.update(values)
|
|
22
|
+
return self
|
|
23
|
+
|
|
24
|
+
def build_query(
|
|
25
|
+
self,
|
|
26
|
+
alias_registry: AliasRegistry | None = None,
|
|
27
|
+
) -> tuple[str, tuple[Any, ...]]:
|
|
28
|
+
if not self._values:
|
|
29
|
+
raise ValueError("No values provided for update")
|
|
30
|
+
|
|
31
|
+
registry = alias_registry or self._alias_registry
|
|
32
|
+
table = self._get_table()
|
|
33
|
+
with_sql, with_params = self._build_with_clause(registry)
|
|
34
|
+
alias = registry.get_alias_for_table(table)
|
|
35
|
+
assignments: list[str] = []
|
|
36
|
+
params: list[Any] = []
|
|
37
|
+
|
|
38
|
+
for column_name, value in self._values.items():
|
|
39
|
+
column_ref = f'"{column_name}"'
|
|
40
|
+
|
|
41
|
+
if isinstance(value, Column):
|
|
42
|
+
assignments.append(
|
|
43
|
+
f"{column_ref} = {value.get_ref(registry)}",
|
|
44
|
+
)
|
|
45
|
+
elif isinstance(value, (FunctionCall, BinaryExpression)):
|
|
46
|
+
value_sql, value_params = value.to_sql(registry)
|
|
47
|
+
assignments.append(f"{column_ref} = {value_sql}")
|
|
48
|
+
params.extend(value_params)
|
|
49
|
+
elif isinstance(value, AbstractQuery):
|
|
50
|
+
value_sql, value_params = value.build_query(registry)
|
|
51
|
+
assignments.append(f"{column_ref} = ({value_sql})")
|
|
52
|
+
params.extend(value_params)
|
|
53
|
+
else:
|
|
54
|
+
assignments.append(f"{column_ref} = ?")
|
|
55
|
+
params.append(value)
|
|
56
|
+
|
|
57
|
+
set_clause = self._build_clause(
|
|
58
|
+
"SET",
|
|
59
|
+
"SET",
|
|
60
|
+
", ".join(assignments),
|
|
61
|
+
)
|
|
62
|
+
query = self._build_clause(
|
|
63
|
+
"UPDATE",
|
|
64
|
+
"UPDATE",
|
|
65
|
+
f'"{table.get_name()}" AS "{alias.name}" {set_clause}',
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if self._where_condition:
|
|
69
|
+
where_sql, where_params = self._where_condition.to_sql(registry)
|
|
70
|
+
query += f" {self._build_clause('WHERE', 'WHERE', where_sql)}"
|
|
71
|
+
params.extend(where_params)
|
|
72
|
+
|
|
73
|
+
return self._apply_compile_expressions(
|
|
74
|
+
f"{with_sql} {query}" if with_sql else query,
|
|
75
|
+
tuple(with_params + params),
|
|
76
|
+
)
|