sqlpiston 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlpiston/__init__.py +27 -0
- sqlpiston/_types.py +5 -0
- sqlpiston/builder/__init__.py +0 -0
- sqlpiston/builder/ddl.py +228 -0
- sqlpiston/builder/dml.py +124 -0
- sqlpiston/builder/nodes.py +248 -0
- sqlpiston/builder/selectable.py +153 -0
- sqlpiston/compiler/__init__.py +0 -0
- sqlpiston/compiler/base.py +581 -0
- sqlpiston/compiler/mysql.py +50 -0
- sqlpiston/compiler/sqlite.py +51 -0
- sqlpiston/core/__init__.py +0 -0
- sqlpiston/core/engine/__init__.py +3 -0
- sqlpiston/core/engine/base.py +99 -0
- sqlpiston/core/engine/mysql.py +80 -0
- sqlpiston/core/engine/sqlite.py +76 -0
- sqlpiston/core/pool.py +49 -0
- sqlpiston/core/session.py +61 -0
- sqlpiston/orm/__init__.py +0 -0
- sqlpiston/orm/mapper.py +68 -0
- sqlpiston-0.1.0.dist-info/METADATA +180 -0
- sqlpiston-0.1.0.dist-info/RECORD +25 -0
- sqlpiston-0.1.0.dist-info/WHEEL +5 -0
- sqlpiston-0.1.0.dist-info/licenses/LICENSE +21 -0
- sqlpiston-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional, Tuple, Type, Union
|
|
3
|
+
|
|
4
|
+
from sqlpiston.builder.nodes import (
|
|
5
|
+
ASTNode, BetweenNode, CaseNode, ComparisonNode, ExistsNode, Field,
|
|
6
|
+
InNode, LogicalNode, SQLFunction, ExprValue,
|
|
7
|
+
)
|
|
8
|
+
from sqlpiston.builder.selectable import CompoundSelect, CTE, Select
|
|
9
|
+
from sqlpiston.builder.dml import Delete, Insert, Update, Upsert
|
|
10
|
+
from sqlpiston.builder.ddl import (
|
|
11
|
+
AlterAction, AlterTable, ColumnDef, CreateIndex, CreateTable,
|
|
12
|
+
CreateView, DropIndex, DropTable, DropView, Truncate,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Compiler(ABC):
|
|
17
|
+
"""Base compiler. process() dispatches node type → visit_* method."""
|
|
18
|
+
|
|
19
|
+
def process(self, node: ASTNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
20
|
+
if isinstance(node, Select):
|
|
21
|
+
return self.visit_select(node)
|
|
22
|
+
if isinstance(node, CompoundSelect):
|
|
23
|
+
return self.visit_compound_select(node)
|
|
24
|
+
if isinstance(node, CTE):
|
|
25
|
+
return self.visit_cte(node)
|
|
26
|
+
if isinstance(node, Insert):
|
|
27
|
+
return self.visit_insert(node)
|
|
28
|
+
if isinstance(node, Update):
|
|
29
|
+
return self.visit_update(node)
|
|
30
|
+
if isinstance(node, Delete):
|
|
31
|
+
return self.visit_delete(node)
|
|
32
|
+
if isinstance(node, Upsert):
|
|
33
|
+
return self.visit_upsert(node)
|
|
34
|
+
if isinstance(node, CreateTable):
|
|
35
|
+
return self.visit_create_table(node)
|
|
36
|
+
if isinstance(node, AlterTable):
|
|
37
|
+
return self.visit_alter_table(node)
|
|
38
|
+
if isinstance(node, DropTable):
|
|
39
|
+
return self.visit_drop_table(node)
|
|
40
|
+
if isinstance(node, CreateIndex):
|
|
41
|
+
return self.visit_create_index(node)
|
|
42
|
+
if isinstance(node, DropIndex):
|
|
43
|
+
return self.visit_drop_index(node)
|
|
44
|
+
if isinstance(node, CreateView):
|
|
45
|
+
return self.visit_create_view(node)
|
|
46
|
+
if isinstance(node, DropView):
|
|
47
|
+
return self.visit_drop_view(node)
|
|
48
|
+
if isinstance(node, Truncate):
|
|
49
|
+
return self.visit_truncate(node)
|
|
50
|
+
if isinstance(node, ComparisonNode):
|
|
51
|
+
return self.visit_comparison(node)
|
|
52
|
+
if isinstance(node, InNode):
|
|
53
|
+
return self.visit_in(node)
|
|
54
|
+
if isinstance(node, BetweenNode):
|
|
55
|
+
return self.visit_between(node)
|
|
56
|
+
if isinstance(node, LogicalNode):
|
|
57
|
+
return self.visit_logical(node)
|
|
58
|
+
if isinstance(node, ExistsNode):
|
|
59
|
+
return self.visit_exists(node)
|
|
60
|
+
if isinstance(node, CaseNode):
|
|
61
|
+
return self.visit_case(node)
|
|
62
|
+
if isinstance(node, SQLFunction):
|
|
63
|
+
return self.visit_function(node)
|
|
64
|
+
raise TypeError(f"Unknown AST node type: {type(node).__name__}")
|
|
65
|
+
|
|
66
|
+
# -- DQL --
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def visit_select(self, node: Select) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover — abstract stub
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def visit_compound_select(self, node: CompoundSelect) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def visit_cte(self, node: CTE) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def visit_insert(self, node: Insert) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def visit_update(self, node: Update) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def visit_delete(self, node: Delete) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def visit_create_table(self, node: CreateTable) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def visit_alter_table(self, node: AlterTable) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def visit_drop_table(self, node: DropTable) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def visit_create_index(self, node: CreateIndex) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def visit_drop_index(self, node: DropIndex) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def visit_create_view(self, node: CreateView) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def visit_drop_view(self, node: DropView) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def visit_truncate(self, node: Truncate) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def visit_comparison(self, node: ComparisonNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def visit_in(self, node: InNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def visit_between(self, node: BetweenNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def visit_logical(self, node: LogicalNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def visit_exists(self, node: ExistsNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def visit_case(self, node: CaseNode) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def visit_function(self, node: SQLFunction) -> Tuple[str, Tuple[ExprValue, ...]]: ... # pragma: no cover
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def placeholder(self) -> str: ... # pragma: no cover — abstract stub
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def quote_identifier(self, name: str) -> str: ... # pragma: no cover — abstract stub
|
|
116
|
+
|
|
117
|
+
# -- Shared helpers --
|
|
118
|
+
|
|
119
|
+
def compile_field(self, field: Field) -> str:
|
|
120
|
+
"""Render table-qualified field name with proper quoting."""
|
|
121
|
+
parts: List[str] = []
|
|
122
|
+
if field.table:
|
|
123
|
+
parts.append(self.quote_identifier(field.table))
|
|
124
|
+
parts.append(self.quote_identifier(field.name))
|
|
125
|
+
result = ".".join(parts)
|
|
126
|
+
if field._alias_prop:
|
|
127
|
+
result = f"{result} AS {self.quote_identifier(field._alias_prop)}"
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
def compile_from(self, from_src: Union[str, Select]) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
131
|
+
"""Compile FROM clause source."""
|
|
132
|
+
if isinstance(from_src, str):
|
|
133
|
+
return self.quote_identifier(from_src), ()
|
|
134
|
+
# Subquery
|
|
135
|
+
sql, params = self.visit_select(from_src)
|
|
136
|
+
alias = from_src._alias or "sub"
|
|
137
|
+
return f"({sql}) AS {self.quote_identifier(alias)}", params
|
|
138
|
+
|
|
139
|
+
def compile_joins(self, joins: List[Tuple[str, str, ASTNode]]) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
140
|
+
"""Compile JOIN clauses."""
|
|
141
|
+
parts: List[str] = []
|
|
142
|
+
params: List[ExprValue] = []
|
|
143
|
+
for table, how, on in joins:
|
|
144
|
+
if how == 'CROSS':
|
|
145
|
+
parts.append(f"CROSS JOIN {self.quote_identifier(table)}")
|
|
146
|
+
else:
|
|
147
|
+
on_sql, on_params = self.process(on)
|
|
148
|
+
parts.append(f"{how} JOIN {self.quote_identifier(table)} ON {on_sql}")
|
|
149
|
+
params.extend(on_params)
|
|
150
|
+
return " ".join(parts), tuple(params)
|
|
151
|
+
|
|
152
|
+
def compile_condition(self, node: ASTNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
153
|
+
"""Compile a WHERE/HAVING/ON condition."""
|
|
154
|
+
return self.process(node)
|
|
155
|
+
|
|
156
|
+
def compile_order_by(self, orders: List[Tuple[Union[str, Field], str]]) -> str:
|
|
157
|
+
"""Compile ORDER BY clause."""
|
|
158
|
+
parts: List[str] = []
|
|
159
|
+
for field, direction in orders:
|
|
160
|
+
if isinstance(field, str):
|
|
161
|
+
parts.append(f"{self.quote_identifier(field)} {direction}")
|
|
162
|
+
else:
|
|
163
|
+
parts.append(f"{self.compile_field(field)} {direction}")
|
|
164
|
+
return "ORDER BY " + ", ".join(parts)
|
|
165
|
+
|
|
166
|
+
def compile_group_by(self, groups: List[Union[str, Field]]) -> str:
|
|
167
|
+
"""Compile GROUP BY clause."""
|
|
168
|
+
parts: List[str] = []
|
|
169
|
+
for g in groups:
|
|
170
|
+
if isinstance(g, str):
|
|
171
|
+
parts.append(self.quote_identifier(g))
|
|
172
|
+
else:
|
|
173
|
+
parts.append(self.compile_field(g))
|
|
174
|
+
return "GROUP BY " + ", ".join(parts)
|
|
175
|
+
|
|
176
|
+
def collect_params(self, *results: Tuple[str, Tuple[ExprValue, ...]]) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
177
|
+
"""Join SQL fragments and concatenate params."""
|
|
178
|
+
sql_parts: List[str] = []
|
|
179
|
+
all_params: List[ExprValue] = []
|
|
180
|
+
for sql, params in results:
|
|
181
|
+
if sql:
|
|
182
|
+
sql_parts.append(sql)
|
|
183
|
+
if params:
|
|
184
|
+
all_params.extend(params)
|
|
185
|
+
return " ".join(sql_parts), tuple(all_params)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class GenericCompiler(Compiler):
|
|
189
|
+
"""Platform-agnostic compiler. Uses %s and backtick quoting. Serves as baseline."""
|
|
190
|
+
|
|
191
|
+
def placeholder(self) -> str:
|
|
192
|
+
return '%s'
|
|
193
|
+
|
|
194
|
+
def quote_identifier(self, name: str) -> str:
|
|
195
|
+
return f'`{name}`'
|
|
196
|
+
|
|
197
|
+
# -- Expression nodes --
|
|
198
|
+
|
|
199
|
+
def visit_comparison(self, node: ComparisonNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
200
|
+
field_str = self.compile_field(node.field)
|
|
201
|
+
|
|
202
|
+
# IS NULL / IS NOT NULL — no parameter
|
|
203
|
+
if node.operator in ('IS NULL', 'IS NOT NULL'):
|
|
204
|
+
return f"{field_str} {node.operator}", ()
|
|
205
|
+
|
|
206
|
+
value = node.value
|
|
207
|
+
# Field vs Field — table-qualified comparison
|
|
208
|
+
if isinstance(value, Field):
|
|
209
|
+
return f"{field_str} {node.operator} {self.compile_field(value)}", ()
|
|
210
|
+
|
|
211
|
+
# Field vs Select — scalar subquery
|
|
212
|
+
if isinstance(value, Select):
|
|
213
|
+
sub_sql, sub_params = self.visit_select(value)
|
|
214
|
+
return f"{field_str} {node.operator} ({sub_sql})", sub_params
|
|
215
|
+
|
|
216
|
+
# Field vs SQLFunction
|
|
217
|
+
if isinstance(value, SQLFunction):
|
|
218
|
+
func_sql, func_params = self.visit_function(value)
|
|
219
|
+
return f"{field_str} {node.operator} {func_sql}", func_params
|
|
220
|
+
|
|
221
|
+
# Field vs literal
|
|
222
|
+
return f"{field_str} {node.operator} {self.placeholder()}", (value,)
|
|
223
|
+
|
|
224
|
+
def visit_in(self, node: InNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
225
|
+
field_str = self.compile_field(node.field)
|
|
226
|
+
values = node.values
|
|
227
|
+
|
|
228
|
+
# Subquery
|
|
229
|
+
if isinstance(values, Select):
|
|
230
|
+
sub_sql, sub_params = self.visit_select(values)
|
|
231
|
+
return f"{field_str} IN ({sub_sql})", sub_params
|
|
232
|
+
|
|
233
|
+
# Literal list
|
|
234
|
+
placeholders = ", ".join([self.placeholder()] * len(values))
|
|
235
|
+
return f"{field_str} IN ({placeholders})", tuple(values)
|
|
236
|
+
|
|
237
|
+
def visit_between(self, node: BetweenNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
238
|
+
field_str = self.compile_field(node.field)
|
|
239
|
+
return (
|
|
240
|
+
f"{field_str} BETWEEN {self.placeholder()} AND {self.placeholder()}",
|
|
241
|
+
(node.low, node.high),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def visit_logical(self, node: LogicalNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
245
|
+
if node.operator == 'NOT':
|
|
246
|
+
child_sql, child_params = self.process(node.children[0])
|
|
247
|
+
return f"NOT ({child_sql})", child_params
|
|
248
|
+
|
|
249
|
+
# AND / OR
|
|
250
|
+
parts: List[str] = []
|
|
251
|
+
params: List[ExprValue] = []
|
|
252
|
+
for child in node.children:
|
|
253
|
+
child_sql, child_params = self.process(child)
|
|
254
|
+
# Wrap non-leaf nodes in parens
|
|
255
|
+
if isinstance(child, LogicalNode) and child.operator != 'NOT':
|
|
256
|
+
child_sql = f"({child_sql})"
|
|
257
|
+
parts.append(child_sql)
|
|
258
|
+
params.extend(child_params)
|
|
259
|
+
|
|
260
|
+
separator = f" {node.operator} "
|
|
261
|
+
combined = separator.join(parts)
|
|
262
|
+
if len(parts) > 1:
|
|
263
|
+
combined = f"({combined})"
|
|
264
|
+
return combined, tuple(params)
|
|
265
|
+
|
|
266
|
+
def visit_exists(self, node: ExistsNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
267
|
+
sub_sql, sub_params = self.visit_select(node.select)
|
|
268
|
+
keyword = "NOT EXISTS" if node.negated else "EXISTS"
|
|
269
|
+
return f"{keyword} ({sub_sql})", sub_params
|
|
270
|
+
|
|
271
|
+
def visit_case(self, node: CaseNode) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
272
|
+
parts: List[str] = ["CASE"]
|
|
273
|
+
params: List[ExprValue] = []
|
|
274
|
+
for condition, result in node._whens:
|
|
275
|
+
cond_sql, cond_params = self.process(condition)
|
|
276
|
+
parts.append(f"WHEN {cond_sql} THEN {self.placeholder()}")
|
|
277
|
+
params.extend(cond_params)
|
|
278
|
+
params.append(result)
|
|
279
|
+
if node._else is not None:
|
|
280
|
+
parts.append(f"ELSE {self.placeholder()}")
|
|
281
|
+
params.append(node._else)
|
|
282
|
+
parts.append("END")
|
|
283
|
+
return " ".join(parts), tuple(params)
|
|
284
|
+
|
|
285
|
+
def visit_function(self, node: SQLFunction) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
286
|
+
params: List[ExprValue] = []
|
|
287
|
+
arg_parts: List[str] = []
|
|
288
|
+
for arg in node.args:
|
|
289
|
+
if isinstance(arg, str) and arg == "*":
|
|
290
|
+
arg_parts.append("*")
|
|
291
|
+
elif isinstance(arg, Field):
|
|
292
|
+
arg_parts.append(self.compile_field(arg))
|
|
293
|
+
elif isinstance(arg, SQLFunction):
|
|
294
|
+
sub_sql, sub_params = self.visit_function(arg)
|
|
295
|
+
arg_parts.append(sub_sql)
|
|
296
|
+
params.extend(sub_params)
|
|
297
|
+
else:
|
|
298
|
+
arg_parts.append(self.placeholder())
|
|
299
|
+
params.append(arg)
|
|
300
|
+
|
|
301
|
+
func_sql = f"{node.name}({', '.join(arg_parts)})"
|
|
302
|
+
if node._alias:
|
|
303
|
+
func_sql = f"{func_sql} AS {self.quote_identifier(node._alias)}"
|
|
304
|
+
return func_sql, tuple(params)
|
|
305
|
+
|
|
306
|
+
# -- DQL --
|
|
307
|
+
|
|
308
|
+
def visit_select(self, node: Select) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
309
|
+
sql_parts: List[str] = []
|
|
310
|
+
all_params: List[ExprValue] = []
|
|
311
|
+
|
|
312
|
+
# CTEs
|
|
313
|
+
if node._ctes:
|
|
314
|
+
cte_parts: List[str] = []
|
|
315
|
+
for cte in node._ctes:
|
|
316
|
+
cte_sql, cte_params = self.visit_cte(cte)
|
|
317
|
+
cte_parts.append(cte_sql)
|
|
318
|
+
all_params.extend(cte_params)
|
|
319
|
+
sql_parts.append("WITH " + ", ".join(cte_parts))
|
|
320
|
+
|
|
321
|
+
# SELECT [DISTINCT]
|
|
322
|
+
cols: List[str] = []
|
|
323
|
+
for col in node._columns:
|
|
324
|
+
if isinstance(col, str):
|
|
325
|
+
if col == "*":
|
|
326
|
+
cols.append("*")
|
|
327
|
+
else:
|
|
328
|
+
cols.append(self.quote_identifier(col))
|
|
329
|
+
elif isinstance(col, Field):
|
|
330
|
+
cols.append(self.compile_field(col))
|
|
331
|
+
elif isinstance(col, Select):
|
|
332
|
+
sub_sql, sub_params = self.process(col)
|
|
333
|
+
cols.append(f"({sub_sql})")
|
|
334
|
+
all_params.extend(sub_params)
|
|
335
|
+
elif isinstance(col, SQLFunction):
|
|
336
|
+
func_sql, func_params = self.visit_function(col)
|
|
337
|
+
cols.append(func_sql)
|
|
338
|
+
all_params.extend(func_params)
|
|
339
|
+
distinct = "DISTINCT " if node._distinct else ""
|
|
340
|
+
sql_parts.append(f"SELECT {distinct}{', '.join(cols)}")
|
|
341
|
+
|
|
342
|
+
# FROM
|
|
343
|
+
if node._from:
|
|
344
|
+
from_sql, from_params = self.compile_from(node._from)
|
|
345
|
+
sql_parts.append(f"FROM {from_sql}")
|
|
346
|
+
all_params.extend(from_params)
|
|
347
|
+
|
|
348
|
+
# JOINs
|
|
349
|
+
if node._joins:
|
|
350
|
+
join_sql, join_params = self.compile_joins(node._joins)
|
|
351
|
+
sql_parts.append(join_sql)
|
|
352
|
+
all_params.extend(join_params)
|
|
353
|
+
|
|
354
|
+
# WHERE
|
|
355
|
+
if node._where:
|
|
356
|
+
where_sql, where_params = self.compile_condition(node._where)
|
|
357
|
+
sql_parts.append(f"WHERE {where_sql}")
|
|
358
|
+
all_params.extend(where_params)
|
|
359
|
+
|
|
360
|
+
# GROUP BY
|
|
361
|
+
if node._group_by:
|
|
362
|
+
sql_parts.append(self.compile_group_by(node._group_by))
|
|
363
|
+
|
|
364
|
+
# HAVING
|
|
365
|
+
if node._having:
|
|
366
|
+
having_sql, having_params = self.compile_condition(node._having)
|
|
367
|
+
sql_parts.append(f"HAVING {having_sql}")
|
|
368
|
+
all_params.extend(having_params)
|
|
369
|
+
|
|
370
|
+
# ORDER BY
|
|
371
|
+
if node._order_by:
|
|
372
|
+
sql_parts.append(self.compile_order_by(node._order_by))
|
|
373
|
+
|
|
374
|
+
# LIMIT
|
|
375
|
+
if node._limit is not None:
|
|
376
|
+
sql_parts.append(f"LIMIT {self.placeholder()}")
|
|
377
|
+
all_params.append(node._limit)
|
|
378
|
+
|
|
379
|
+
# OFFSET
|
|
380
|
+
if node._offset is not None:
|
|
381
|
+
sql_parts.append(f"OFFSET {self.placeholder()}")
|
|
382
|
+
all_params.append(node._offset)
|
|
383
|
+
|
|
384
|
+
return " ".join(sql_parts), tuple(all_params)
|
|
385
|
+
|
|
386
|
+
def visit_compound_select(self, node: CompoundSelect) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
387
|
+
left_sql, left_params = self.visit_select(node.left)
|
|
388
|
+
right_sql, right_params = self.visit_select(node.right)
|
|
389
|
+
return (
|
|
390
|
+
f"({left_sql}) {node.operator} ({right_sql})",
|
|
391
|
+
left_params + right_params,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def visit_cte(self, node: CTE) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
395
|
+
sub_sql, sub_params = self.visit_select(node.select)
|
|
396
|
+
return f"{self.quote_identifier(node.name)} AS ({sub_sql})", sub_params
|
|
397
|
+
|
|
398
|
+
# -- DML --
|
|
399
|
+
|
|
400
|
+
def visit_insert(self, node: Insert) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
401
|
+
if node._table is None:
|
|
402
|
+
raise ValueError("INSERT requires a table name")
|
|
403
|
+
|
|
404
|
+
table = self.quote_identifier(node._table)
|
|
405
|
+
|
|
406
|
+
if node._select is not None:
|
|
407
|
+
sub_sql, sub_params = self.visit_select(node._select)
|
|
408
|
+
return f"INSERT INTO {table} {sub_sql}", sub_params
|
|
409
|
+
|
|
410
|
+
if node._data is None:
|
|
411
|
+
raise ValueError("INSERT requires values() or select()")
|
|
412
|
+
|
|
413
|
+
cols = ", ".join(self.quote_identifier(k) for k in node._data.keys())
|
|
414
|
+
placeholders = ", ".join([self.placeholder()] * len(node._data))
|
|
415
|
+
return (
|
|
416
|
+
f"INSERT INTO {table} ({cols}) VALUES ({placeholders})",
|
|
417
|
+
tuple(node._data.values()),
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
def visit_update(self, node: Update) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
421
|
+
if node._table is None:
|
|
422
|
+
raise ValueError("UPDATE requires a table name")
|
|
423
|
+
if node._data is None:
|
|
424
|
+
raise ValueError("UPDATE requires set() data")
|
|
425
|
+
|
|
426
|
+
table = self.quote_identifier(node._table)
|
|
427
|
+
set_parts: List[str] = []
|
|
428
|
+
params: List[ExprValue] = []
|
|
429
|
+
for col, val in node._data.items():
|
|
430
|
+
set_parts.append(f"{self.quote_identifier(col)} = {self.placeholder()}")
|
|
431
|
+
params.append(val)
|
|
432
|
+
|
|
433
|
+
sql = f"UPDATE {table} SET {', '.join(set_parts)}"
|
|
434
|
+
|
|
435
|
+
if node._where:
|
|
436
|
+
where_sql, where_params = self.process(node._where)
|
|
437
|
+
sql += f" WHERE {where_sql}"
|
|
438
|
+
params.extend(where_params)
|
|
439
|
+
|
|
440
|
+
return sql, tuple(params)
|
|
441
|
+
|
|
442
|
+
def visit_delete(self, node: Delete) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
443
|
+
if node._table is None:
|
|
444
|
+
raise ValueError("DELETE requires a table name")
|
|
445
|
+
|
|
446
|
+
table = self.quote_identifier(node._table)
|
|
447
|
+
sql = f"DELETE FROM {table}"
|
|
448
|
+
params: List[ExprValue] = []
|
|
449
|
+
|
|
450
|
+
if node._where:
|
|
451
|
+
where_sql, where_params = self.process(node._where)
|
|
452
|
+
sql += f" WHERE {where_sql}"
|
|
453
|
+
params.extend(where_params)
|
|
454
|
+
|
|
455
|
+
return sql, tuple(params)
|
|
456
|
+
|
|
457
|
+
def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
458
|
+
raise NotImplementedError("UPSERT must be compiled by a dialect-specific compiler")
|
|
459
|
+
|
|
460
|
+
# -- DDL --
|
|
461
|
+
|
|
462
|
+
def visit_create_table(self, node: CreateTable) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
463
|
+
if node._table is None:
|
|
464
|
+
raise ValueError("CREATE TABLE requires a table name")
|
|
465
|
+
|
|
466
|
+
table = self.quote_identifier(node._table)
|
|
467
|
+
if_not_exists = "IF NOT EXISTS " if node._if_not_exists else ""
|
|
468
|
+
|
|
469
|
+
col_defs: List[str] = []
|
|
470
|
+
params: List[ExprValue] = []
|
|
471
|
+
for col in node._columns:
|
|
472
|
+
col_parts = [self.quote_identifier(col.name), col.type_]
|
|
473
|
+
if not col.nullable:
|
|
474
|
+
col_parts.append("NOT NULL")
|
|
475
|
+
if col.default is not None:
|
|
476
|
+
col_parts.append(f"DEFAULT {self.placeholder()}")
|
|
477
|
+
params.append(col.default)
|
|
478
|
+
if col.primary_key:
|
|
479
|
+
col_parts.append("PRIMARY KEY")
|
|
480
|
+
if col.unique:
|
|
481
|
+
col_parts.append("UNIQUE")
|
|
482
|
+
col_defs.append(" ".join(col_parts))
|
|
483
|
+
|
|
484
|
+
return (
|
|
485
|
+
f"CREATE TABLE {if_not_exists}{table} ({', '.join(col_defs)})",
|
|
486
|
+
tuple(params),
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
def visit_alter_table(self, node: AlterTable) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
490
|
+
if node._table is None:
|
|
491
|
+
raise ValueError("ALTER TABLE requires a table name")
|
|
492
|
+
|
|
493
|
+
table = self.quote_identifier(node._table)
|
|
494
|
+
sql_parts: List[str] = []
|
|
495
|
+
all_params: List[ExprValue] = []
|
|
496
|
+
|
|
497
|
+
for action, col_name, col_type, col_def in node._actions:
|
|
498
|
+
if action == AlterAction.ADD:
|
|
499
|
+
sql_parts.append(f"ALTER TABLE {table} ADD COLUMN {self._compile_column_def(col_def)}")
|
|
500
|
+
if col_def is not None and col_def.default is not None:
|
|
501
|
+
all_params.append(col_def.default)
|
|
502
|
+
elif action == AlterAction.DROP:
|
|
503
|
+
sql_parts.append(f"ALTER TABLE {table} DROP COLUMN {self.quote_identifier(col_name)}")
|
|
504
|
+
elif action == AlterAction.MODIFY:
|
|
505
|
+
sql_parts.append(f"ALTER TABLE {table} MODIFY COLUMN {self._compile_column_def(col_def)}")
|
|
506
|
+
if col_def is not None and col_def.default is not None:
|
|
507
|
+
all_params.append(col_def.default)
|
|
508
|
+
|
|
509
|
+
return "; ".join(sql_parts), tuple(all_params)
|
|
510
|
+
|
|
511
|
+
def _compile_column_def(self, col_def: Optional[ColumnDef]) -> str:
|
|
512
|
+
if col_def is None: # pragma: no cover — defensive guard, never called with None
|
|
513
|
+
return ""
|
|
514
|
+
parts = [self.quote_identifier(col_def.name), col_def.type_]
|
|
515
|
+
if not col_def.nullable:
|
|
516
|
+
parts.append("NOT NULL")
|
|
517
|
+
if col_def.default is not None:
|
|
518
|
+
parts.append(f"DEFAULT {self.placeholder()}")
|
|
519
|
+
if col_def.unique:
|
|
520
|
+
parts.append("UNIQUE")
|
|
521
|
+
return " ".join(parts)
|
|
522
|
+
|
|
523
|
+
def visit_drop_table(self, node: DropTable) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
524
|
+
if node._table is None:
|
|
525
|
+
raise ValueError("DROP TABLE requires a table name")
|
|
526
|
+
if_exists = "IF EXISTS " if node._if_exists else ""
|
|
527
|
+
return f"DROP TABLE {if_exists}{self.quote_identifier(node._table)}", ()
|
|
528
|
+
|
|
529
|
+
def visit_create_index(self, node: CreateIndex) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
530
|
+
if node._name is None or node._table is None:
|
|
531
|
+
raise ValueError("CREATE INDEX requires name and table")
|
|
532
|
+
unique = "UNIQUE " if node._unique else ""
|
|
533
|
+
if_not_exists = "IF NOT EXISTS " if node._if_not_exists else ""
|
|
534
|
+
cols = ", ".join(self.quote_identifier(c) for c in node._columns)
|
|
535
|
+
return (
|
|
536
|
+
f"CREATE {unique}INDEX {if_not_exists}{self.quote_identifier(node._name)} "
|
|
537
|
+
f"ON {self.quote_identifier(node._table)} ({cols})",
|
|
538
|
+
(),
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
def visit_drop_index(self, node: DropIndex) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
542
|
+
if node._name is None:
|
|
543
|
+
raise ValueError("DROP INDEX requires an index name")
|
|
544
|
+
if_exists = "IF EXISTS " if node._if_exists else ""
|
|
545
|
+
sql = f"DROP INDEX {if_exists}{self.quote_identifier(node._name)}"
|
|
546
|
+
if node._table:
|
|
547
|
+
sql += f" ON {self.quote_identifier(node._table)}"
|
|
548
|
+
return sql, ()
|
|
549
|
+
|
|
550
|
+
def visit_create_view(self, node: CreateView) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
551
|
+
if node._name is None or node._select is None:
|
|
552
|
+
raise ValueError("CREATE VIEW requires name and AS SELECT")
|
|
553
|
+
if_not_exists = "IF NOT EXISTS " if node._if_not_exists else ""
|
|
554
|
+
sub_sql, sub_params = self.visit_select(node._select)
|
|
555
|
+
return (
|
|
556
|
+
f"CREATE VIEW {if_not_exists}{self.quote_identifier(node._name)} AS {sub_sql}",
|
|
557
|
+
sub_params,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
def visit_drop_view(self, node: DropView) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
561
|
+
if node._name is None:
|
|
562
|
+
raise ValueError("DROP VIEW requires a view name")
|
|
563
|
+
if_exists = "IF EXISTS " if node._if_exists else ""
|
|
564
|
+
return f"DROP VIEW {if_exists}{self.quote_identifier(node._name)}", ()
|
|
565
|
+
|
|
566
|
+
def visit_truncate(self, node: Truncate) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
567
|
+
if node._table is None:
|
|
568
|
+
raise ValueError("TRUNCATE requires a table name")
|
|
569
|
+
return f"TRUNCATE TABLE {self.quote_identifier(node._table)}", ()
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class Dialect:
|
|
573
|
+
"""Holds compiler factory and syntax config for a database."""
|
|
574
|
+
|
|
575
|
+
def __init__(self, placeholder: str, quote_char: str, compiler_cls: Type[Compiler]) -> None:
|
|
576
|
+
self.placeholder = placeholder
|
|
577
|
+
self.quote_char = quote_char
|
|
578
|
+
self._compiler_cls = compiler_cls
|
|
579
|
+
|
|
580
|
+
def get_compiler(self) -> Compiler:
|
|
581
|
+
return self._compiler_cls()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
from sqlpiston.builder.dml import Upsert
|
|
4
|
+
from sqlpiston.compiler.base import Dialect, GenericCompiler
|
|
5
|
+
from sqlpiston.builder.nodes import ExprValue
|
|
6
|
+
from sqlpiston._types import ColumnValue
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MySQLCompiler(GenericCompiler):
|
|
10
|
+
"""MySQL dialect compiler. %s placeholders, `backtick` quoting."""
|
|
11
|
+
|
|
12
|
+
def placeholder(self) -> str:
|
|
13
|
+
return '%s'
|
|
14
|
+
|
|
15
|
+
def quote_identifier(self, name: str) -> str:
|
|
16
|
+
return f'`{name}`'
|
|
17
|
+
|
|
18
|
+
def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
19
|
+
if node._table is None or node._data is None:
|
|
20
|
+
raise ValueError("UPSERT requires table and values")
|
|
21
|
+
|
|
22
|
+
table = self.quote_identifier(node._table)
|
|
23
|
+
cols = ", ".join(self.quote_identifier(k) for k in node._data.keys())
|
|
24
|
+
placeholders = ", ".join([self.placeholder()] * len(node._data))
|
|
25
|
+
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
|
|
26
|
+
|
|
27
|
+
if node._do_nothing and node._conflict_columns:
|
|
28
|
+
conflict_cols = ", ".join(self.quote_identifier(c) for c in node._conflict_columns)
|
|
29
|
+
sql += f" ON DUPLICATE KEY UPDATE {conflict_cols} = {conflict_cols}"
|
|
30
|
+
return sql, tuple(node._data.values())
|
|
31
|
+
|
|
32
|
+
if node._update_data:
|
|
33
|
+
update_parts: List[str] = []
|
|
34
|
+
params: List[ColumnValue] = list(node._data.values())
|
|
35
|
+
for col in node._update_data:
|
|
36
|
+
update_parts.append(f"{self.quote_identifier(col)} = VALUES({self.quote_identifier(col)})")
|
|
37
|
+
sql += " ON DUPLICATE KEY UPDATE " + ", ".join(update_parts)
|
|
38
|
+
return sql, tuple(params)
|
|
39
|
+
|
|
40
|
+
if node._do_nothing:
|
|
41
|
+
# INSERT IGNORE as fallback
|
|
42
|
+
sql = sql.replace("INSERT INTO", "INSERT IGNORE INTO")
|
|
43
|
+
return sql, tuple(node._data.values())
|
|
44
|
+
|
|
45
|
+
return sql, tuple(node._data.values())
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MySQLDialect(Dialect):
|
|
49
|
+
def __init__(self) -> None:
|
|
50
|
+
super().__init__(placeholder='%s', quote_char='`', compiler_cls=MySQLCompiler)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
from sqlpiston.builder.dml import Upsert
|
|
4
|
+
from sqlpiston.compiler.base import Dialect, GenericCompiler
|
|
5
|
+
from sqlpiston.builder.nodes import ExprValue
|
|
6
|
+
from sqlpiston._types import ColumnValue
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SQLiteCompiler(GenericCompiler):
|
|
10
|
+
"""SQLite dialect compiler. ? placeholders, "double-quote" quoting."""
|
|
11
|
+
|
|
12
|
+
def placeholder(self) -> str:
|
|
13
|
+
return '?'
|
|
14
|
+
|
|
15
|
+
def quote_identifier(self, name: str) -> str:
|
|
16
|
+
return f'"{name}"'
|
|
17
|
+
|
|
18
|
+
def visit_upsert(self, node: Upsert) -> Tuple[str, Tuple[ExprValue, ...]]:
|
|
19
|
+
if node._table is None or node._data is None:
|
|
20
|
+
raise ValueError("UPSERT requires table and values")
|
|
21
|
+
|
|
22
|
+
table = self.quote_identifier(node._table)
|
|
23
|
+
cols = ", ".join(self.quote_identifier(k) for k in node._data.keys())
|
|
24
|
+
placeholders = ", ".join([self.placeholder()] * len(node._data))
|
|
25
|
+
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
|
|
26
|
+
|
|
27
|
+
if node._do_nothing and node._conflict_columns:
|
|
28
|
+
conflict_cols = ", ".join(self.quote_identifier(c) for c in node._conflict_columns)
|
|
29
|
+
sql += f" ON CONFLICT ({conflict_cols}) DO NOTHING"
|
|
30
|
+
return sql, tuple(node._data.values())
|
|
31
|
+
|
|
32
|
+
if node._update_data and node._conflict_columns:
|
|
33
|
+
conflict_cols = ", ".join(self.quote_identifier(c) for c in node._conflict_columns)
|
|
34
|
+
update_parts: List[str] = []
|
|
35
|
+
update_params: List[ColumnValue] = []
|
|
36
|
+
for col, val in node._update_data.items():
|
|
37
|
+
update_parts.append(f"{self.quote_identifier(col)} = {self.placeholder()}")
|
|
38
|
+
update_params.append(val)
|
|
39
|
+
sql += f" ON CONFLICT ({conflict_cols}) DO UPDATE SET {', '.join(update_parts)}"
|
|
40
|
+
return sql, tuple(list(node._data.values()) + update_params)
|
|
41
|
+
|
|
42
|
+
if node._do_nothing:
|
|
43
|
+
sql += " ON CONFLICT DO NOTHING"
|
|
44
|
+
return sql, tuple(node._data.values())
|
|
45
|
+
|
|
46
|
+
return sql, tuple(node._data.values())
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SQLiteDialect(Dialect):
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
super().__init__(placeholder='?', quote_char='"', compiler_cls=SQLiteCompiler)
|
|
File without changes
|