dbapi-mongodb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1004 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Sequence, Mapping
6
+ from dataclasses import replace
7
+
8
+ from sqlglot import parse_one, exp
9
+
10
+ from .errors import raise_error
11
+
12
+
13
+ PLACEHOLDER_PATTERN = re.compile(r"%s")
14
+ NAMED_PLACEHOLDER_PATTERN = re.compile(r"%\((?P<name>[^)]+)\)s")
15
+ PARAM_TOKEN_TEMPLATE = "__param_{index}__"
16
+ PARAM_NAMED_TEMPLATE = "__param_{name}__"
17
+
18
+ CREATE_INDEX_RE = re.compile(
19
+ r"^create\s+(unique\s+)?index\s+([A-Za-z_][\w-]*)\s+on\s+([A-Za-z_][\w-]*)\s*\(([^)]+)\)",
20
+ re.IGNORECASE,
21
+ )
22
+ DROP_INDEX_RE = re.compile(
23
+ r"^drop\s+index\s+([A-Za-z_][\w-]*)\s+on\s+([A-Za-z_][\w-]*)", re.IGNORECASE
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class QueryParts:
29
+ """Mongo query parts / Mongo クエリ部品"""
30
+
31
+ operation: str
32
+ collection: str
33
+ filter: Dict[str, Any] | None = None
34
+ projection: List[str] | None = None
35
+ projection_paths: List[tuple[str, str]] | None = None
36
+ sort: List[tuple[str, int]] | None = None
37
+ limit: int | None = None
38
+ skip: int | None = None
39
+ values: Dict[str, Any] | None = None
40
+ update: Dict[str, Any] | None = None
41
+ pipeline: List[Dict[str, Any]] | None = None
42
+ index_keys: List[tuple[str, int]] | None = None
43
+ index_name: str | None = None
44
+ unique: bool = False
45
+ union_parts: List["QueryParts"] | None = None
46
+ subqueries: dict[str, dict[str, Any]] | None = None
47
+ inline_token: str | None = None
48
+ inline_rows: list[dict[str, Any]] | None = None
49
+ inline_aggregates: list[tuple[str, str, str | None]] | None = None
50
+ uses_window: bool = False
51
+
52
+
53
+ def preprocess_sql(sql: str, params: Sequence[Any] | Mapping[str, Any] | None) -> tuple[str, list[Any], list[str]]:
54
+ """Replace placeholders with param tokens and validate / プレースホルダーを置換し検証"""
55
+ params_seq: list[Any] = []
56
+ tokens: list[str] = []
57
+ new_sql = sql
58
+ named_matches = list(NAMED_PLACEHOLDER_PATTERN.finditer(sql))
59
+ if named_matches:
60
+ if not isinstance(params, Mapping):
61
+ raise_error("[mdb][E4]")
62
+ used = []
63
+ for m in named_matches:
64
+ name = m.group("name")
65
+ if name not in params:
66
+ raise_error("[mdb][E4]")
67
+ token = PARAM_NAMED_TEMPLATE.format(name=name)
68
+ new_sql = new_sql.replace(m.group(0), token, 1)
69
+ params_seq.append(params[name])
70
+ tokens.append(token)
71
+ used.append(name)
72
+ if len(used) != len(params):
73
+ raise_error("[mdb][E4]")
74
+ return new_sql, params_seq, tokens
75
+ matches = list(PLACEHOLDER_PATTERN.finditer(sql))
76
+ count = len(matches)
77
+ params_list = list(params or [])
78
+ if count != len(params_list):
79
+ raise_error("[mdb][E4]")
80
+ for idx, _ in enumerate(matches):
81
+ token = PARAM_TOKEN_TEMPLATE.format(index=idx)
82
+ new_sql = new_sql.replace("%s", token, 1)
83
+ params_seq.append(params_list[idx])
84
+ tokens.append(token)
85
+ return new_sql, params_seq, tokens
86
+
87
+
88
+ def _register_subquery(
89
+ sub_expr: exp.Expression, params_map: dict[str, Any], parent_subqueries: dict[str, dict[str, Any]], mode: str
90
+ ) -> str:
91
+ """Register subquery and return placeholder token / サブクエリを登録しトークンを返す"""
92
+ # Collect nested subqueries separately to keep scopes isolated
93
+ sub_collector: dict[str, dict[str, Any]] = {}
94
+ if isinstance(sub_expr, exp.Subquery):
95
+ sub_expr = sub_expr.this
96
+ inner_select = getattr(sub_expr, "this", None)
97
+ if not isinstance(sub_expr, exp.Select) and isinstance(inner_select, exp.Select):
98
+ sub_expr = inner_select
99
+ if not isinstance(sub_expr, exp.Select):
100
+ raise_error("[mdb][E2]", "Unsupported SQL construct: SUBQUERY")
101
+ sub_parts = _parse_select_like(sub_expr, params_map, sub_collector)
102
+ sub_parts.subqueries = sub_collector or None
103
+ token = f"__subquery_{len(parent_subqueries)}__"
104
+ parent_subqueries[token] = {"parts": sub_parts, "mode": mode}
105
+ return token
106
+
107
+
108
+ def _literal_value(
109
+ node: exp.Expression, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]
110
+ ) -> Any:
111
+ """Extract value from SQLGlot expression / SQLGlot 式から値を取得"""
112
+ if isinstance(node, exp.Literal):
113
+ if node.is_string:
114
+ return node.this
115
+ try:
116
+ return node.to_python()
117
+ except Exception:
118
+ try:
119
+ return int(node.this)
120
+ except Exception:
121
+ try:
122
+ return float(node.this)
123
+ except Exception:
124
+ return node.this
125
+ if isinstance(node, exp.Column):
126
+ name = ".".join(part.name for part in node.parts if hasattr(part, "name"))
127
+ if name in params_map:
128
+ return params_map[name]
129
+ raise_error("[mdb][E2]", "Unsupported SQL construct: COLUMN_AS_VALUE")
130
+ if isinstance(node, (exp.Subquery, exp.Select)):
131
+ return _register_subquery(node, params_map, subqueries, mode="values")
132
+ if isinstance(node, exp.Tuple):
133
+ return [_literal_value(e, params_map, subqueries) for e in node.expressions]
134
+ raise_error("[mdb][E2]")
135
+
136
+
137
+ def _field_name(node: exp.Expression, params_map: dict[str, Any]) -> str:
138
+ """Extract field name / フィールド名を抽出"""
139
+ if isinstance(node, exp.Column):
140
+ # Prefer column name without table prefix to match Mongo field / Mongo のフィールド名にテーブル接頭辞を付けない
141
+ if node.table:
142
+ return node.name
143
+ return ".".join(part.name for part in node.parts if hasattr(part, "name"))
144
+ if isinstance(node, exp.Identifier):
145
+ return node.name
146
+ if isinstance(node, exp.Literal) and node.is_string:
147
+ return node.this
148
+ if isinstance(node, exp.Column) and node.name in params_map:
149
+ raise_error("[mdb][E2]", "Unsupported SQL construct: PARAM_AS_FIELD")
150
+ raise_error("[mdb][E2]")
151
+
152
+
153
+ def _column_table_field(node: exp.Expression) -> tuple[str | None, str]:
154
+ """Return (table, field) for Column / カラムのテーブル名とフィールド名を返す"""
155
+ if isinstance(node, exp.Column):
156
+ tbl = node.table
157
+ name = node.name
158
+ return tbl, name
159
+ raise_error("[mdb][E2]")
160
+
161
+
162
+ def _field_with_alias(node: exp.Expression, alias_map: dict[str, str]) -> str:
163
+ if isinstance(node, exp.Column):
164
+ tbl = node.table or ""
165
+ fld = node.name
166
+ if tbl and tbl in alias_map:
167
+ return f"{alias_map[tbl]}{fld}"
168
+ if not tbl and "" in alias_map:
169
+ return f"{alias_map['']}{fld}"
170
+ if not tbl:
171
+ # default to base alias (empty prefix)
172
+ return f"{alias_map.get('', '')}{fld}"
173
+ raise_error("[mdb][E2]")
174
+
175
+
176
+ def _like_to_regex(pattern: str) -> str:
177
+ """Convert SQL LIKE pattern to regex / LIKE パターンを正規表現へ"""
178
+ escaped = ""
179
+ i = 0
180
+ while i < len(pattern):
181
+ ch = pattern[i]
182
+ if ch == "%":
183
+ escaped += ".*"
184
+ elif ch == "_":
185
+ escaped += "."
186
+ elif ch == "\\" and i + 1 < len(pattern):
187
+ escaped += re.escape(pattern[i + 1])
188
+ i += 1
189
+ else:
190
+ escaped += re.escape(ch)
191
+ i += 1
192
+ return f"^{escaped}$"
193
+
194
+
195
+ def _condition_to_filter(
196
+ node: exp.Expression, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]
197
+ ) -> Dict[str, Any]:
198
+ """Convert WHERE expression to Mongo filter / WHERE を Mongo フィルタへ変換"""
199
+ if isinstance(node, exp.And):
200
+ parts = []
201
+ if node.expressions:
202
+ parts = [_condition_to_filter(e, params_map, subqueries) for e in node.expressions]
203
+ else:
204
+ parts = [
205
+ _condition_to_filter(node.this, params_map, subqueries),
206
+ _condition_to_filter(node.expression, params_map, subqueries),
207
+ ]
208
+ return {"$and": parts}
209
+ if isinstance(node, exp.Or):
210
+ parts = []
211
+ if node.expressions:
212
+ parts = [_condition_to_filter(e, params_map, subqueries) for e in node.expressions]
213
+ else:
214
+ parts = [
215
+ _condition_to_filter(node.this, params_map, subqueries),
216
+ _condition_to_filter(node.expression, params_map, subqueries),
217
+ ]
218
+ return {"$or": parts}
219
+ if isinstance(node, exp.Between):
220
+ field = _field_name(node.this, params_map)
221
+ low = _literal_value(node.args["low"], params_map, subqueries)
222
+ high = _literal_value(node.args["high"], params_map, subqueries)
223
+ return {field: {"$gte": low, "$lte": high}}
224
+ if isinstance(node, exp.Like):
225
+ field = _field_name(node.this, params_map)
226
+ value = _literal_value(node.expression, params_map, subqueries)
227
+ if not isinstance(value, str):
228
+ raise_error("[mdb][E2]", "Unsupported SQL construct: LIKE")
229
+ regex = _like_to_regex(value)
230
+ return {field: {"$regex": regex}}
231
+ if hasattr(exp, "ILike") and isinstance(node, getattr(exp, "ILike")):
232
+ field = _field_name(node.this, params_map)
233
+ value = _literal_value(node.expression, params_map, subqueries)
234
+ regex = _like_to_regex(str(value))
235
+ return {field: {"$regex": regex, "$options": "i"}}
236
+ def _strip_slashes(val: Any) -> str:
237
+ sval = str(val)
238
+ if sval.startswith("/") and sval.endswith("/") and len(sval) >= 2:
239
+ return sval[1:-1]
240
+ return sval
241
+
242
+ if hasattr(exp, "Regex") and isinstance(node, getattr(exp, "Regex")):
243
+ field = _field_name(node.this, params_map)
244
+ pattern = _strip_slashes(_literal_value(node.expression, params_map, subqueries))
245
+ return {field: {"$regex": str(pattern)}}
246
+ if hasattr(exp, "RegexpLike") and isinstance(node, getattr(exp, "RegexpLike")):
247
+ field = _field_name(node.this, params_map)
248
+ pattern = _strip_slashes(_literal_value(node.expression, params_map, subqueries))
249
+ return {field: {"$regex": str(pattern)}}
250
+ if isinstance(node, exp.In):
251
+ field = _field_name(node.this, params_map)
252
+ expr_val = node.expression or node.args.get("query") or node.args.get("expressions")
253
+ if isinstance(expr_val, (exp.Subquery, exp.Select)) or (
254
+ isinstance(expr_val, exp.Expression) and expr_val.find(exp.Select)
255
+ ):
256
+ token = _register_subquery(expr_val, params_map, subqueries, mode="values")
257
+ values = token
258
+ else:
259
+ if isinstance(expr_val, list):
260
+ values = [_literal_value(v, params_map, subqueries) for v in expr_val]
261
+ else:
262
+ values = _literal_value(expr_val, params_map, subqueries)
263
+ return {field: {"$in": values}}
264
+ if isinstance(node, exp.Exists):
265
+ sub_expr = node.this
266
+ token = _register_subquery(sub_expr, params_map, subqueries, mode="exists")
267
+ return {"$expr": {"$literal": token}}
268
+ if isinstance(node, exp.Paren):
269
+ return _condition_to_filter(node.this, params_map, subqueries)
270
+ if isinstance(node, (exp.EQ, exp.NEQ, exp.GT, exp.GTE, exp.LT, exp.LTE)):
271
+ left = node.left
272
+ right = node.right
273
+ field = _field_name(left, params_map)
274
+ value = _literal_value(right, params_map, subqueries)
275
+ if isinstance(node, exp.EQ):
276
+ return {field: value}
277
+ if isinstance(node, exp.NEQ):
278
+ return {field: {"$ne": value}}
279
+ if isinstance(node, exp.GT):
280
+ return {field: {"$gt": value}}
281
+ if isinstance(node, exp.GTE):
282
+ return {field: {"$gte": value}}
283
+ if isinstance(node, exp.LT):
284
+ return {field: {"$lt": value}}
285
+ if isinstance(node, exp.LTE):
286
+ return {field: {"$lte": value}}
287
+ raise_error("[mdb][E2]")
288
+
289
+
290
+ def _condition_to_filter_join(
291
+ node: exp.Expression, params_map: dict[str, Any], allowed_table: str, subqueries: dict[str, dict[str, Any]]
292
+ ) -> Dict[str, Any]:
293
+ """WHERE for JOIN: only allow columns from allowed_table / JOIN の WHERE は左テーブルのみ許可"""
294
+ if isinstance(node, exp.And):
295
+ filters = []
296
+ if node.expressions:
297
+ filters = [_condition_to_filter_join(e, params_map, allowed_table, subqueries) for e in node.expressions]
298
+ else:
299
+ filters = [
300
+ _condition_to_filter_join(node.this, params_map, allowed_table, subqueries),
301
+ _condition_to_filter_join(node.expression, params_map, allowed_table, subqueries),
302
+ ]
303
+ return {"$and": filters}
304
+ if isinstance(node, exp.Or):
305
+ filters = []
306
+ if node.expressions:
307
+ filters = [_condition_to_filter_join(e, params_map, allowed_table, subqueries) for e in node.expressions]
308
+ else:
309
+ filters = [
310
+ _condition_to_filter_join(node.this, params_map, allowed_table, subqueries),
311
+ _condition_to_filter_join(node.expression, params_map, allowed_table, subqueries),
312
+ ]
313
+ return {"$or": filters}
314
+ if isinstance(node, exp.In):
315
+ tbl, field = _column_table_field(node.this)
316
+ if tbl and tbl != allowed_table:
317
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_WHERE_RIGHT_TABLE")
318
+ values = _literal_value(node.expression, params_map, subqueries)
319
+ return {field: {"$in": values}}
320
+ if isinstance(node, (exp.EQ, exp.NEQ, exp.GT, exp.GTE, exp.LT, exp.LTE)):
321
+ tbl, field = _column_table_field(node.left)
322
+ if tbl and tbl != allowed_table:
323
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_WHERE_RIGHT_TABLE")
324
+ value = _literal_value(node.right, params_map, subqueries)
325
+ if isinstance(node, exp.EQ):
326
+ return {field: value}
327
+ if isinstance(node, exp.NEQ):
328
+ return {field: {"$ne": value}}
329
+ if isinstance(node, exp.GT):
330
+ return {field: {"$gt": value}}
331
+ if isinstance(node, exp.GTE):
332
+ return {field: {"$gte": value}}
333
+ if isinstance(node, exp.LT):
334
+ return {field: {"$lt": value}}
335
+ if isinstance(node, exp.LTE):
336
+ return {field: {"$lte": value}}
337
+ if isinstance(node, exp.Paren):
338
+ return _condition_to_filter_join(node.this, params_map, allowed_table, subqueries)
339
+ if isinstance(node, exp.Between):
340
+ tbl, field = _column_table_field(node.this)
341
+ if tbl and tbl != allowed_table:
342
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_WHERE_RIGHT_TABLE")
343
+ low = _literal_value(node.args["low"], params_map, subqueries)
344
+ high = _literal_value(node.args["high"], params_map, subqueries)
345
+ return {field: {"$gte": low, "$lte": high}}
346
+ if isinstance(node, exp.Like):
347
+ tbl, field = _column_table_field(node.this)
348
+ if tbl and tbl != allowed_table:
349
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_WHERE_RIGHT_TABLE")
350
+ value = _literal_value(node.expression, params_map, subqueries)
351
+ regex = _like_to_regex(str(value))
352
+ return {field: {"$regex": regex}}
353
+ if hasattr(exp, "ILike") and isinstance(node, getattr(exp, "ILike")):
354
+ tbl, field = _column_table_field(node.this)
355
+ if tbl and tbl != allowed_table:
356
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_WHERE_RIGHT_TABLE")
357
+ value = _literal_value(node.expression, params_map, subqueries)
358
+ regex = _like_to_regex(str(value))
359
+ return {field: {"$regex": regex, "$options": "i"}}
360
+ raise_error("[mdb][E2]")
361
+
362
+
363
+ def _condition_to_filter_alias(
364
+ node: exp.Expression, params_map: dict[str, Any], alias_map: dict[str, str], subqueries: dict[str, dict[str, Any]]
365
+ ) -> Dict[str, Any]:
366
+ """WHERE with alias prefixes / エイリアス付き WHERE を Mongo フィルタへ変換"""
367
+ if isinstance(node, exp.And):
368
+ parts = []
369
+ if node.expressions:
370
+ parts = [_condition_to_filter_alias(e, params_map, alias_map, subqueries) for e in node.expressions]
371
+ else:
372
+ parts = [
373
+ _condition_to_filter_alias(node.this, params_map, alias_map, subqueries),
374
+ _condition_to_filter_alias(node.expression, params_map, alias_map, subqueries),
375
+ ]
376
+ return {"$and": parts}
377
+ if isinstance(node, exp.Or):
378
+ parts = []
379
+ if node.expressions:
380
+ parts = [_condition_to_filter_alias(e, params_map, alias_map, subqueries) for e in node.expressions]
381
+ else:
382
+ parts = [
383
+ _condition_to_filter_alias(node.this, params_map, alias_map, subqueries),
384
+ _condition_to_filter_alias(node.expression, params_map, alias_map, subqueries),
385
+ ]
386
+ return {"$or": parts}
387
+ if isinstance(node, exp.Between):
388
+ field = _field_with_alias(node.this, alias_map)
389
+ low = _literal_value(node.args["low"], params_map, subqueries)
390
+ high = _literal_value(node.args["high"], params_map, subqueries)
391
+ return {field: {"$gte": low, "$lte": high}}
392
+ if isinstance(node, exp.Like):
393
+ field = _field_with_alias(node.this, alias_map)
394
+ value = _literal_value(node.expression, params_map, subqueries)
395
+ regex = _like_to_regex(str(value))
396
+ return {field: {"$regex": regex}}
397
+ if hasattr(exp, "ILike") and isinstance(node, getattr(exp, "ILike")):
398
+ field = _field_with_alias(node.this, alias_map)
399
+ value = _literal_value(node.expression, params_map, subqueries)
400
+ regex = _like_to_regex(str(value))
401
+ return {field: {"$regex": regex, "$options": "i"}}
402
+ if isinstance(node, exp.In):
403
+ field = _field_with_alias(node.this, alias_map)
404
+ values = _literal_value(node.expression, params_map, subqueries)
405
+ return {field: {"$in": values}}
406
+ if isinstance(node, exp.Paren):
407
+ return _condition_to_filter_alias(node.this, params_map, alias_map, subqueries)
408
+ if isinstance(node, (exp.EQ, exp.NEQ, exp.GT, exp.GTE, exp.LT, exp.LTE)):
409
+ field = _field_with_alias(node.left, alias_map)
410
+ value = _literal_value(node.right, params_map, subqueries)
411
+ if isinstance(node, exp.EQ):
412
+ return {field: value}
413
+ if isinstance(node, exp.NEQ):
414
+ return {field: {"$ne": value}}
415
+ if isinstance(node, exp.GT):
416
+ return {field: {"$gt": value}}
417
+ if isinstance(node, exp.GTE):
418
+ return {field: {"$gte": value}}
419
+ if isinstance(node, exp.LT):
420
+ return {field: {"$lt": value}}
421
+ if isinstance(node, exp.LTE):
422
+ return {field: {"$lte": value}}
423
+ raise_error("[mdb][E2]")
424
+
425
+
426
+ def _ensure_supported(expr: exp.Expression) -> None:
427
+ """Reject unsupported constructs early / 非対応構文を早期に検出"""
428
+ unsupported = (exp.Or, exp.Between, exp.Like, exp.Offset)
429
+ for node in expr.walk():
430
+ if isinstance(node, unsupported):
431
+ keyword = node.key.upper() if hasattr(node, "key") else node.__class__.__name__
432
+ raise_error("[mdb][E2]", f"Unsupported SQL construct: {keyword}")
433
+
434
+
435
+ def parse_sql(sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None) -> QueryParts:
436
+ """Parse SQL to QueryParts / SQL を QueryParts に変換"""
437
+ normalized_sql, param_values, tokens = preprocess_sql(sql, params)
438
+ params_map = {tokens[i]: val for i, val in enumerate(param_values)}
439
+ subqueries: dict[str, dict[str, Any]] = {}
440
+ # Handle CREATE/DROP INDEX via simple parser
441
+ ci = _parse_create_index_sql(normalized_sql)
442
+ if ci:
443
+ return ci
444
+ di = _parse_drop_index_sql(normalized_sql)
445
+ if di:
446
+ return di
447
+ try:
448
+ expr = parse_one(normalized_sql)
449
+ except Exception as exc:
450
+ raise_error("[mdb][E5]", cause=exc)
451
+
452
+ if isinstance(expr, exp.Union):
453
+ if expr.args.get("distinct"):
454
+ raise_error("[mdb][E2]", "Unsupported SQL construct: UNION DISTINCT")
455
+ left = _parse_select_like(expr.left, params_map, subqueries)
456
+ right = _parse_select_like(expr.right, params_map, subqueries)
457
+ order = None
458
+ limit = None
459
+ if expr.args.get("order"):
460
+ order = []
461
+ for e in expr.args["order"].expressions:
462
+ field = _field_name(e.this, params_map)
463
+ direction = -1 if e.args.get("desc") else 1
464
+ order.append((field, direction))
465
+ if expr.args.get("limit"):
466
+ try:
467
+ limit = int(expr.args["limit"].expression.name)
468
+ except Exception:
469
+ limit = int(expr.args["limit"].expression.this)
470
+ parts = QueryParts(operation="union_all", collection="", union_parts=[left, right], sort=order, limit=limit)
471
+ parts.subqueries = subqueries or None
472
+ return parts
473
+
474
+ if isinstance(expr, exp.Select):
475
+ # window function detection
476
+ if expr.find(exp.Window) or expr.find(exp.RowNumber) or expr.find(exp.Rank) or expr.find(exp.DenseRank):
477
+ return _parse_window_select(expr, params_map, subqueries)
478
+ if expr.args.get("joins"):
479
+ parts = _parse_join_select(expr, params_map, subqueries)
480
+ elif expr.args.get("group"):
481
+ parts = _parse_group_select(expr, params_map, subqueries)
482
+ else:
483
+ parts = _parse_select(expr, params_map, subqueries)
484
+ parts.subqueries = subqueries or None
485
+ return parts
486
+ if isinstance(expr, exp.Insert):
487
+ parts = _parse_insert(expr, params_map, subqueries)
488
+ parts.subqueries = subqueries or None
489
+ return parts
490
+ if isinstance(expr, exp.Update):
491
+ parts = _parse_update(expr, params_map, subqueries)
492
+ parts.subqueries = subqueries or None
493
+ return parts
494
+ if isinstance(expr, exp.Delete):
495
+ parts = _parse_delete(expr, params_map, subqueries)
496
+ parts.subqueries = subqueries or None
497
+ return parts
498
+ if isinstance(expr, exp.Create):
499
+ return _parse_create(expr)
500
+ if isinstance(expr, exp.Drop):
501
+ return _parse_drop(expr)
502
+ raise_error("[mdb][E2]", "Unsupported SQL construct: STATEMENT")
503
+
504
+
505
+ def _parse_select(expr: exp.Select, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
506
+ from_expr = expr.args.get("from_")
507
+ collection = None
508
+ inline_token = None
509
+ aggregates: list[tuple[str, str, str | None]] = []
510
+ if from_expr:
511
+ if hasattr(from_expr, "this") and isinstance(from_expr.this, exp.Table) and from_expr.this.name:
512
+ collection = from_expr.this.name
513
+ elif hasattr(from_expr, "this") and isinstance(from_expr.this, (exp.Subquery, exp.Select)):
514
+ inline_token = _register_subquery(from_expr.this, params_map, subqueries, mode="from")
515
+ else:
516
+ raise_error("[mdb][E5]", "Failed to parse SQL")
517
+ if not collection and not inline_token:
518
+ table = expr.find(exp.Table)
519
+ if table and table.name:
520
+ collection = table.name
521
+ else:
522
+ raise_error("[mdb][E5]", "Failed to parse SQL")
523
+ projection: List[str] | None = None
524
+ projection_paths: list[tuple[str, str]] | None = None
525
+ if not expr.is_star:
526
+ projection_paths = []
527
+ for item in expr.expressions:
528
+ target = item.this if isinstance(item, exp.Alias) else item
529
+ alias = item.alias_or_name
530
+ if isinstance(target, exp.Column):
531
+ projection_paths.append((_field_name(target, params_map), alias))
532
+ else:
533
+ projection_paths.append((alias, alias))
534
+ aggregates = []
535
+ for item in expr.expressions:
536
+ target = item.this if isinstance(item, exp.Alias) else item
537
+ alias = item.alias_or_name
538
+ if isinstance(target, exp.Count):
539
+ aggregates.append((alias, "count", None))
540
+ elif isinstance(target, exp.Sum):
541
+ aggregates.append((alias, "sum", _field_name(target.this, params_map)))
542
+ elif isinstance(target, exp.Avg):
543
+ aggregates.append((alias, "avg", _field_name(target.this, params_map)))
544
+ elif isinstance(target, exp.Min):
545
+ aggregates.append((alias, "min", _field_name(target.this, params_map)))
546
+ elif isinstance(target, exp.Max):
547
+ aggregates.append((alias, "max", _field_name(target.this, params_map)))
548
+
549
+ mongo_filter = None
550
+ if expr.args.get("where"):
551
+ mongo_filter = _condition_to_filter(expr.args["where"].this, params_map, subqueries)
552
+
553
+ sort_items = None
554
+ if expr.args.get("order"):
555
+ sort_items = []
556
+ for e in expr.args["order"].expressions:
557
+ field = _field_name(e.this, params_map)
558
+ direction = -1 if e.args.get("desc") else 1
559
+ sort_items.append((field, direction))
560
+
561
+ limit_val = None
562
+ if expr.args.get("limit"):
563
+ try:
564
+ limit_val = int(expr.args["limit"].expression.name)
565
+ except Exception:
566
+ limit_val = int(expr.args["limit"].expression.this)
567
+ skip_val = None
568
+ if expr.args.get("offset"):
569
+ try:
570
+ skip_val = int(expr.args["offset"].expression.name)
571
+ except Exception:
572
+ skip_val = int(expr.args["offset"].expression.this)
573
+
574
+ if aggregates:
575
+ if inline_token:
576
+ return QueryParts(
577
+ operation="from_subquery",
578
+ collection=collection or "",
579
+ filter=mongo_filter or {},
580
+ projection=[alias for alias, _, _ in aggregates],
581
+ sort=sort_items,
582
+ limit=limit_val,
583
+ skip=skip_val,
584
+ inline_token=inline_token,
585
+ inline_aggregates=aggregates,
586
+ projection_paths=[(alias, alias) for alias, _, _ in aggregates],
587
+ )
588
+ pipeline: list[dict[str, Any]] = []
589
+ if mongo_filter:
590
+ pipeline.append({"$match": mongo_filter})
591
+ group_doc: dict[str, Any] = {"_id": None}
592
+ for alias, op, field in aggregates:
593
+ if op == "count":
594
+ group_doc[alias] = {"$sum": 1}
595
+ elif op == "sum":
596
+ group_doc[alias] = {"$sum": f"${field}"}
597
+ elif op == "avg":
598
+ group_doc[alias] = {"$avg": f"${field}"}
599
+ elif op == "min":
600
+ group_doc[alias] = {"$min": f"${field}"}
601
+ elif op == "max":
602
+ group_doc[alias] = {"$max": f"${field}"}
603
+ pipeline.append({"$group": group_doc})
604
+ return QueryParts(
605
+ operation="aggregate",
606
+ collection=collection or "",
607
+ pipeline=pipeline,
608
+ projection_paths=[(alias, alias) for alias, _, _ in aggregates],
609
+ )
610
+
611
+ return QueryParts(
612
+ operation="from_subquery" if inline_token else "find",
613
+ collection=collection or "",
614
+ filter=mongo_filter or {},
615
+ projection=projection,
616
+ projection_paths=projection_paths,
617
+ sort=sort_items,
618
+ limit=limit_val,
619
+ skip=skip_val,
620
+ inline_token=inline_token,
621
+ )
622
+
623
+
624
+ def _parse_select_like(expr: exp.Select, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
625
+ if expr.args.get("joins"):
626
+ return _parse_join_select(expr, params_map, subqueries)
627
+ if expr.args.get("group"):
628
+ return _parse_group_select(expr, params_map, subqueries)
629
+ return _parse_select(expr, params_map, subqueries)
630
+
631
+
632
+ def _parse_window_select(expr: exp.Select, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
633
+ from_expr = expr.args.get("from_")
634
+ if not from_expr or not hasattr(from_expr, "this") or not from_expr.this.name:
635
+ raise_error("[mdb][E5]", "Failed to parse SQL")
636
+ if expr.args.get("joins"):
637
+ raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
638
+ collection = from_expr.this.name
639
+ if expr.args.get("group"):
640
+ raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
641
+ where_filter = None
642
+ if expr.args.get("where"):
643
+ where_filter = _condition_to_filter(expr.args["where"].this, params_map, subqueries)
644
+ window_expr = None
645
+ output_alias = None
646
+ base_columns: list[tuple[str, str]] = []
647
+ for item in expr.expressions:
648
+ target = item.this if isinstance(item, exp.Alias) else item
649
+ alias = item.alias_or_name
650
+ if isinstance(target, exp.Window) and isinstance(target.this, exp.RowNumber):
651
+ window_expr = target
652
+ output_alias = alias
653
+ elif isinstance(target, exp.Column):
654
+ base_columns.append((_field_name(target, params_map), alias))
655
+ else:
656
+ raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
657
+ if not window_expr or not output_alias:
658
+ raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
659
+ partition = window_expr.args.get("partition_by")
660
+ order = window_expr.args.get("order")
661
+ if partition and isinstance(partition, list) and len(partition) > 1:
662
+ raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
663
+ partition_expr = None
664
+ if partition:
665
+ target = partition[0] if isinstance(partition, list) else partition.expressions[0]
666
+ partition_expr = f"${_field_name(target, params_map)}"
667
+ sort_doc: dict[str, int] = {}
668
+ if order and order.expressions:
669
+ for e in order.expressions:
670
+ fld = _field_name(e.this, params_map)
671
+ direction = -1 if e.args.get("desc") else 1
672
+ sort_doc[fld] = direction
673
+ window_output = {output_alias: {"$documentNumber": {}}}
674
+ window_doc: dict[str, Any] = {"output": window_output}
675
+ if partition_expr:
676
+ window_doc["partitionBy"] = partition_expr
677
+ if sort_doc:
678
+ window_doc["sortBy"] = sort_doc
679
+ pipeline: list[dict[str, Any]] = []
680
+ if where_filter:
681
+ pipeline.append({"$match": where_filter})
682
+ pipeline.append({"$setWindowFields": window_doc})
683
+ project_doc: dict[str, Any] = {}
684
+ for path, alias in base_columns:
685
+ project_doc[alias] = f"${path}"
686
+ project_doc[output_alias] = f"${output_alias}"
687
+ if project_doc:
688
+ pipeline.append({"$project": project_doc})
689
+ projection_paths = [(alias, alias) for _, alias in base_columns]
690
+ projection_paths.append((output_alias, output_alias))
691
+ return QueryParts(
692
+ operation="aggregate",
693
+ collection=collection,
694
+ pipeline=pipeline,
695
+ projection_paths=projection_paths,
696
+ uses_window=True,
697
+ )
698
+
699
+
700
+ def _parse_join_select(expr: exp.Select, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
701
+ from_expr = expr.args.get("from_")
702
+ joins = expr.args.get("joins") or []
703
+ if not from_expr or not hasattr(from_expr, "this") or not from_expr.this.name or len(joins) < 1:
704
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN")
705
+ base_collection = from_expr.this.name
706
+ base_alias = from_expr.this.alias_or_name or base_collection
707
+
708
+ alias_map = {base_alias: "", base_collection: ""}
709
+ pipeline: list[dict] = []
710
+ join_prefixes: list[tuple[str, str, str, str, exp.Join]] = []
711
+
712
+ # prepare joins (up to 3)
713
+ if len(joins) > 3:
714
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_DEPTH")
715
+ for idx, join_expr in enumerate(joins):
716
+ if join_expr.kind and join_expr.kind.upper() not in ("INNER", "LEFT"):
717
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN")
718
+ on_expr = join_expr.args.get("on")
719
+ if not on_expr or not isinstance(on_expr, exp.EQ):
720
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_ON")
721
+ left_tbl, left_field = _column_table_field(on_expr.left)
722
+ right_tbl, right_field = _column_table_field(on_expr.right)
723
+ join_table = join_expr.this.this.name if hasattr(join_expr.this, "this") and hasattr(join_expr.this.this, "name") else None
724
+ join_alias = join_expr.this.alias_or_name or join_table
725
+ if not join_table or (left_tbl and left_tbl not in alias_map) or (
726
+ right_tbl and right_tbl not in (join_table, join_alias)
727
+ ):
728
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_TABLE")
729
+ prefix = f"__join{idx}"
730
+ alias_map[join_alias] = f"{prefix}."
731
+ alias_map[join_table] = f"{prefix}."
732
+ join_prefixes.append((prefix, join_table, left_field, right_field, join_expr, on_expr.left))
733
+
734
+ where_filter = None
735
+ if expr.args.get("where"):
736
+ where_filter = _condition_to_filter_alias(expr.args["where"].this, params_map, alias_map, subqueries)
737
+
738
+ for prefix, join_table, left_field, right_field, join_expr, left_expr in join_prefixes:
739
+ join_side = (join_expr.args.get("side") or "").upper()
740
+ preserve_null = bool(join_side == "LEFT" or (join_expr.kind and join_expr.kind.upper() == "LEFT"))
741
+ pipeline.append(
742
+ {
743
+ "$lookup": {
744
+ "from": join_table,
745
+ "localField": _field_with_alias(left_expr, alias_map),
746
+ "foreignField": right_field,
747
+ "as": prefix,
748
+ }
749
+ }
750
+ )
751
+ pipeline.append({"$unwind": {"path": f"${prefix}", "preserveNullAndEmptyArrays": preserve_null}})
752
+ if where_filter:
753
+ pipeline.append({"$match": where_filter})
754
+
755
+ if expr.args.get("order"):
756
+ sort_doc: dict[str, int] = {}
757
+ for e in expr.args["order"].expressions:
758
+ field = _field_name(e.this, params_map)
759
+ direction = -1 if e.args.get("desc") else 1
760
+ sort_doc[field] = direction
761
+ if sort_doc:
762
+ pipeline.append({"$sort": sort_doc})
763
+
764
+ if expr.args.get("limit"):
765
+ try:
766
+ limit_val = int(expr.args["limit"].expression.name)
767
+ except Exception:
768
+ limit_val = int(expr.args["limit"].expression.this)
769
+ pipeline.append({"$limit": limit_val})
770
+ if expr.args.get("offset"):
771
+ try:
772
+ skip_val = int(expr.args["offset"].expression.name)
773
+ except Exception:
774
+ skip_val = int(expr.args["offset"].expression.this)
775
+ pipeline.append({"$skip": skip_val})
776
+
777
+ projection_paths: list[tuple[str, str]] | None = None
778
+ if not expr.is_star:
779
+ projection_paths = []
780
+ for c in expr.expressions:
781
+ if isinstance(c, exp.Column):
782
+ tbl, fld = _column_table_field(c)
783
+ if tbl and tbl not in alias_map:
784
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_COLUMN")
785
+ path = _field_with_alias(c, alias_map)
786
+ out_name = c.alias_or_name or (f"{tbl}.{fld}" if tbl and tbl != base_collection else fld)
787
+ projection_paths.append((path, out_name))
788
+ else:
789
+ raise_error("[mdb][E2]", "Unsupported SQL construct: JOIN_PROJECTION")
790
+ return QueryParts(
791
+ operation="aggregate",
792
+ collection=base_collection,
793
+ pipeline=pipeline,
794
+ projection_paths=projection_paths,
795
+ )
796
+
797
+
798
+ def _parse_group_select(expr: exp.Select, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
799
+ table = expr.find(exp.Table)
800
+ if not table or not table.name:
801
+ raise_error("[mdb][E5]", "Failed to parse SQL")
802
+ pipeline: list[dict] = []
803
+ if expr.args.get("where"):
804
+ where_filter = _condition_to_filter(expr.args["where"].this, params_map, subqueries)
805
+ pipeline.append({"$match": where_filter})
806
+ group_fields = expr.args.get("group")
807
+ if not group_fields:
808
+ raise_error("[mdb][E5]", "Failed to parse SQL")
809
+ group_id: dict[str, str] = {}
810
+ group_cols: list[str] = []
811
+ for col in group_fields.expressions:
812
+ name = _field_name(col, params_map)
813
+ group_id[name] = f"${name}"
814
+ group_cols.append(name)
815
+
816
+ agg_fields: dict[str, dict] = {}
817
+ projection_paths: list[tuple[str, str]] = []
818
+ final_order: list[str] = []
819
+ seen_outputs: list[str] = []
820
+ for exp_item in expr.expressions:
821
+ target = exp_item.this if isinstance(exp_item, exp.Alias) else exp_item
822
+ alias = exp_item.alias_or_name
823
+ if alias in seen_outputs:
824
+ continue
825
+ seen_outputs.append(alias)
826
+ final_order.append(alias)
827
+ if isinstance(target, exp.Column):
828
+ col_name = _field_name(target, params_map)
829
+ agg_fields[alias] = {"$first": f"${col_name}"}
830
+ elif isinstance(target, exp.Count):
831
+ agg_fields[alias] = {"$sum": 1}
832
+ elif isinstance(target, exp.Sum):
833
+ col_name = _field_name(target.this, params_map)
834
+ agg_fields[alias] = {"$sum": f"${col_name}"}
835
+ elif isinstance(target, exp.Avg):
836
+ col_name = _field_name(target.this, params_map)
837
+ agg_fields[alias] = {"$avg": f"${col_name}"}
838
+ elif isinstance(target, exp.Min):
839
+ col_name = _field_name(target.this, params_map)
840
+ agg_fields[alias] = {"$min": f"${col_name}"}
841
+ elif isinstance(target, exp.Max):
842
+ col_name = _field_name(target.this, params_map)
843
+ agg_fields[alias] = {"$max": f"${col_name}"}
844
+ else:
845
+ raise_error("[mdb][E2]", "Unsupported SQL construct: GROUP_SELECT")
846
+
847
+ group_stage: dict[str, Any] = {"_id": group_id}
848
+ group_stage.update(agg_fields)
849
+ pipeline.append({"$group": group_stage})
850
+
851
+ having_filter = None
852
+ if expr.args.get("having"):
853
+ alias_map = {k: "" for k in list(group_cols) + list(agg_fields.keys())}
854
+ having_filter = _condition_to_filter_alias(expr.args["having"].this, params_map, alias_map, subqueries)
855
+
856
+ project_doc: dict[str, str] = {}
857
+ for key in final_order:
858
+ if key in group_cols:
859
+ project_doc[key] = f"$_id.{key}"
860
+ else:
861
+ project_doc[key] = f"${key}"
862
+ projection_paths.append((key, key))
863
+ if having_filter:
864
+ pipeline.append({"$match": having_filter})
865
+ pipeline.append({"$project": project_doc})
866
+
867
+ if expr.args.get("order"):
868
+ sort_doc: dict[str, int] = {}
869
+ for e in expr.args["order"].expressions:
870
+ field = _field_name(e.this, params_map)
871
+ direction = -1 if e.args.get("desc") else 1
872
+ sort_doc[field] = direction
873
+ if sort_doc:
874
+ pipeline.append({"$sort": sort_doc})
875
+ if expr.args.get("offset"):
876
+ try:
877
+ skip_val = int(expr.args["offset"].expression.name)
878
+ except Exception:
879
+ skip_val = int(expr.args["offset"].expression.this)
880
+ pipeline.append({"$skip": skip_val})
881
+ if expr.args.get("limit"):
882
+ try:
883
+ limit_val = int(expr.args["limit"].expression.name)
884
+ except Exception:
885
+ limit_val = int(expr.args["limit"].expression.this)
886
+ pipeline.append({"$limit": limit_val})
887
+
888
+ return QueryParts(
889
+ operation="aggregate",
890
+ collection=table.name,
891
+ pipeline=pipeline,
892
+ projection_paths=projection_paths,
893
+ )
894
+
895
+
896
+ def _parse_insert(expr: exp.Insert, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
897
+ table_expr = expr.this
898
+ columns: List[str] = []
899
+ table_name = None
900
+ if isinstance(table_expr, exp.Schema):
901
+ table_name = table_expr.this.name if table_expr.this else None
902
+ columns = [c.name for c in table_expr.expressions]
903
+ elif table_expr and table_expr.name:
904
+ table_name = table_expr.name
905
+ if not table_name:
906
+ raise_error("[mdb][E5]", "Failed to parse SQL")
907
+ values_exp = expr.expression
908
+ if not isinstance(values_exp, exp.Values):
909
+ raise_error("[mdb][E5]", "Failed to parse SQL")
910
+ if len(values_exp.expressions) != 1:
911
+ raise_error("[mdb][E5]", "Failed to parse SQL")
912
+ row = values_exp.expressions[0]
913
+ values = [_literal_value(v, params_map, subqueries) for v in row.expressions]
914
+ if columns and len(columns) != len(values):
915
+ raise_error("[mdb][E4]")
916
+ doc = dict(zip(columns, values)) if columns else dict(enumerate(values))
917
+ return QueryParts(
918
+ operation="insert",
919
+ collection=table_name,
920
+ values=doc,
921
+ )
922
+
923
+
924
+ def _parse_update(expr: exp.Update, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
925
+ table = expr.this
926
+ if not table or not table.name:
927
+ raise_error("[mdb][E5]", "Failed to parse SQL")
928
+ assignments = {}
929
+ set_exp = expr.args.get("expressions") or []
930
+ for assign in set_exp:
931
+ if not isinstance(assign, exp.EQ):
932
+ raise_error("[mdb][E5]")
933
+ field = _field_name(assign.left, params_map)
934
+ value = _literal_value(assign.right, params_map, subqueries)
935
+ assignments[field] = value
936
+ where_clause = expr.args.get("where")
937
+ if not where_clause:
938
+ raise_error("[mdb][E3]")
939
+ mongo_filter = _condition_to_filter(where_clause.this, params_map, subqueries)
940
+ return QueryParts(
941
+ operation="update",
942
+ collection=table.name,
943
+ filter=mongo_filter,
944
+ update={"$set": assignments},
945
+ )
946
+
947
+
948
+ def _parse_delete(expr: exp.Delete, params_map: dict[str, Any], subqueries: dict[str, dict[str, Any]]) -> QueryParts:
949
+ table = expr.this if hasattr(expr, "this") else None
950
+ if not table or not table.name:
951
+ raise_error("[mdb][E5]", "Failed to parse SQL")
952
+ where_clause = expr.args.get("where")
953
+ if not where_clause:
954
+ raise_error("[mdb][E3]")
955
+ mongo_filter = _condition_to_filter(where_clause.this, params_map, subqueries)
956
+ return QueryParts(
957
+ operation="delete",
958
+ collection=table.name,
959
+ filter=mongo_filter,
960
+ )
961
+
962
+
963
+ def _parse_create(expr: exp.Create) -> QueryParts:
964
+ table = expr.this.this.name if hasattr(expr.this, "this") and hasattr(expr.this.this, "name") else None
965
+ if not table:
966
+ raise_error("[mdb][E5]", "Failed to parse SQL")
967
+ return QueryParts(operation="create", collection=table)
968
+
969
+
970
+ def _parse_drop(expr: exp.Drop) -> QueryParts:
971
+ table = expr.this.this.name if hasattr(expr.this, "this") and hasattr(expr.this.this, "name") else None
972
+ if not table:
973
+ raise_error("[mdb][E5]", "Failed to parse SQL")
974
+ return QueryParts(operation="drop", collection=table)
975
+
976
+
977
+ def _parse_create_index_sql(sql: str) -> QueryParts | None:
978
+ m = CREATE_INDEX_RE.match(sql.strip())
979
+ if not m:
980
+ return None
981
+ unique = bool(m.group(1))
982
+ index_name = m.group(2)
983
+ table = m.group(3)
984
+ cols_raw = m.group(4)
985
+ keys: list[tuple[str, int]] = []
986
+ for col in cols_raw.split(","):
987
+ parts = col.strip().split()
988
+ if not parts:
989
+ continue
990
+ name = parts[0]
991
+ direction = 1
992
+ if len(parts) > 1 and parts[1].lower() == "desc":
993
+ direction = -1
994
+ keys.append((name, direction))
995
+ return QueryParts(operation="create_index", collection=table, index_keys=keys, index_name=index_name, unique=unique)
996
+
997
+
998
+ def _parse_drop_index_sql(sql: str) -> QueryParts | None:
999
+ m = DROP_INDEX_RE.match(sql.strip())
1000
+ if not m:
1001
+ return None
1002
+ index_name = m.group(1)
1003
+ table = m.group(2)
1004
+ return QueryParts(operation="drop_index", collection=table, index_name=index_name)