sqlpiston 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,581 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Tuple, Type, Union
3
+
4
+ from sqlpiston.builder.nodes import (
5
+ ASTNode, BetweenNode, CaseNode, ComparisonNode, ExistsNode, Field,
6
+ InNode, LogicalNode, SQLFunction, ExprValue,
7
+ )
8
+ from sqlpiston.builder.selectable import CompoundSelect, CTE, Select
9
+ from sqlpiston.builder.dml import Delete, Insert, Update, Upsert
10
+ from sqlpiston.builder.ddl import (
11
+ AlterAction, AlterTable, ColumnDef, CreateIndex, CreateTable,
12
+ CreateView, DropIndex, DropTable, DropView, Truncate,
13
+ )
14
+
15
+
16
+ class Compiler(ABC):
17
+ """Base compiler. process() dispatches node type → visit_* method."""
18
+
19
+ def process(self, node: ASTNode) -> Tuple[str, Tuple[ExprValue, ...]]:
20
+ if isinstance(node, Select):
21
+ return self.visit_select(node)
22
+ if isinstance(node, CompoundSelect):
23
+ return self.visit_compound_select(node)
24
+ if isinstance(node, CTE):
25
+ return self.visit_cte(node)
26
+ if isinstance(node, Insert):
27
+ return self.visit_insert(node)
28
+ if isinstance(node, Update):
29
+ return self.visit_update(node)
30
+ if isinstance(node, Delete):
31
+ return self.visit_delete(node)
32
+ if isinstance(node, Upsert):
33
+ return self.visit_upsert(node)
34
+ if isinstance(node, CreateTable):
35
+ return self.visit_create_table(node)
36
+ if isinstance(node, AlterTable):
37
+ return self.visit_alter_table(node)
38
+ if isinstance(node, DropTable):
39
+ return self.visit_drop_table(node)
40
+ if isinstance(node, CreateIndex):
41
+ return self.visit_create_index(node)
42
+ if isinstance(node, DropIndex):
43
+ return self.visit_drop_index(node)
44
+ if isinstance(node, CreateView):
45
+ return self.visit_create_view(node)
46
+ if isinstance(node, DropView):
47
+ return self.visit_drop_view(node)
48
+ if isinstance(node, Truncate):
49
+ return self.visit_truncate(node)
50
+ if isinstance(node, ComparisonNode):
51
+ return self.visit_comparison(node)
52
+ if isinstance(node, InNode):
53
+ return self.visit_in(node)
54
+ if isinstance(node, BetweenNode):
55
+ return self.visit_between(node)
56
+ if isinstance(node, LogicalNode):
57
+ return self.visit_logical(node)
58
+ if isinstance(node, ExistsNode):
59
+ return self.visit_exists(node)
60
+ if isinstance(node, CaseNode):
61
+ return self.visit_case(node)
62
+ if isinstance(node, SQLFunction):
63
+ return self.visit_function(node)
64
+ raise TypeError(f"Unknown AST node type: {type(node).__name__}")
65
+
66
+ # -- DQL --
67
+
68
+ @abstractmethod
69
+ def visit_select(self, node: Select) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover — abstract stub
70
+ @abstractmethod
71
+ def visit_compound_select(self, node: CompoundSelect) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
72
+ @abstractmethod
73
+ def visit_cte(self, node: CTE) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
74
+ @abstractmethod
75
+ def visit_insert(self, node: Insert) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
76
+ @abstractmethod
77
+ def visit_update(self, node: Update) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
78
+ @abstractmethod
79
+ def visit_delete(self, node: Delete) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
80
+ @abstractmethod
81
+ def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
82
+ @abstractmethod
83
+ def visit_create_table(self, node: CreateTable) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
84
+ @abstractmethod
85
+ def visit_alter_table(self, node: AlterTable) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
86
+ @abstractmethod
87
+ def visit_drop_table(self, node: DropTable) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
88
+ @abstractmethod
89
+ def visit_create_index(self, node: CreateIndex) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
90
+ @abstractmethod
91
+ def visit_drop_index(self, node: DropIndex) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
92
+ @abstractmethod
93
+ def visit_create_view(self, node: CreateView) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
94
+ @abstractmethod
95
+ def visit_drop_view(self, node: DropView) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
96
+ @abstractmethod
97
+ def visit_truncate(self, node: Truncate) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
98
+ @abstractmethod
99
+ def visit_comparison(self, node: ComparisonNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
100
+ @abstractmethod
101
+ def visit_in(self, node: InNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
102
+ @abstractmethod
103
+ def visit_between(self, node: BetweenNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
104
+ @abstractmethod
105
+ def visit_logical(self, node: LogicalNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
106
+ @abstractmethod
107
+ def visit_exists(self, node: ExistsNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
108
+ @abstractmethod
109
+ def visit_case(self, node: CaseNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
110
+ @abstractmethod
111
+ def visit_function(self, node: SQLFunction) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
112
+ @abstractmethod
113
+ def placeholder(self) -> str: ... # pragma: no cover — abstract stub
114
+ @abstractmethod
115
+ def quote_identifier(self, name: str) -> str: ... # pragma: no cover — abstract stub
116
+
117
+ # -- Shared helpers --
118
+
119
+ def compile_field(self, field: Field) -> str:
120
+ """Render table-qualified field name with proper quoting."""
121
+ parts: List[str] = []
122
+ if field.table:
123
+ parts.append(self.quote_identifier(field.table))
124
+ parts.append(self.quote_identifier(field.name))
125
+ result = ".".join(parts)
126
+ if field._alias_prop:
127
+ result = f"{result} AS {self.quote_identifier(field._alias_prop)}"
128
+ return result
129
+
130
+ def compile_from(self, from_src: Union[str, Select]) -> Tuple[str, Tuple[ExprValue, ...]]:
131
+ """Compile FROM clause source."""
132
+ if isinstance(from_src, str):
133
+ return self.quote_identifier(from_src), ()
134
+ # Subquery
135
+ sql, params = self.visit_select(from_src)
136
+ alias = from_src._alias or "sub"
137
+ return f"({sql}) AS {self.quote_identifier(alias)}", params
138
+
139
+ def compile_joins(self, joins: List[Tuple[str, str, ASTNode]]) -> Tuple[str, Tuple[ExprValue, ...]]:
140
+ """Compile JOIN clauses."""
141
+ parts: List[str] = []
142
+ params: List[ExprValue] = []
143
+ for table, how, on in joins:
144
+ if how == 'CROSS':
145
+ parts.append(f"CROSS JOIN {self.quote_identifier(table)}")
146
+ else:
147
+ on_sql, on_params = self.process(on)
148
+ parts.append(f"{how} JOIN {self.quote_identifier(table)} ON {on_sql}")
149
+ params.extend(on_params)
150
+ return " ".join(parts), tuple(params)
151
+
152
+ def compile_condition(self, node: ASTNode) -> Tuple[str, Tuple[ExprValue, ...]]:
153
+ """Compile a WHERE/HAVING/ON condition."""
154
+ return self.process(node)
155
+
156
+ def compile_order_by(self, orders: List[Tuple[Union[str, Field], str]]) -> str:
157
+ """Compile ORDER BY clause."""
158
+ parts: List[str] = []
159
+ for field, direction in orders:
160
+ if isinstance(field, str):
161
+ parts.append(f"{self.quote_identifier(field)} {direction}")
162
+ else:
163
+ parts.append(f"{self.compile_field(field)} {direction}")
164
+ return "ORDER BY " + ", ".join(parts)
165
+
166
+ def compile_group_by(self, groups: List[Union[str, Field]]) -> str:
167
+ """Compile GROUP BY clause."""
168
+ parts: List[str] = []
169
+ for g in groups:
170
+ if isinstance(g, str):
171
+ parts.append(self.quote_identifier(g))
172
+ else:
173
+ parts.append(self.compile_field(g))
174
+ return "GROUP BY " + ", ".join(parts)
175
+
176
+ def collect_params(self, *results: Tuple[str, Tuple[ExprValue, ...]]) -> Tuple[str, Tuple[ExprValue, ...]]:
177
+ """Join SQL fragments and concatenate params."""
178
+ sql_parts: List[str] = []
179
+ all_params: List[ExprValue] = []
180
+ for sql, params in results:
181
+ if sql:
182
+ sql_parts.append(sql)
183
+ if params:
184
+ all_params.extend(params)
185
+ return " ".join(sql_parts), tuple(all_params)
186
+
187
+
188
+ class GenericCompiler(Compiler):
189
+ """Platform-agnostic compiler. Uses %s and backtick quoting. Serves as baseline."""
190
+
191
+ def placeholder(self) -> str:
192
+ return '%s'
193
+
194
+ def quote_identifier(self, name: str) -> str:
195
+ return f'`{name}`'
196
+
197
+ # -- Expression nodes --
198
+
199
+ def visit_comparison(self, node: ComparisonNode) -> Tuple[str, Tuple[ExprValue, ...]]:
200
+ field_str = self.compile_field(node.field)
201
+
202
+ # IS NULL / IS NOT NULL — no parameter
203
+ if node.operator in ('IS NULL', 'IS NOT NULL'):
204
+ return f"{field_str} {node.operator}", ()
205
+
206
+ value = node.value
207
+ # Field vs Field — table-qualified comparison
208
+ if isinstance(value, Field):
209
+ return f"{field_str} {node.operator} {self.compile_field(value)}", ()
210
+
211
+ # Field vs Select — scalar subquery
212
+ if isinstance(value, Select):
213
+ sub_sql, sub_params = self.visit_select(value)
214
+ return f"{field_str} {node.operator} ({sub_sql})", sub_params
215
+
216
+ # Field vs SQLFunction
217
+ if isinstance(value, SQLFunction):
218
+ func_sql, func_params = self.visit_function(value)
219
+ return f"{field_str} {node.operator} {func_sql}", func_params
220
+
221
+ # Field vs literal
222
+ return f"{field_str} {node.operator} {self.placeholder()}", (value,)
223
+
224
+ def visit_in(self, node: InNode) -> Tuple[str, Tuple[ExprValue, ...]]:
225
+ field_str = self.compile_field(node.field)
226
+ values = node.values
227
+
228
+ # Subquery
229
+ if isinstance(values, Select):
230
+ sub_sql, sub_params = self.visit_select(values)
231
+ return f"{field_str} IN ({sub_sql})", sub_params
232
+
233
+ # Literal list
234
+ placeholders = ", ".join([self.placeholder()] * len(values))
235
+ return f"{field_str} IN ({placeholders})", tuple(values)
236
+
237
+ def visit_between(self, node: BetweenNode) -> Tuple[str, Tuple[ExprValue, ...]]:
238
+ field_str = self.compile_field(node.field)
239
+ return (
240
+ f"{field_str} BETWEEN {self.placeholder()} AND {self.placeholder()}",
241
+ (node.low, node.high),
242
+ )
243
+
244
+ def visit_logical(self, node: LogicalNode) -> Tuple[str, Tuple[ExprValue, ...]]:
245
+ if node.operator == 'NOT':
246
+ child_sql, child_params = self.process(node.children[0])
247
+ return f"NOT ({child_sql})", child_params
248
+
249
+ # AND / OR
250
+ parts: List[str] = []
251
+ params: List[ExprValue] = []
252
+ for child in node.children:
253
+ child_sql, child_params = self.process(child)
254
+ # Wrap non-leaf nodes in parens
255
+ if isinstance(child, LogicalNode) and child.operator != 'NOT':
256
+ child_sql = f"({child_sql})"
257
+ parts.append(child_sql)
258
+ params.extend(child_params)
259
+
260
+ separator = f" {node.operator} "
261
+ combined = separator.join(parts)
262
+ if len(parts) > 1:
263
+ combined = f"({combined})"
264
+ return combined, tuple(params)
265
+
266
+ def visit_exists(self, node: ExistsNode) -> Tuple[str, Tuple[ExprValue, ...]]:
267
+ sub_sql, sub_params = self.visit_select(node.select)
268
+ keyword = "NOT EXISTS" if node.negated else "EXISTS"
269
+ return f"{keyword} ({sub_sql})", sub_params
270
+
271
+ def visit_case(self, node: CaseNode) -> Tuple[str, Tuple[ExprValue, ...]]:
272
+ parts: List[str] = ["CASE"]
273
+ params: List[ExprValue] = []
274
+ for condition, result in node._whens:
275
+ cond_sql, cond_params = self.process(condition)
276
+ parts.append(f"WHEN {cond_sql} THEN {self.placeholder()}")
277
+ params.extend(cond_params)
278
+ params.append(result)
279
+ if node._else is not None:
280
+ parts.append(f"ELSE {self.placeholder()}")
281
+ params.append(node._else)
282
+ parts.append("END")
283
+ return " ".join(parts), tuple(params)
284
+
285
+ def visit_function(self, node: SQLFunction) -> Tuple[str, Tuple[ExprValue, ...]]:
286
+ params: List[ExprValue] = []
287
+ arg_parts: List[str] = []
288
+ for arg in node.args:
289
+ if isinstance(arg, str) and arg == "*":
290
+ arg_parts.append("*")
291
+ elif isinstance(arg, Field):
292
+ arg_parts.append(self.compile_field(arg))
293
+ elif isinstance(arg, SQLFunction):
294
+ sub_sql, sub_params = self.visit_function(arg)
295
+ arg_parts.append(sub_sql)
296
+ params.extend(sub_params)
297
+ else:
298
+ arg_parts.append(self.placeholder())
299
+ params.append(arg)
300
+
301
+ func_sql = f"{node.name}({', '.join(arg_parts)})"
302
+ if node._alias:
303
+ func_sql = f"{func_sql} AS {self.quote_identifier(node._alias)}"
304
+ return func_sql, tuple(params)
305
+
306
+ # -- DQL --
307
+
308
+ def visit_select(self, node: Select) -> Tuple[str, Tuple[ExprValue, ...]]:
309
+ sql_parts: List[str] = []
310
+ all_params: List[ExprValue] = []
311
+
312
+ # CTEs
313
+ if node._ctes:
314
+ cte_parts: List[str] = []
315
+ for cte in node._ctes:
316
+ cte_sql, cte_params = self.visit_cte(cte)
317
+ cte_parts.append(cte_sql)
318
+ all_params.extend(cte_params)
319
+ sql_parts.append("WITH " + ", ".join(cte_parts))
320
+
321
+ # SELECT [DISTINCT]
322
+ cols: List[str] = []
323
+ for col in node._columns:
324
+ if isinstance(col, str):
325
+ if col == "*":
326
+ cols.append("*")
327
+ else:
328
+ cols.append(self.quote_identifier(col))
329
+ elif isinstance(col, Field):
330
+ cols.append(self.compile_field(col))
331
+ elif isinstance(col, Select):
332
+ sub_sql, sub_params = self.process(col)
333
+ cols.append(f"({sub_sql})")
334
+ all_params.extend(sub_params)
335
+ elif isinstance(col, SQLFunction):
336
+ func_sql, func_params = self.visit_function(col)
337
+ cols.append(func_sql)
338
+ all_params.extend(func_params)
339
+ distinct = "DISTINCT " if node._distinct else ""
340
+ sql_parts.append(f"SELECT {distinct}{', '.join(cols)}")
341
+
342
+ # FROM
343
+ if node._from:
344
+ from_sql, from_params = self.compile_from(node._from)
345
+ sql_parts.append(f"FROM {from_sql}")
346
+ all_params.extend(from_params)
347
+
348
+ # JOINs
349
+ if node._joins:
350
+ join_sql, join_params = self.compile_joins(node._joins)
351
+ sql_parts.append(join_sql)
352
+ all_params.extend(join_params)
353
+
354
+ # WHERE
355
+ if node._where:
356
+ where_sql, where_params = self.compile_condition(node._where)
357
+ sql_parts.append(f"WHERE {where_sql}")
358
+ all_params.extend(where_params)
359
+
360
+ # GROUP BY
361
+ if node._group_by:
362
+ sql_parts.append(self.compile_group_by(node._group_by))
363
+
364
+ # HAVING
365
+ if node._having:
366
+ having_sql, having_params = self.compile_condition(node._having)
367
+ sql_parts.append(f"HAVING {having_sql}")
368
+ all_params.extend(having_params)
369
+
370
+ # ORDER BY
371
+ if node._order_by:
372
+ sql_parts.append(self.compile_order_by(node._order_by))
373
+
374
+ # LIMIT
375
+ if node._limit is not None:
376
+ sql_parts.append(f"LIMIT {self.placeholder()}")
377
+ all_params.append(node._limit)
378
+
379
+ # OFFSET
380
+ if node._offset is not None:
381
+ sql_parts.append(f"OFFSET {self.placeholder()}")
382
+ all_params.append(node._offset)
383
+
384
+ return " ".join(sql_parts), tuple(all_params)
385
+
386
+ def visit_compound_select(self, node: CompoundSelect) -> Tuple[str, Tuple[ExprValue, ...]]:
387
+ left_sql, left_params = self.visit_select(node.left)
388
+ right_sql, right_params = self.visit_select(node.right)
389
+ return (
390
+ f"({left_sql}) {node.operator} ({right_sql})",
391
+ left_params + right_params,
392
+ )
393
+
394
+ def visit_cte(self, node: CTE) -> Tuple[str, Tuple[ExprValue, ...]]:
395
+ sub_sql, sub_params = self.visit_select(node.select)
396
+ return f"{self.quote_identifier(node.name)} AS ({sub_sql})", sub_params
397
+
398
+ # -- DML --
399
+
400
+ def visit_insert(self, node: Insert) -> Tuple[str, Tuple[ExprValue, ...]]:
401
+ if node._table is None:
402
+ raise ValueError("INSERT requires a table name")
403
+
404
+ table = self.quote_identifier(node._table)
405
+
406
+ if node._select is not None:
407
+ sub_sql, sub_params = self.visit_select(node._select)
408
+ return f"INSERT INTO {table} {sub_sql}", sub_params
409
+
410
+ if node._data is None:
411
+ raise ValueError("INSERT requires values() or select()")
412
+
413
+ cols = ", ".join(self.quote_identifier(k) for k in node._data.keys())
414
+ placeholders = ", ".join([self.placeholder()] * len(node._data))
415
+ return (
416
+ f"INSERT INTO {table} ({cols}) VALUES ({placeholders})",
417
+ tuple(node._data.values()),
418
+ )
419
+
420
+ def visit_update(self, node: Update) -> Tuple[str, Tuple[ExprValue, ...]]:
421
+ if node._table is None:
422
+ raise ValueError("UPDATE requires a table name")
423
+ if node._data is None:
424
+ raise ValueError("UPDATE requires set() data")
425
+
426
+ table = self.quote_identifier(node._table)
427
+ set_parts: List[str] = []
428
+ params: List[ExprValue] = []
429
+ for col, val in node._data.items():
430
+ set_parts.append(f"{self.quote_identifier(col)} = {self.placeholder()}")
431
+ params.append(val)
432
+
433
+ sql = f"UPDATE {table} SET {', '.join(set_parts)}"
434
+
435
+ if node._where:
436
+ where_sql, where_params = self.process(node._where)
437
+ sql += f" WHERE {where_sql}"
438
+ params.extend(where_params)
439
+
440
+ return sql, tuple(params)
441
+
442
+ def visit_delete(self, node: Delete) -> Tuple[str, Tuple[ExprValue, ...]]:
443
+ if node._table is None:
444
+ raise ValueError("DELETE requires a table name")
445
+
446
+ table = self.quote_identifier(node._table)
447
+ sql = f"DELETE FROM {table}"
448
+ params: List[ExprValue] = []
449
+
450
+ if node._where:
451
+ where_sql, where_params = self.process(node._where)
452
+ sql += f" WHERE {where_sql}"
453
+ params.extend(where_params)
454
+
455
+ return sql, tuple(params)
456
+
457
+ def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]:
458
+ raise NotImplementedError("UPSERT must be compiled by a dialect-specific compiler")
459
+
460
+ # -- DDL --
461
+
462
+ def visit_create_table(self, node: CreateTable) -> Tuple[str, Tuple[ExprValue, ...]]:
463
+ if node._table is None:
464
+ raise ValueError("CREATE TABLE requires a table name")
465
+
466
+ table = self.quote_identifier(node._table)
467
+ if_not_exists = "IF NOT EXISTS " if node._if_not_exists else ""
468
+
469
+ col_defs: List[str] = []
470
+ params: List[ExprValue] = []
471
+ for col in node._columns:
472
+ col_parts = [self.quote_identifier(col.name), col.type_]
473
+ if not col.nullable:
474
+ col_parts.append("NOT NULL")
475
+ if col.default is not None:
476
+ col_parts.append(f"DEFAULT {self.placeholder()}")
477
+ params.append(col.default)
478
+ if col.primary_key:
479
+ col_parts.append("PRIMARY KEY")
480
+ if col.unique:
481
+ col_parts.append("UNIQUE")
482
+ col_defs.append(" ".join(col_parts))
483
+
484
+ return (
485
+ f"CREATE TABLE {if_not_exists}{table} ({', '.join(col_defs)})",
486
+ tuple(params),
487
+ )
488
+
489
+ def visit_alter_table(self, node: AlterTable) -> Tuple[str, Tuple[ExprValue, ...]]:
490
+ if node._table is None:
491
+ raise ValueError("ALTER TABLE requires a table name")
492
+
493
+ table = self.quote_identifier(node._table)
494
+ sql_parts: List[str] = []
495
+ all_params: List[ExprValue] = []
496
+
497
+ for action, col_name, col_type, col_def in node._actions:
498
+ if action == AlterAction.ADD:
499
+ sql_parts.append(f"ALTER TABLE {table} ADD COLUMN {self._compile_column_def(col_def)}")
500
+ if col_def is not None and col_def.default is not None:
501
+ all_params.append(col_def.default)
502
+ elif action == AlterAction.DROP:
503
+ sql_parts.append(f"ALTER TABLE {table} DROP COLUMN {self.quote_identifier(col_name)}")
504
+ elif action == AlterAction.MODIFY:
505
+ sql_parts.append(f"ALTER TABLE {table} MODIFY COLUMN {self._compile_column_def(col_def)}")
506
+ if col_def is not None and col_def.default is not None:
507
+ all_params.append(col_def.default)
508
+
509
+ return "; ".join(sql_parts), tuple(all_params)
510
+
511
+ def _compile_column_def(self, col_def: Optional[ColumnDef]) -> str:
512
+ if col_def is None: # pragma: no cover — defensive guard, never called with None
513
+ return ""
514
+ parts = [self.quote_identifier(col_def.name), col_def.type_]
515
+ if not col_def.nullable:
516
+ parts.append("NOT NULL")
517
+ if col_def.default is not None:
518
+ parts.append(f"DEFAULT {self.placeholder()}")
519
+ if col_def.unique:
520
+ parts.append("UNIQUE")
521
+ return " ".join(parts)
522
+
523
+ def visit_drop_table(self, node: DropTable) -> Tuple[str, Tuple[ExprValue, ...]]:
524
+ if node._table is None:
525
+ raise ValueError("DROP TABLE requires a table name")
526
+ if_exists = "IF EXISTS " if node._if_exists else ""
527
+ return f"DROP TABLE {if_exists}{self.quote_identifier(node._table)}", ()
528
+
529
+ def visit_create_index(self, node: CreateIndex) -> Tuple[str, Tuple[ExprValue, ...]]:
530
+ if node._name is None or node._table is None:
531
+ raise ValueError("CREATE INDEX requires name and table")
532
+ unique = "UNIQUE " if node._unique else ""
533
+ if_not_exists = "IF NOT EXISTS " if node._if_not_exists else ""
534
+ cols = ", ".join(self.quote_identifier(c) for c in node._columns)
535
+ return (
536
+ f"CREATE {unique}INDEX {if_not_exists}{self.quote_identifier(node._name)} "
537
+ f"ON {self.quote_identifier(node._table)} ({cols})",
538
+ (),
539
+ )
540
+
541
+ def visit_drop_index(self, node: DropIndex) -> Tuple[str, Tuple[ExprValue, ...]]:
542
+ if node._name is None:
543
+ raise ValueError("DROP INDEX requires an index name")
544
+ if_exists = "IF EXISTS " if node._if_exists else ""
545
+ sql = f"DROP INDEX {if_exists}{self.quote_identifier(node._name)}"
546
+ if node._table:
547
+ sql += f" ON {self.quote_identifier(node._table)}"
548
+ return sql, ()
549
+
550
+ def visit_create_view(self, node: CreateView) -> Tuple[str, Tuple[ExprValue, ...]]:
551
+ if node._name is None or node._select is None:
552
+ raise ValueError("CREATE VIEW requires name and AS SELECT")
553
+ if_not_exists = "IF NOT EXISTS " if node._if_not_exists else ""
554
+ sub_sql, sub_params = self.visit_select(node._select)
555
+ return (
556
+ f"CREATE VIEW {if_not_exists}{self.quote_identifier(node._name)} AS {sub_sql}",
557
+ sub_params,
558
+ )
559
+
560
+ def visit_drop_view(self, node: DropView) -> Tuple[str, Tuple[ExprValue, ...]]:
561
+ if node._name is None:
562
+ raise ValueError("DROP VIEW requires a view name")
563
+ if_exists = "IF EXISTS " if node._if_exists else ""
564
+ return f"DROP VIEW {if_exists}{self.quote_identifier(node._name)}", ()
565
+
566
+ def visit_truncate(self, node: Truncate) -> Tuple[str, Tuple[ExprValue, ...]]:
567
+ if node._table is None:
568
+ raise ValueError("TRUNCATE requires a table name")
569
+ return f"TRUNCATE TABLE {self.quote_identifier(node._table)}", ()
570
+
571
+
572
+ class Dialect:
573
+ """Holds compiler factory and syntax config for a database."""
574
+
575
+ def __init__(self, placeholder: str, quote_char: str, compiler_cls: Type[Compiler]) -> None:
576
+ self.placeholder = placeholder
577
+ self.quote_char = quote_char
578
+ self._compiler_cls = compiler_cls
579
+
580
+ def get_compiler(self) -> Compiler:
581
+ return self._compiler_cls()
@@ -0,0 +1,50 @@
1
+ from typing import List, Tuple
2
+
3
+ from sqlpiston.builder.dml import Upsert
4
+ from sqlpiston.compiler.base import Dialect, GenericCompiler
5
+ from sqlpiston.builder.nodes import ExprValue
6
+ from sqlpiston._types import ColumnValue
7
+
8
+
9
+ class MySQLCompiler(GenericCompiler):
10
+ """MySQL dialect compiler. %s placeholders, `backtick` quoting."""
11
+
12
+ def placeholder(self) -> str:
13
+ return '%s'
14
+
15
+ def quote_identifier(self, name: str) -> str:
16
+ return f'`{name}`'
17
+
18
+ def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]:
19
+ if node._table is None or node._data is None:
20
+ raise ValueError("UPSERT requires table and values")
21
+
22
+ table = self.quote_identifier(node._table)
23
+ cols = ", ".join(self.quote_identifier(k) for k in node._data.keys())
24
+ placeholders = ", ".join([self.placeholder()] * len(node._data))
25
+ sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
26
+
27
+ if node._do_nothing and node._conflict_columns:
28
+ conflict_cols = ", ".join(self.quote_identifier(c) for c in node._conflict_columns)
29
+ sql += f" ON DUPLICATE KEY UPDATE {conflict_cols} = {conflict_cols}"
30
+ return sql, tuple(node._data.values())
31
+
32
+ if node._update_data:
33
+ update_parts: List[str] = []
34
+ params: List[ColumnValue] = list(node._data.values())
35
+ for col in node._update_data:
36
+ update_parts.append(f"{self.quote_identifier(col)} = VALUES({self.quote_identifier(col)})")
37
+ sql += " ON DUPLICATE KEY UPDATE " + ", ".join(update_parts)
38
+ return sql, tuple(params)
39
+
40
+ if node._do_nothing:
41
+ # INSERT IGNORE as fallback
42
+ sql = sql.replace("INSERT INTO", "INSERT IGNORE INTO")
43
+ return sql, tuple(node._data.values())
44
+
45
+ return sql, tuple(node._data.values())
46
+
47
+
48
+ class MySQLDialect(Dialect):
49
+ def __init__(self) -> None:
50
+ super().__init__(placeholder='%s', quote_char='`', compiler_cls=MySQLCompiler)
@@ -0,0 +1,51 @@
1
+ from typing import List, Tuple
2
+
3
+ from sqlpiston.builder.dml import Upsert
4
+ from sqlpiston.compiler.base import Dialect, GenericCompiler
5
+ from sqlpiston.builder.nodes import ExprValue
6
+ from sqlpiston._types import ColumnValue
7
+
8
+
9
+ class SQLiteCompiler(GenericCompiler):
10
+ """SQLite dialect compiler. ? placeholders, "double-quote" quoting."""
11
+
12
+ def placeholder(self) -> str:
13
+ return '?'
14
+
15
+ def quote_identifier(self, name: str) -> str:
16
+ return f'"{name}"'
17
+
18
+ def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]:
19
+ if node._table is None or node._data is None:
20
+ raise ValueError("UPSERT requires table and values")
21
+
22
+ table = self.quote_identifier(node._table)
23
+ cols = ", ".join(self.quote_identifier(k) for k in node._data.keys())
24
+ placeholders = ", ".join([self.placeholder()] * len(node._data))
25
+ sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
26
+
27
+ if node._do_nothing and node._conflict_columns:
28
+ conflict_cols = ", ".join(self.quote_identifier(c) for c in node._conflict_columns)
29
+ sql += f" ON CONFLICT ({conflict_cols}) DO NOTHING"
30
+ return sql, tuple(node._data.values())
31
+
32
+ if node._update_data and node._conflict_columns:
33
+ conflict_cols = ", ".join(self.quote_identifier(c) for c in node._conflict_columns)
34
+ update_parts: List[str] = []
35
+ update_params: List[ColumnValue] = []
36
+ for col, val in node._update_data.items():
37
+ update_parts.append(f"{self.quote_identifier(col)} = {self.placeholder()}")
38
+ update_params.append(val)
39
+ sql += f" ON CONFLICT ({conflict_cols}) DO UPDATE SET {', '.join(update_parts)}"
40
+ return sql, tuple(list(node._data.values()) + update_params)
41
+
42
+ if node._do_nothing:
43
+ sql += " ON CONFLICT DO NOTHING"
44
+ return sql, tuple(node._data.values())
45
+
46
+ return sql, tuple(node._data.values())
47
+
48
+
49
+ class SQLiteDialect(Dialect):
50
+ def __init__(self) -> None:
51
+ super().__init__(placeholder='?', quote_char='"', compiler_cls=SQLiteCompiler)
File without changes
@@ -0,0 +1,3 @@
1
+ from sqlpiston.core.engine.base import DBEngine, DBType
2
+
3
+ __all__ = ['DBEngine', 'DBType']