sql_fusion 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,527 @@
1
+ from copy import copy
2
+ from typing import Any, Self
3
+
4
+ from sql_fusion.composite_table import (
5
+ AbstractQuery,
6
+ Alias,
7
+ AliasRegistry,
8
+ Column,
9
+ Condition,
10
+ FunctionCall,
11
+ Table,
12
+ )
13
+ from sql_fusion.operators import EqualOperator
14
+
15
+
16
+ class select(AbstractQuery):
17
+ def __init__(self, *columns: Column | Alias | FunctionCall) -> None:
18
+ super().__init__(table=None, columns=columns)
19
+ self._having_condition: Condition | None = None
20
+ self._group_by_columns: tuple[Column, ...] = ()
21
+ self._group_by_type: str = "normal"
22
+ self._grouping_sets: tuple[tuple[Column, ...], ...] = ()
23
+ self._order_by_columns: tuple[
24
+ tuple[Column | Alias | FunctionCall, bool],
25
+ ...,
26
+ ] = ()
27
+ self._joins: list[
28
+ tuple[str, Table, Condition | None]
29
+ ] = [] # (join_type, table, condition or None for CROSS JOIN)
30
+ self._limit: int | None = None
31
+ self._offset: int | None = None
32
+ self._distinct: bool = False
33
+
34
+ def build_query( # noqa: C901, PLR0912, PLR0915
35
+ self,
36
+ alias_registry: AliasRegistry | None = None,
37
+ ) -> tuple[str, tuple[Any, ...]]:
38
+ registry = alias_registry or self._alias_registry
39
+ params: list[Any] = []
40
+ with_sql, with_params = self._build_with_clause(registry)
41
+ params.extend(with_params)
42
+ table = self._get_table()
43
+ table_sql, table_params, alias = self._prepare_table_entry(
44
+ table,
45
+ registry,
46
+ )
47
+ joins_data = self._prepare_join_entries(registry)
48
+
49
+ if not self._columns:
50
+ col_part: str = "*"
51
+ else:
52
+ col_parts: list[str] = []
53
+
54
+ for col in self._columns:
55
+ if isinstance(col, FunctionCall):
56
+ # Handle function calls
57
+ func_sql, func_params = col.to_sql(
58
+ registry,
59
+ include_alias=True,
60
+ )
61
+ col_parts.append(func_sql)
62
+ params.extend(func_params)
63
+ elif isinstance(col, Alias):
64
+ col_parts.append(col.to_sql(registry))
65
+ else:
66
+ # Handle regular columns from any table in scope.
67
+ col_parts.append(col.get_ref(registry))
68
+
69
+ col_part = ", ".join(col_parts)
70
+
71
+ distinct_part = "SELECT DISTINCT" if self._distinct else "SELECT"
72
+ query_parts: list[str] = []
73
+ if with_sql:
74
+ query_parts.append(with_sql)
75
+ query_parts.append(
76
+ self._build_clause("SELECT", distinct_part, col_part),
77
+ )
78
+ query_parts.append(
79
+ self._build_clause(
80
+ "FROM",
81
+ "FROM",
82
+ f'{table_sql} AS "{alias.name}"',
83
+ ),
84
+ )
85
+ params.extend(table_params)
86
+
87
+ # Add JOIN clauses
88
+ if joins_data:
89
+ joins_sql, joins_params = self._build_joins_from_entries(
90
+ registry,
91
+ joins_data,
92
+ )
93
+ query_parts.append(joins_sql)
94
+ params.extend(joins_params)
95
+
96
+ if self._where_condition:
97
+ where_sql, where_params = self._where_condition.to_sql(registry)
98
+ query_parts.append(self._build_clause("WHERE", "WHERE", where_sql))
99
+ params.extend(where_params)
100
+
101
+ if (
102
+ self._group_by_columns
103
+ or self._grouping_sets
104
+ or self._group_by_type == "all"
105
+ ):
106
+ group_by_sql, group_by_params = self._build_group_by_clause(
107
+ registry,
108
+ )
109
+ query_parts.append(group_by_sql)
110
+ params.extend(group_by_params)
111
+
112
+ if self._having_condition:
113
+ having_sql, having_params = self._having_condition.to_sql(registry)
114
+ query_parts.append(
115
+ self._build_clause("HAVING", "HAVING", having_sql),
116
+ )
117
+ params.extend(having_params)
118
+
119
+ if self._order_by_columns:
120
+ order_parts: list[str] = []
121
+ for col, descending in self._order_by_columns:
122
+ if isinstance(col, FunctionCall):
123
+ col_sql, col_params = col.to_sql(registry)
124
+ params.extend(col_params)
125
+ elif isinstance(col, Alias):
126
+ col_sql = col.to_sql(registry)
127
+ else:
128
+ col_sql = col.get_ref(registry)
129
+
130
+ if descending:
131
+ col_sql = f"{col_sql} DESC"
132
+ order_parts.append(col_sql)
133
+
134
+ query_parts.append(
135
+ self._build_clause(
136
+ "ORDER BY",
137
+ "ORDER BY",
138
+ ", ".join(order_parts),
139
+ ),
140
+ )
141
+
142
+ if self._limit is not None:
143
+ query_parts.append(
144
+ self._build_clause("LIMIT", "LIMIT", str(self._limit)),
145
+ )
146
+
147
+ if self._offset is not None:
148
+ query_parts.append(
149
+ self._build_clause("OFFSET", "OFFSET", str(self._offset)),
150
+ )
151
+
152
+ return self._apply_compile_expressions(
153
+ " ".join(query_parts),
154
+ tuple(params),
155
+ )
156
+
157
+ def _prepare_table_entry(
158
+ self,
159
+ table: Table,
160
+ alias_registry: AliasRegistry,
161
+ ) -> tuple[str, tuple[Any, ...], Alias]:
162
+ if table._subquery is not None: # pyright: ignore[reportPrivateUsage]
163
+ table_sql, table_params = table.to_sql(alias_registry)
164
+ alias = alias_registry.get_alias_for_table(table)
165
+ return table_sql, table_params, alias
166
+
167
+ alias = alias_registry.get_alias_for_table(table)
168
+ table_sql, table_params = table.to_sql(alias_registry)
169
+ return table_sql, table_params, alias
170
+
171
+ def _prepare_join_entries(
172
+ self,
173
+ alias_registry: AliasRegistry,
174
+ ) -> list[
175
+ tuple[str, Table, Condition | None, str, tuple[Any, ...], Alias]
176
+ ]:
177
+ join_entries: list[
178
+ tuple[str, Table, Condition | None, str, tuple[Any, ...], Alias]
179
+ ] = []
180
+
181
+ for join_type, join_table, condition in self._joins:
182
+ join_sql, join_params, alias = self._prepare_table_entry(
183
+ join_table,
184
+ alias_registry,
185
+ )
186
+ join_entries.append(
187
+ (
188
+ join_type,
189
+ join_table,
190
+ condition,
191
+ join_sql,
192
+ join_params,
193
+ alias,
194
+ ),
195
+ )
196
+
197
+ return join_entries
198
+
199
+ def _build_joins_from_entries(
200
+ self,
201
+ alias_registry: AliasRegistry,
202
+ join_entries: list[
203
+ tuple[str, Table, Condition | None, str, tuple[Any, ...], Alias]
204
+ ],
205
+ ) -> tuple[str, list[Any]]:
206
+ """Build JOIN clauses and return SQL string and parameters."""
207
+ joins_sql_parts: list[str] = []
208
+ joins_params: list[Any] = []
209
+
210
+ for (
211
+ join_type,
212
+ _join_table,
213
+ condition,
214
+ join_sql,
215
+ join_params,
216
+ alias,
217
+ ) in join_entries:
218
+ join_body = f'{join_sql} AS "{alias.name}"'
219
+ joins_params.extend(join_params)
220
+
221
+ # CROSS JOIN doesn't have an ON clause
222
+ if condition is not None:
223
+ condition_sql, condition_params = condition.to_sql(
224
+ alias_registry,
225
+ )
226
+ join_body += f" ON {condition_sql}"
227
+ joins_params.extend(condition_params)
228
+
229
+ joins_sql_parts.append(
230
+ self._build_clause(
231
+ "JOIN",
232
+ f"{join_type} JOIN",
233
+ join_body,
234
+ ),
235
+ )
236
+
237
+ return " ".join(joins_sql_parts), joins_params
238
+
239
+ def join(
240
+ self,
241
+ table: Table | AbstractQuery,
242
+ condition: Condition,
243
+ ) -> Self:
244
+ """Add an INNER JOIN clause."""
245
+ qs = copy(self)
246
+ qs._joins = self._joins.copy()
247
+ if isinstance(table, AbstractQuery):
248
+ table = Table(table)
249
+ qs._joins.append(("INNER", table, condition))
250
+ return qs
251
+
252
+ def left_join(
253
+ self,
254
+ table: Table | AbstractQuery,
255
+ condition: Condition,
256
+ ) -> Self:
257
+ """Add a LEFT JOIN clause."""
258
+ qs = copy(self)
259
+ qs._joins = self._joins.copy()
260
+ if isinstance(table, AbstractQuery):
261
+ table = Table(table)
262
+ qs._joins.append(("LEFT", table, condition))
263
+ return qs
264
+
265
+ def right_join(
266
+ self,
267
+ table: Table | AbstractQuery,
268
+ condition: Condition,
269
+ ) -> Self:
270
+ """Add a RIGHT JOIN clause."""
271
+ qs = copy(self)
272
+ qs._joins = self._joins.copy()
273
+ if isinstance(table, AbstractQuery):
274
+ table = Table(table)
275
+ qs._joins.append(("RIGHT", table, condition))
276
+ return qs
277
+
278
+ def full_join(
279
+ self,
280
+ table: Table | AbstractQuery,
281
+ condition: Condition,
282
+ ) -> Self:
283
+ """Add a FULL OUTER JOIN clause."""
284
+ qs = copy(self)
285
+ qs._joins = self._joins.copy()
286
+ if isinstance(table, AbstractQuery):
287
+ table = Table(table)
288
+ qs._joins.append(("FULL OUTER", table, condition))
289
+ return qs
290
+
291
+ def cross_join(self, table: Table | AbstractQuery) -> Self:
292
+ """Add a CROSS JOIN clause (cartesian product)."""
293
+ qs = copy(self)
294
+ qs._joins = self._joins.copy()
295
+ if isinstance(table, AbstractQuery):
296
+ table = Table(table)
297
+ qs._joins.append(("CROSS", table, None))
298
+ return qs
299
+
300
+ def semi_join(
301
+ self,
302
+ table: Table | AbstractQuery,
303
+ condition: Condition,
304
+ ) -> Self:
305
+ """Add a SEMI JOIN clause (exists check)."""
306
+ qs = copy(self)
307
+ qs._joins = self._joins.copy()
308
+ if isinstance(table, AbstractQuery):
309
+ table = Table(table)
310
+ qs._joins.append(("SEMI", table, condition))
311
+ return qs
312
+
313
+ def anti_join(
314
+ self,
315
+ table: Table | AbstractQuery,
316
+ condition: Condition,
317
+ ) -> Self:
318
+ """Add an ANTI JOIN clause (not exists check)."""
319
+ qs = copy(self)
320
+ qs._joins = self._joins.copy()
321
+ if isinstance(table, AbstractQuery):
322
+ table = Table(table)
323
+ qs._joins.append(("ANTI", table, condition))
324
+ return qs
325
+
326
+ def limit(self, n: int) -> Self:
327
+ if n < 0:
328
+ raise ValueError("LIMIT must be non-negative")
329
+ qs = copy(self)
330
+ qs._limit = n
331
+ return qs
332
+
333
+ def offset(self, n: int) -> Self:
334
+ if n < 0:
335
+ raise ValueError("OFFSET must be non-negative")
336
+ qs = copy(self)
337
+ qs._offset = n
338
+ return qs
339
+
340
+ def order_by(
341
+ self,
342
+ *columns: Column | Alias | FunctionCall,
343
+ descending: bool = False,
344
+ ) -> Self:
345
+ if not columns:
346
+ raise ValueError("order_by() requires at least one column")
347
+
348
+ qs = copy(self)
349
+ qs._order_by_columns = self._order_by_columns + tuple(
350
+ (column, descending) for column in columns
351
+ )
352
+ return qs
353
+
354
+ def distinct(self) -> Self:
355
+ """Add DISTINCT clause to select only unique rows."""
356
+ qs = copy(self)
357
+ qs._distinct = True
358
+ return qs
359
+
360
+ def _build_group_by_clause(
361
+ self,
362
+ alias_registry: AliasRegistry | None = None,
363
+ ) -> tuple[str, list[Any]]:
364
+ registry = alias_registry or self._alias_registry
365
+ if self._group_by_type == "all":
366
+ return self._build_clause("GROUP BY", "GROUP BY", "ALL"), []
367
+
368
+ col_refs: str = ", ".join(
369
+ col.get_ref(registry) for col in self._group_by_columns
370
+ )
371
+
372
+ if self._group_by_type == "rollup":
373
+ return (
374
+ self._build_clause(
375
+ "GROUP BY",
376
+ "GROUP BY",
377
+ f"ROLLUP ({col_refs})",
378
+ ),
379
+ [],
380
+ )
381
+
382
+ if self._group_by_type == "cube":
383
+ return (
384
+ self._build_clause(
385
+ "GROUP BY",
386
+ "GROUP BY",
387
+ f"CUBE ({col_refs})",
388
+ ),
389
+ [],
390
+ )
391
+
392
+ if self._group_by_type == "grouping_sets":
393
+ gr_sets = (
394
+ self._extract_col_set_with_registry(col_set, registry)
395
+ if col_set
396
+ else "()"
397
+ for col_set in self._grouping_sets
398
+ )
399
+
400
+ sets_sql: str = ", ".join(gr_sets)
401
+ return (
402
+ self._build_clause(
403
+ "GROUP BY",
404
+ "GROUP BY",
405
+ f"GROUPING SETS ({sets_sql})",
406
+ ),
407
+ [],
408
+ )
409
+
410
+ return self._build_clause("GROUP BY", "GROUP BY", col_refs), []
411
+
412
+ def _extract_col_set(self, col_set: tuple[Column, ...]) -> str:
413
+ return self._extract_col_set_with_registry(
414
+ col_set,
415
+ self._alias_registry,
416
+ )
417
+
418
+ def _extract_col_set_with_registry(
419
+ self,
420
+ col_set: tuple[Column, ...],
421
+ alias_registry: AliasRegistry,
422
+ ) -> str:
423
+ col_gen: list[str] = []
424
+
425
+ for col in col_set:
426
+ col_gen.append(col.get_ref(alias_registry))
427
+
428
+ st = ", ".join(col_gen)
429
+ return f"({st})"
430
+
431
+ def having(self, *conditions: Condition) -> Self:
432
+ if not self._group_by_columns and self._group_by_type == "normal":
433
+ raise ValueError("Cannot use having() without group_by()")
434
+
435
+ qs = copy(self)
436
+ combined_condition: Condition | None = None
437
+
438
+ for condition in conditions:
439
+ if combined_condition is None:
440
+ combined_condition = condition
441
+ else:
442
+ combined_condition = combined_condition & condition
443
+
444
+ if combined_condition:
445
+ if qs._having_condition is None:
446
+ qs._having_condition = combined_condition
447
+ else:
448
+ qs._having_condition = (
449
+ qs._having_condition & combined_condition
450
+ )
451
+
452
+ return qs
453
+
454
+ def having_by(self, **kwargs: Any) -> Self:
455
+ if not self._group_by_columns and self._group_by_type == "normal":
456
+ raise ValueError("Cannot use having_by() without group_by()")
457
+
458
+ qs = copy(self)
459
+ combined_condition: Condition | None = None
460
+ table = self._get_table()
461
+ qs._alias_registry.get_alias_for_table(table)
462
+
463
+ for key, value in kwargs.items():
464
+ col: Column = Column(key)
465
+ col._attach_table(table) # pyright: ignore[reportPrivateUsage]
466
+ condition = Condition(
467
+ column=col,
468
+ operator=EqualOperator,
469
+ value=value,
470
+ )
471
+ if combined_condition is None:
472
+ combined_condition = condition
473
+ else:
474
+ combined_condition = combined_condition & condition
475
+
476
+ if combined_condition:
477
+ if qs._having_condition is None:
478
+ qs._having_condition = combined_condition
479
+ else:
480
+ qs._having_condition = (
481
+ qs._having_condition & combined_condition
482
+ )
483
+
484
+ return qs
485
+
486
+ def group_by(self, *columns: Column) -> Self:
487
+ qs = copy(self)
488
+ if not columns:
489
+ qs._group_by_type = "all"
490
+ else:
491
+ qs._group_by_columns = columns
492
+ qs._group_by_type = "normal"
493
+ return qs
494
+
495
+ def group_by_rollup(self, *columns: Column) -> Self:
496
+ if not columns:
497
+ raise ValueError("group_by_rollup() requires at least one column")
498
+
499
+ qs = copy(self)
500
+ qs._group_by_columns = columns
501
+ qs._group_by_type = "rollup"
502
+ return qs
503
+
504
+ def group_by_cube(self, *columns: Column) -> Self:
505
+ if not columns:
506
+ raise ValueError("group_by_cube() requires at least one column")
507
+
508
+ qs = copy(self)
509
+ qs._group_by_columns = columns
510
+ qs._group_by_type = "cube"
511
+ return qs
512
+
513
+ def group_by_grouping_sets(self, *column_sets: tuple[Column, ...]) -> Self:
514
+ if not column_sets:
515
+ raise ValueError(
516
+ "group_by_grouping_sets() requires at least one set",
517
+ )
518
+
519
+ qs = copy(self)
520
+ qs._group_by_type = "grouping_sets"
521
+ qs._grouping_sets = column_sets
522
+ return qs
523
+
524
+ def from_(self, table: Table | AbstractQuery) -> Self:
525
+ qs = copy(self)
526
+ qs._table = table if isinstance(table, Table) else Table(table)
527
+ return qs
@@ -0,0 +1,76 @@
1
+ from typing import Any, Self
2
+
3
+ from sql_fusion.composite_table import (
4
+ AbstractQuery,
5
+ AliasRegistry,
6
+ BinaryExpression,
7
+ Column,
8
+ FunctionCall,
9
+ Table,
10
+ )
11
+
12
+
13
+ class update(AbstractQuery):
14
+ def __init__(self, table: Table) -> None:
15
+ super().__init__(table=table, columns=())
16
+ self._values: dict[str, Any] = {}
17
+
18
+ def set(self, **values: Any) -> Self:
19
+ if not values:
20
+ raise ValueError("No values provided for update")
21
+ self._values.update(values)
22
+ return self
23
+
24
+ def build_query(
25
+ self,
26
+ alias_registry: AliasRegistry | None = None,
27
+ ) -> tuple[str, tuple[Any, ...]]:
28
+ if not self._values:
29
+ raise ValueError("No values provided for update")
30
+
31
+ registry = alias_registry or self._alias_registry
32
+ table = self._get_table()
33
+ with_sql, with_params = self._build_with_clause(registry)
34
+ alias = registry.get_alias_for_table(table)
35
+ assignments: list[str] = []
36
+ params: list[Any] = []
37
+
38
+ for column_name, value in self._values.items():
39
+ column_ref = f'"{column_name}"'
40
+
41
+ if isinstance(value, Column):
42
+ assignments.append(
43
+ f"{column_ref} = {value.get_ref(registry)}",
44
+ )
45
+ elif isinstance(value, (FunctionCall, BinaryExpression)):
46
+ value_sql, value_params = value.to_sql(registry)
47
+ assignments.append(f"{column_ref} = {value_sql}")
48
+ params.extend(value_params)
49
+ elif isinstance(value, AbstractQuery):
50
+ value_sql, value_params = value.build_query(registry)
51
+ assignments.append(f"{column_ref} = ({value_sql})")
52
+ params.extend(value_params)
53
+ else:
54
+ assignments.append(f"{column_ref} = ?")
55
+ params.append(value)
56
+
57
+ set_clause = self._build_clause(
58
+ "SET",
59
+ "SET",
60
+ ", ".join(assignments),
61
+ )
62
+ query = self._build_clause(
63
+ "UPDATE",
64
+ "UPDATE",
65
+ f'"{table.get_name()}" AS "{alias.name}" {set_clause}',
66
+ )
67
+
68
+ if self._where_condition:
69
+ where_sql, where_params = self._where_condition.to_sql(registry)
70
+ query += f" {self._build_clause('WHERE', 'WHERE', where_sql)}"
71
+ params.extend(where_params)
72
+
73
+ return self._apply_compile_expressions(
74
+ f"{with_sql} {query}" if with_sql else query,
75
+ tuple(with_params + params),
76
+ )