dbapi-mongodb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dbapi_mongodb-0.1.0.dist-info/METADATA +136 -0
- dbapi_mongodb-0.1.0.dist-info/RECORD +11 -0
- dbapi_mongodb-0.1.0.dist-info/WHEEL +5 -0
- dbapi_mongodb-0.1.0.dist-info/licenses/LICENSE +21 -0
- dbapi_mongodb-0.1.0.dist-info/top_level.txt +1 -0
- mongo_dbapi/__init__.py +13 -0
- mongo_dbapi/async_dbapi.py +74 -0
- mongo_dbapi/dbapi.py +493 -0
- mongo_dbapi/errors.py +28 -0
- mongo_dbapi/sqlalchemy_dialect.py +57 -0
- mongo_dbapi/translation.py +1004 -0
mongo_dbapi/dbapi.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass, replace
|
|
5
|
+
from typing import Any, Iterable, List, Mapping, Sequence
|
|
6
|
+
import base64
|
|
7
|
+
import decimal
|
|
8
|
+
import uuid
|
|
9
|
+
import datetime
|
|
10
|
+
|
|
11
|
+
from bson import ObjectId
|
|
12
|
+
from pymongo import MongoClient
|
|
13
|
+
from pymongo.errors import ConnectionFailure, OperationFailure
|
|
14
|
+
|
|
15
|
+
from .errors import raise_error
|
|
16
|
+
from .translation import QueryParts, parse_sql
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("mongo_dbapi")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _convert_value(value: Any) -> Any:
|
|
22
|
+
"""Convert Mongo value to Python-friendly value / Mongo の値を Python 向けに変換"""
|
|
23
|
+
if isinstance(value, ObjectId):
|
|
24
|
+
return str(value)
|
|
25
|
+
if isinstance(value, bytes):
|
|
26
|
+
return base64.b64encode(value).decode("ascii")
|
|
27
|
+
if isinstance(value, decimal.Decimal):
|
|
28
|
+
return str(value)
|
|
29
|
+
if isinstance(value, uuid.UUID):
|
|
30
|
+
return str(value)
|
|
31
|
+
if isinstance(value, dict) and "$binary" in value:
|
|
32
|
+
return value # leave as-is for now
|
|
33
|
+
return value
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class CursorState:
|
|
38
|
+
rows: List[tuple] | None = None
|
|
39
|
+
rowcount: int = -1
|
|
40
|
+
lastrowid: Any | None = None
|
|
41
|
+
description: List[tuple] | None = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Cursor:
|
|
45
|
+
"""DBAPI-like cursor / DBAPI 風カーソル"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, connection: "Connection"):
|
|
48
|
+
self.connection = connection
|
|
49
|
+
self._state = CursorState()
|
|
50
|
+
self._closed = False
|
|
51
|
+
|
|
52
|
+
def execute(self, sql: str, params: Sequence | Mapping | None = None) -> "Cursor":
|
|
53
|
+
if self._closed:
|
|
54
|
+
raise_error("[mdb][E5]", "Failed to parse SQL")
|
|
55
|
+
parts = parse_sql(sql, params)
|
|
56
|
+
self._state = self.connection._execute_parts(parts) # noqa: SLF001
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
def executemany(self, sql: str, seq_of_params: Sequence[Sequence | Mapping]) -> "Cursor":
|
|
60
|
+
if self._closed:
|
|
61
|
+
raise_error("[mdb][E5]", "Failed to parse SQL")
|
|
62
|
+
total_rows = 0
|
|
63
|
+
lastrowid = None
|
|
64
|
+
for params in seq_of_params:
|
|
65
|
+
parts = parse_sql(sql, params)
|
|
66
|
+
state = self.connection._execute_parts(parts) # noqa: SLF001
|
|
67
|
+
total_rows += state.rowcount
|
|
68
|
+
lastrowid = state.lastrowid
|
|
69
|
+
self._state = CursorState(rows=[], rowcount=total_rows, lastrowid=lastrowid, description=None)
|
|
70
|
+
return self
|
|
71
|
+
|
|
72
|
+
def fetchone(self) -> tuple | None:
|
|
73
|
+
if not self._state.rows:
|
|
74
|
+
return None
|
|
75
|
+
return self._state.rows.pop(0)
|
|
76
|
+
|
|
77
|
+
def fetchall(self) -> List[tuple]:
|
|
78
|
+
rows = self._state.rows or []
|
|
79
|
+
self._state.rows = []
|
|
80
|
+
return rows
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def rowcount(self) -> int:
|
|
84
|
+
return self._state.rowcount
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def lastrowid(self) -> Any:
|
|
88
|
+
return self._state.lastrowid
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def description(self) -> List[tuple] | None:
|
|
92
|
+
return self._state.description
|
|
93
|
+
|
|
94
|
+
def close(self) -> None:
|
|
95
|
+
self._closed = True
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Connection:
|
|
99
|
+
"""DBAPI-like connection / DBAPI 風接続"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, uri: str, db_name: str):
|
|
102
|
+
if not uri:
|
|
103
|
+
raise_error("[mdb][E1]")
|
|
104
|
+
self._uri = uri
|
|
105
|
+
self._db_name = db_name
|
|
106
|
+
try:
|
|
107
|
+
self._client = MongoClient(uri)
|
|
108
|
+
self._db = self._client[db_name]
|
|
109
|
+
except ConnectionFailure as exc:
|
|
110
|
+
raise_error("[mdb][E7]", cause=exc)
|
|
111
|
+
except OperationFailure as exc:
|
|
112
|
+
raise_error("[mdb][E8]", cause=exc)
|
|
113
|
+
self._session = None
|
|
114
|
+
self._transactions_supported = self._detect_transactions()
|
|
115
|
+
self._window_supported = self._detect_window()
|
|
116
|
+
|
|
117
|
+
def _detect_transactions(self) -> bool:
|
|
118
|
+
try:
|
|
119
|
+
info = self._client.server_info()
|
|
120
|
+
version_str = info.get("version", "0.0")
|
|
121
|
+
major = int(version_str.split(".")[0])
|
|
122
|
+
return major >= 4
|
|
123
|
+
except Exception:
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
def _detect_window(self) -> bool:
|
|
127
|
+
try:
|
|
128
|
+
info = self._client.server_info()
|
|
129
|
+
version_str = info.get("version", "0.0")
|
|
130
|
+
major = int(version_str.split(".")[0])
|
|
131
|
+
minor = int(version_str.split(".")[1])
|
|
132
|
+
return major > 5 or (major == 5 and minor >= 0)
|
|
133
|
+
except Exception:
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
def cursor(self) -> Cursor:
|
|
137
|
+
return Cursor(self)
|
|
138
|
+
|
|
139
|
+
def begin(self) -> None:
|
|
140
|
+
if not self._transactions_supported:
|
|
141
|
+
logger.debug("Transaction not supported; no-op / トランザクション非対応のため no-op")
|
|
142
|
+
return
|
|
143
|
+
logger.debug("Starting transaction session / トランザクションセッション開始")
|
|
144
|
+
self._session = self._client.start_session()
|
|
145
|
+
self._session.start_transaction()
|
|
146
|
+
|
|
147
|
+
def commit(self) -> None:
|
|
148
|
+
if self._session:
|
|
149
|
+
logger.debug("Commit transaction / トランザクションコミット")
|
|
150
|
+
self._session.commit_transaction()
|
|
151
|
+
self._session.end_session()
|
|
152
|
+
self._session = None
|
|
153
|
+
|
|
154
|
+
def rollback(self) -> None:
|
|
155
|
+
if self._session:
|
|
156
|
+
logger.debug("Abort transaction / トランザクションアボート")
|
|
157
|
+
self._session.abort_transaction()
|
|
158
|
+
self._session.end_session()
|
|
159
|
+
self._session = None
|
|
160
|
+
|
|
161
|
+
def close(self) -> None:
|
|
162
|
+
self._client.close()
|
|
163
|
+
|
|
164
|
+
def list_tables(self) -> list[str]:
|
|
165
|
+
return self._db.list_collection_names()
|
|
166
|
+
|
|
167
|
+
def _execute_parts(self, parts: QueryParts) -> CursorState:
|
|
168
|
+
if parts.subqueries:
|
|
169
|
+
parts = self._materialize_subqueries(parts)
|
|
170
|
+
if parts.uses_window and not self._window_supported:
|
|
171
|
+
raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
|
|
172
|
+
if parts.operation == "from_subquery":
|
|
173
|
+
return self._execute_from_subquery(parts)
|
|
174
|
+
if parts.operation == "find":
|
|
175
|
+
return self._execute_find(parts)
|
|
176
|
+
if parts.operation == "insert":
|
|
177
|
+
return self._execute_insert(parts)
|
|
178
|
+
if parts.operation == "update":
|
|
179
|
+
return self._execute_update(parts)
|
|
180
|
+
if parts.operation == "delete":
|
|
181
|
+
return self._execute_delete(parts)
|
|
182
|
+
if parts.operation == "aggregate":
|
|
183
|
+
return self._execute_aggregate(parts)
|
|
184
|
+
if parts.operation == "create":
|
|
185
|
+
return self._execute_create(parts)
|
|
186
|
+
if parts.operation == "drop":
|
|
187
|
+
return self._execute_drop(parts)
|
|
188
|
+
if parts.operation == "create_index":
|
|
189
|
+
return self._execute_create_index(parts)
|
|
190
|
+
if parts.operation == "drop_index":
|
|
191
|
+
return self._execute_drop_index(parts)
|
|
192
|
+
if parts.operation == "union_all":
|
|
193
|
+
return self._execute_union_all(parts)
|
|
194
|
+
raise_error("[mdb][E2]")
|
|
195
|
+
|
|
196
|
+
def _materialize_subqueries(self, parts: QueryParts) -> QueryParts:
|
|
197
|
+
"""Execute subqueries and substitute placeholders / サブクエリを実行し置換"""
|
|
198
|
+
resolved: dict[str, Any] = {}
|
|
199
|
+
for token, spec in (parts.subqueries or {}).items():
|
|
200
|
+
sub_parts: QueryParts = spec["parts"]
|
|
201
|
+
mode = spec.get("mode")
|
|
202
|
+
state = self._execute_parts(sub_parts)
|
|
203
|
+
if mode == "values":
|
|
204
|
+
resolved[token] = [row[0] for row in (state.rows or [])]
|
|
205
|
+
elif mode == "exists":
|
|
206
|
+
resolved[token] = bool(state.rows)
|
|
207
|
+
elif mode == "from":
|
|
208
|
+
cols = [c[0] for c in state.description or []] if state.description else []
|
|
209
|
+
rows_dicts = []
|
|
210
|
+
for row in state.rows or []:
|
|
211
|
+
rows_dicts.append({cols[i]: row[i] for i in range(len(cols))})
|
|
212
|
+
resolved[token] = rows_dicts
|
|
213
|
+
else:
|
|
214
|
+
resolved[token] = state.rows or []
|
|
215
|
+
|
|
216
|
+
def _replace(obj: Any) -> Any:
|
|
217
|
+
if isinstance(obj, str) and obj in resolved:
|
|
218
|
+
return resolved[obj]
|
|
219
|
+
if isinstance(obj, list):
|
|
220
|
+
return [_replace(v) for v in obj]
|
|
221
|
+
if isinstance(obj, dict):
|
|
222
|
+
return {k: _replace(v) for k, v in obj.items()}
|
|
223
|
+
return obj
|
|
224
|
+
|
|
225
|
+
return replace(
|
|
226
|
+
parts,
|
|
227
|
+
filter=_replace(parts.filter),
|
|
228
|
+
pipeline=_replace(parts.pipeline),
|
|
229
|
+
values=_replace(parts.values),
|
|
230
|
+
update=_replace(parts.update),
|
|
231
|
+
subqueries=None,
|
|
232
|
+
inline_token=None if (parts.inline_token and parts.inline_token in resolved) else parts.inline_token,
|
|
233
|
+
collection=_replace(parts.collection) if isinstance(parts.collection, dict) else parts.collection,
|
|
234
|
+
inline_rows=_replace(resolved.get(parts.inline_token)) if parts.inline_token else None,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def _match_filter(self, doc: dict, flt: Any) -> bool:
|
|
238
|
+
if flt is None:
|
|
239
|
+
return True
|
|
240
|
+
if isinstance(flt, dict):
|
|
241
|
+
for key, val in flt.items():
|
|
242
|
+
if key == "$and":
|
|
243
|
+
if not all(self._match_filter(doc, f) for f in val):
|
|
244
|
+
return False
|
|
245
|
+
continue
|
|
246
|
+
if key == "$or":
|
|
247
|
+
if not any(self._match_filter(doc, f) for f in val):
|
|
248
|
+
return False
|
|
249
|
+
continue
|
|
250
|
+
if key == "$expr":
|
|
251
|
+
# already reduced to literal truthy/falsey
|
|
252
|
+
return bool(val.get("$literal"))
|
|
253
|
+
actual = doc.get(key)
|
|
254
|
+
if isinstance(val, dict):
|
|
255
|
+
for op, expected in val.items():
|
|
256
|
+
if op == "$in":
|
|
257
|
+
if actual not in expected:
|
|
258
|
+
return False
|
|
259
|
+
elif op == "$gte":
|
|
260
|
+
if not (actual >= expected):
|
|
261
|
+
return False
|
|
262
|
+
elif op == "$lte":
|
|
263
|
+
if not (actual <= expected):
|
|
264
|
+
return False
|
|
265
|
+
elif op == "$gt":
|
|
266
|
+
if not (actual > expected):
|
|
267
|
+
return False
|
|
268
|
+
elif op == "$lt":
|
|
269
|
+
if not (actual < expected):
|
|
270
|
+
return False
|
|
271
|
+
elif op == "$ne":
|
|
272
|
+
if actual == expected:
|
|
273
|
+
return False
|
|
274
|
+
elif op == "$regex":
|
|
275
|
+
import re
|
|
276
|
+
|
|
277
|
+
if not isinstance(actual, str):
|
|
278
|
+
return False
|
|
279
|
+
flags = re.I if val.get("$options") == "i" else 0
|
|
280
|
+
if not re.match(expected, actual, flags):
|
|
281
|
+
return False
|
|
282
|
+
else:
|
|
283
|
+
return False
|
|
284
|
+
else:
|
|
285
|
+
if actual != val:
|
|
286
|
+
return False
|
|
287
|
+
return True
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
def _execute_from_subquery(self, parts: QueryParts) -> CursorState:
|
|
291
|
+
rows = parts.inline_rows or []
|
|
292
|
+
filtered = [r for r in rows if self._match_filter(r, parts.filter)]
|
|
293
|
+
if parts.inline_aggregates:
|
|
294
|
+
agg_result: dict[str, Any] = {}
|
|
295
|
+
for alias, op, field in parts.inline_aggregates:
|
|
296
|
+
if op == "count":
|
|
297
|
+
agg_result[alias] = len(filtered)
|
|
298
|
+
elif op == "sum":
|
|
299
|
+
agg_result[alias] = sum((r.get(field) or 0) for r in filtered)
|
|
300
|
+
elif op == "avg":
|
|
301
|
+
vals = [r.get(field) for r in filtered if r.get(field) is not None]
|
|
302
|
+
agg_result[alias] = (sum(vals) / len(vals)) if vals else None
|
|
303
|
+
elif op == "min":
|
|
304
|
+
vals = [r.get(field) for r in filtered if r.get(field) is not None]
|
|
305
|
+
agg_result[alias] = min(vals) if vals else None
|
|
306
|
+
elif op == "max":
|
|
307
|
+
vals = [r.get(field) for r in filtered if r.get(field) is not None]
|
|
308
|
+
agg_result[alias] = max(vals) if vals else None
|
|
309
|
+
result_rows = [tuple(agg_result.get(alias) for alias, _, _ in parts.inline_aggregates)]
|
|
310
|
+
description = [(alias, None, None, None, None, None, None) for alias, _, _ in parts.inline_aggregates]
|
|
311
|
+
return CursorState(rows=result_rows, rowcount=len(result_rows), description=description)
|
|
312
|
+
if parts.sort:
|
|
313
|
+
for field, direction in reversed(parts.sort):
|
|
314
|
+
filtered.sort(key=lambda r, f=field: r.get(f), reverse=direction == -1)
|
|
315
|
+
if parts.skip:
|
|
316
|
+
filtered = filtered[parts.skip :]
|
|
317
|
+
if parts.limit:
|
|
318
|
+
filtered = filtered[: parts.limit]
|
|
319
|
+
columns = parts.projection or (sorted(filtered[0].keys()) if filtered else [])
|
|
320
|
+
result_rows = [tuple(_convert_value(r.get(c)) for c in columns) for r in filtered]
|
|
321
|
+
description = [(c, None, None, None, None, None, None) for c in columns] if columns else None
|
|
322
|
+
return CursorState(rows=result_rows, rowcount=len(result_rows), description=description)
|
|
323
|
+
|
|
324
|
+
def _execute_find(self, parts: QueryParts) -> CursorState:
|
|
325
|
+
proj = None
|
|
326
|
+
columns: list[str] | None = None
|
|
327
|
+
if parts.projection_paths:
|
|
328
|
+
proj = {path: 1 for path, _ in parts.projection_paths}
|
|
329
|
+
columns = [alias for _, alias in parts.projection_paths]
|
|
330
|
+
elif parts.projection:
|
|
331
|
+
proj = {field: 1 for field in parts.projection}
|
|
332
|
+
columns = parts.projection
|
|
333
|
+
logger.debug(
|
|
334
|
+
"Executing find / find 実行: collection=%s filter=%s projection=%s sort=%s limit=%s",
|
|
335
|
+
parts.collection,
|
|
336
|
+
parts.filter,
|
|
337
|
+
proj,
|
|
338
|
+
parts.sort,
|
|
339
|
+
parts.limit,
|
|
340
|
+
)
|
|
341
|
+
cursor = self._db[parts.collection].find(parts.filter or {}, projection=proj, session=self._session)
|
|
342
|
+
if parts.sort:
|
|
343
|
+
cursor = cursor.sort(parts.sort)
|
|
344
|
+
if parts.skip:
|
|
345
|
+
cursor = cursor.skip(parts.skip)
|
|
346
|
+
if parts.limit:
|
|
347
|
+
cursor = cursor.limit(parts.limit)
|
|
348
|
+
docs = list(cursor)
|
|
349
|
+
if columns is None and docs:
|
|
350
|
+
columns = sorted(docs[0].keys())
|
|
351
|
+
rows = []
|
|
352
|
+
for doc in docs:
|
|
353
|
+
if parts.projection_paths:
|
|
354
|
+
def _get_path(doc: dict, path: str) -> Any:
|
|
355
|
+
current = doc
|
|
356
|
+
for seg in path.split("."):
|
|
357
|
+
if isinstance(current, dict):
|
|
358
|
+
current = current.get(seg)
|
|
359
|
+
else:
|
|
360
|
+
current = None
|
|
361
|
+
return _convert_value(current)
|
|
362
|
+
|
|
363
|
+
row = tuple(_get_path(doc, path) for path, _ in parts.projection_paths)
|
|
364
|
+
else:
|
|
365
|
+
row = tuple(_convert_value(doc.get(col)) for col in columns or [])
|
|
366
|
+
rows.append(row)
|
|
367
|
+
description = None
|
|
368
|
+
if columns:
|
|
369
|
+
description = [(col, None, None, None, None, None, None) for col in columns]
|
|
370
|
+
return CursorState(rows=rows, rowcount=len(rows), description=description)
|
|
371
|
+
|
|
372
|
+
def _execute_insert(self, parts: QueryParts) -> CursorState:
|
|
373
|
+
logger.debug("Executing insert_one / insert_one 実行: collection=%s values=%s", parts.collection, parts.values)
|
|
374
|
+
doc = {k: _convert_value(v) for k, v in (parts.values or {}).items()}
|
|
375
|
+
result = self._db[parts.collection].insert_one(doc, session=self._session)
|
|
376
|
+
return CursorState(rows=[], rowcount=1, lastrowid=_convert_value(result.inserted_id))
|
|
377
|
+
|
|
378
|
+
def _execute_update(self, parts: QueryParts) -> CursorState:
|
|
379
|
+
logger.debug(
|
|
380
|
+
"Executing update_many / update_many 実行: collection=%s filter=%s update=%s",
|
|
381
|
+
parts.collection,
|
|
382
|
+
parts.filter,
|
|
383
|
+
parts.update,
|
|
384
|
+
)
|
|
385
|
+
result = self._db[parts.collection].update_many(parts.filter or {}, parts.update or {}, session=self._session)
|
|
386
|
+
return CursorState(rows=[], rowcount=result.modified_count)
|
|
387
|
+
|
|
388
|
+
def _execute_delete(self, parts: QueryParts) -> CursorState:
|
|
389
|
+
logger.debug(
|
|
390
|
+
"Executing delete_many / delete_many 実行: collection=%s filter=%s",
|
|
391
|
+
parts.collection,
|
|
392
|
+
parts.filter,
|
|
393
|
+
)
|
|
394
|
+
result = self._db[parts.collection].delete_many(parts.filter or {}, session=self._session)
|
|
395
|
+
return CursorState(rows=[], rowcount=result.deleted_count)
|
|
396
|
+
|
|
397
|
+
def _execute_create_index(self, parts: QueryParts) -> CursorState:
|
|
398
|
+
logger.debug(
|
|
399
|
+
"Executing create_index / インデックス作成: collection=%s name=%s keys=%s unique=%s",
|
|
400
|
+
parts.collection,
|
|
401
|
+
parts.index_name,
|
|
402
|
+
parts.index_keys,
|
|
403
|
+
parts.unique,
|
|
404
|
+
)
|
|
405
|
+
try:
|
|
406
|
+
self._db[parts.collection].create_index(parts.index_keys or [], name=parts.index_name, unique=parts.unique)
|
|
407
|
+
except Exception:
|
|
408
|
+
pass
|
|
409
|
+
return CursorState(rows=[], rowcount=0)
|
|
410
|
+
|
|
411
|
+
def _execute_drop_index(self, parts: QueryParts) -> CursorState:
|
|
412
|
+
logger.debug("Executing drop_index / インデックス削除: collection=%s name=%s", parts.collection, parts.index_name)
|
|
413
|
+
try:
|
|
414
|
+
self._db[parts.collection].drop_index(parts.index_name)
|
|
415
|
+
except Exception:
|
|
416
|
+
pass
|
|
417
|
+
return CursorState(rows=[], rowcount=0)
|
|
418
|
+
|
|
419
|
+
def _execute_union_all(self, parts: QueryParts) -> CursorState:
|
|
420
|
+
rows: list[tuple] = []
|
|
421
|
+
description = None
|
|
422
|
+
for sub in parts.union_parts or []:
|
|
423
|
+
state = self._execute_parts(sub)
|
|
424
|
+
if description is None:
|
|
425
|
+
description = state.description
|
|
426
|
+
rows.extend(state.rows or [])
|
|
427
|
+
if parts.sort:
|
|
428
|
+
rows.sort(key=lambda r: tuple(r[0:len(parts.sort)]), reverse=False)
|
|
429
|
+
if parts.limit is not None:
|
|
430
|
+
rows = rows[: parts.limit]
|
|
431
|
+
return CursorState(rows=rows, rowcount=len(rows), description=description)
|
|
432
|
+
|
|
433
|
+
def _execute_aggregate(self, parts: QueryParts) -> CursorState:
|
|
434
|
+
logger.debug("Executing aggregate / aggregate 実行: collection=%s pipeline=%s", parts.collection, parts.pipeline)
|
|
435
|
+
cursor = self._db[parts.collection].aggregate(parts.pipeline or [], session=self._session)
|
|
436
|
+
docs = list(cursor)
|
|
437
|
+
projection_paths = parts.projection_paths
|
|
438
|
+
rows: list[tuple] = []
|
|
439
|
+
|
|
440
|
+
def _get_path(doc: dict, path: str) -> Any:
|
|
441
|
+
current = doc
|
|
442
|
+
for seg in path.split("."):
|
|
443
|
+
if isinstance(current, dict):
|
|
444
|
+
current = current.get(seg)
|
|
445
|
+
else:
|
|
446
|
+
current = None
|
|
447
|
+
return _convert_value(current)
|
|
448
|
+
|
|
449
|
+
if projection_paths:
|
|
450
|
+
columns = [out for _, out in projection_paths]
|
|
451
|
+
for doc in docs:
|
|
452
|
+
row = tuple(_get_path(doc, path) for path, _ in projection_paths)
|
|
453
|
+
rows.append(row)
|
|
454
|
+
description = [(col, None, None, None, None, None, None) for col in columns]
|
|
455
|
+
else:
|
|
456
|
+
if docs:
|
|
457
|
+
join_keys = sorted([k for k in docs[0].keys() if k.startswith("__join")])
|
|
458
|
+
columns_left = sorted([k for k in docs[0].keys() if not k.startswith("__join")])
|
|
459
|
+
columns_join: list[str] = []
|
|
460
|
+
for idx, jk in enumerate(join_keys):
|
|
461
|
+
if isinstance(docs[0].get(jk), dict):
|
|
462
|
+
columns_join.extend([f"{jk}.{k}" for k in sorted(docs[0][jk].keys())])
|
|
463
|
+
columns = columns_left + columns_join
|
|
464
|
+
for doc in docs:
|
|
465
|
+
left_vals = tuple(_convert_value(doc.get(k)) for k in columns_left)
|
|
466
|
+
join_vals_list = []
|
|
467
|
+
for jk in join_keys:
|
|
468
|
+
join_doc = doc.get(jk) or {}
|
|
469
|
+
for col in sorted(join_doc.keys()) if isinstance(join_doc, dict) else []:
|
|
470
|
+
join_vals_list.append(_convert_value(join_doc.get(col)))
|
|
471
|
+
rows.append(left_vals + tuple(join_vals_list))
|
|
472
|
+
description = [(col, None, None, None, None, None, None) for col in columns] if columns else None
|
|
473
|
+
else:
|
|
474
|
+
description = None
|
|
475
|
+
return CursorState(rows=rows, rowcount=len(rows), description=description)
|
|
476
|
+
|
|
477
|
+
def _execute_create(self, parts: QueryParts) -> CursorState:
|
|
478
|
+
logger.debug("Executing create_collection / コレクション作成: %s", parts.collection)
|
|
479
|
+
try:
|
|
480
|
+
self._db.create_collection(parts.collection)
|
|
481
|
+
except Exception:
|
|
482
|
+
pass
|
|
483
|
+
return CursorState(rows=[], rowcount=0)
|
|
484
|
+
|
|
485
|
+
def _execute_drop(self, parts: QueryParts) -> CursorState:
|
|
486
|
+
logger.debug("Executing drop_collection / コレクション削除: %s", parts.collection)
|
|
487
|
+
self._db.drop_collection(parts.collection)
|
|
488
|
+
return CursorState(rows=[], rowcount=0)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def connect(uri: str, db_name: str, **_: Any) -> Connection:
|
|
492
|
+
"""DBAPI entry point / DBAPI エントリーポイント"""
|
|
493
|
+
return Connection(uri, db_name)
|
mongo_dbapi/errors.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class MongoDbApiError(Exception):
|
|
2
|
+
"""Mongo DBAPI error with Error ID prefix / Mongo DBAPI エラー(Error ID 付き)"""
|
|
3
|
+
|
|
4
|
+
def __init__(self, code: str, message: str, cause: Exception | None = None):
|
|
5
|
+
full = f"{code} {message}"
|
|
6
|
+
super().__init__(full)
|
|
7
|
+
self.code = code
|
|
8
|
+
self.message = message
|
|
9
|
+
if cause:
|
|
10
|
+
self.__cause__ = cause
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
ERROR_MESSAGES: dict[str, str] = {
|
|
14
|
+
"[mdb][E1]": "Invalid connection URI",
|
|
15
|
+
"[mdb][E2]": "Unsupported SQL construct: <keyword>",
|
|
16
|
+
"[mdb][E3]": "Unsafe operation without WHERE",
|
|
17
|
+
"[mdb][E4]": "Parameter count mismatch",
|
|
18
|
+
"[mdb][E5]": "Failed to parse SQL",
|
|
19
|
+
"[mdb][E6]": "Transactions not supported on this server",
|
|
20
|
+
"[mdb][E7]": "Connection failed",
|
|
21
|
+
"[mdb][E8]": "Authentication failed",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def raise_error(code: str, message: str | None = None, cause: Exception | None = None) -> None:
|
|
26
|
+
"""Raise MongoDbApiError with code / コード付き例外を送出"""
|
|
27
|
+
msg = message or ERROR_MESSAGES.get(code, "")
|
|
28
|
+
raise MongoDbApiError(code, msg, cause)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Tuple
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.engine import default, url
|
|
6
|
+
from sqlalchemy import pool
|
|
7
|
+
|
|
8
|
+
import mongo_dbapi
|
|
9
|
+
|
|
10
|
+
# Register dialect entry point style
|
|
11
|
+
from sqlalchemy.dialects import registry # noqa: E402
|
|
12
|
+
|
|
13
|
+
registry.register("mongodb+dbapi", "mongo_dbapi.sqlalchemy_dialect", "MongoDBAPIDialect")
|
|
14
|
+
registry.register("mongodb.dbapi", "mongo_dbapi.sqlalchemy_dialect", "MongoDBAPIDialect")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MongoDBAPIDialect(default.DefaultDialect):
|
|
18
|
+
name = "mongodb+dbapi"
|
|
19
|
+
driver = "dbapi"
|
|
20
|
+
paramstyle = "pyformat"
|
|
21
|
+
supports_native_boolean = True
|
|
22
|
+
supports_sane_rowcount = False
|
|
23
|
+
supports_native_decimal = False
|
|
24
|
+
default_paramstyle = "pyformat"
|
|
25
|
+
poolclass = pool.SingletonThreadPool
|
|
26
|
+
supports_statement_cache = False
|
|
27
|
+
requires_name_normalize = True
|
|
28
|
+
driver = "mongo-dbapi"
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def dbapi(cls):
|
|
32
|
+
return mongo_dbapi
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def import_dbapi(cls):
|
|
36
|
+
return mongo_dbapi
|
|
37
|
+
|
|
38
|
+
def has_table(self, connection, table_name, schema=None, **kw):
|
|
39
|
+
db = connection.connection._db # pymongo database
|
|
40
|
+
return table_name in db.list_collection_names()
|
|
41
|
+
|
|
42
|
+
def get_driver_connection(self, connection):
|
|
43
|
+
return connection.connection
|
|
44
|
+
|
|
45
|
+
def create_connect_args(self, url_obj: url.URL) -> Tuple[tuple, Dict[str, Any]]:
|
|
46
|
+
host = url_obj.host or "127.0.0.1"
|
|
47
|
+
uri = f"mongodb://{host}"
|
|
48
|
+
if url_obj.port:
|
|
49
|
+
uri += f":{url_obj.port}"
|
|
50
|
+
if url_obj.username:
|
|
51
|
+
uri = uri.replace("mongodb://", f"mongodb://{url_obj.username}:{url_obj.password or ''}@")
|
|
52
|
+
db_name = url_obj.database or ""
|
|
53
|
+
return (), {"uri": uri, "db_name": db_name}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def dialect(**kwargs):
|
|
57
|
+
return MongoDBAPIDialect(**kwargs)
|