dbapi-mongodb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mongo_dbapi/dbapi.py ADDED
@@ -0,0 +1,493 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass, replace
5
+ from typing import Any, Iterable, List, Mapping, Sequence
6
+ import base64
7
+ import decimal
8
+ import uuid
9
+ import datetime
10
+
11
+ from bson import ObjectId
12
+ from pymongo import MongoClient
13
+ from pymongo.errors import ConnectionFailure, OperationFailure
14
+
15
+ from .errors import raise_error
16
+ from .translation import QueryParts, parse_sql
17
+
18
+ logger = logging.getLogger("mongo_dbapi")
19
+
20
+
21
+ def _convert_value(value: Any) -> Any:
22
+ """Convert Mongo value to Python-friendly value / Mongo の値を Python 向けに変換"""
23
+ if isinstance(value, ObjectId):
24
+ return str(value)
25
+ if isinstance(value, bytes):
26
+ return base64.b64encode(value).decode("ascii")
27
+ if isinstance(value, decimal.Decimal):
28
+ return str(value)
29
+ if isinstance(value, uuid.UUID):
30
+ return str(value)
31
+ if isinstance(value, dict) and "$binary" in value:
32
+ return value # leave as-is for now
33
+ return value
34
+
35
+
36
+ @dataclass
37
+ class CursorState:
38
+ rows: List[tuple] | None = None
39
+ rowcount: int = -1
40
+ lastrowid: Any | None = None
41
+ description: List[tuple] | None = None
42
+
43
+
44
+ class Cursor:
45
+ """DBAPI-like cursor / DBAPI 風カーソル"""
46
+
47
+ def __init__(self, connection: "Connection"):
48
+ self.connection = connection
49
+ self._state = CursorState()
50
+ self._closed = False
51
+
52
+ def execute(self, sql: str, params: Sequence | Mapping | None = None) -> "Cursor":
53
+ if self._closed:
54
+ raise_error("[mdb][E5]", "Failed to parse SQL")
55
+ parts = parse_sql(sql, params)
56
+ self._state = self.connection._execute_parts(parts) # noqa: SLF001
57
+ return self
58
+
59
+ def executemany(self, sql: str, seq_of_params: Sequence[Sequence | Mapping]) -> "Cursor":
60
+ if self._closed:
61
+ raise_error("[mdb][E5]", "Failed to parse SQL")
62
+ total_rows = 0
63
+ lastrowid = None
64
+ for params in seq_of_params:
65
+ parts = parse_sql(sql, params)
66
+ state = self.connection._execute_parts(parts) # noqa: SLF001
67
+ total_rows += state.rowcount
68
+ lastrowid = state.lastrowid
69
+ self._state = CursorState(rows=[], rowcount=total_rows, lastrowid=lastrowid, description=None)
70
+ return self
71
+
72
+ def fetchone(self) -> tuple | None:
73
+ if not self._state.rows:
74
+ return None
75
+ return self._state.rows.pop(0)
76
+
77
+ def fetchall(self) -> List[tuple]:
78
+ rows = self._state.rows or []
79
+ self._state.rows = []
80
+ return rows
81
+
82
+ @property
83
+ def rowcount(self) -> int:
84
+ return self._state.rowcount
85
+
86
+ @property
87
+ def lastrowid(self) -> Any:
88
+ return self._state.lastrowid
89
+
90
+ @property
91
+ def description(self) -> List[tuple] | None:
92
+ return self._state.description
93
+
94
+ def close(self) -> None:
95
+ self._closed = True
96
+
97
+
98
+ class Connection:
99
+ """DBAPI-like connection / DBAPI 風接続"""
100
+
101
+ def __init__(self, uri: str, db_name: str):
102
+ if not uri:
103
+ raise_error("[mdb][E1]")
104
+ self._uri = uri
105
+ self._db_name = db_name
106
+ try:
107
+ self._client = MongoClient(uri)
108
+ self._db = self._client[db_name]
109
+ except ConnectionFailure as exc:
110
+ raise_error("[mdb][E7]", cause=exc)
111
+ except OperationFailure as exc:
112
+ raise_error("[mdb][E8]", cause=exc)
113
+ self._session = None
114
+ self._transactions_supported = self._detect_transactions()
115
+ self._window_supported = self._detect_window()
116
+
117
+ def _detect_transactions(self) -> bool:
118
+ try:
119
+ info = self._client.server_info()
120
+ version_str = info.get("version", "0.0")
121
+ major = int(version_str.split(".")[0])
122
+ return major >= 4
123
+ except Exception:
124
+ return False
125
+
126
+ def _detect_window(self) -> bool:
127
+ try:
128
+ info = self._client.server_info()
129
+ version_str = info.get("version", "0.0")
130
+ major = int(version_str.split(".")[0])
131
+ minor = int(version_str.split(".")[1])
132
+ return major > 5 or (major == 5 and minor >= 0)
133
+ except Exception:
134
+ return False
135
+
136
+ def cursor(self) -> Cursor:
137
+ return Cursor(self)
138
+
139
+ def begin(self) -> None:
140
+ if not self._transactions_supported:
141
+ logger.debug("Transaction not supported; no-op / トランザクション非対応のため no-op")
142
+ return
143
+ logger.debug("Starting transaction session / トランザクションセッション開始")
144
+ self._session = self._client.start_session()
145
+ self._session.start_transaction()
146
+
147
+ def commit(self) -> None:
148
+ if self._session:
149
+ logger.debug("Commit transaction / トランザクションコミット")
150
+ self._session.commit_transaction()
151
+ self._session.end_session()
152
+ self._session = None
153
+
154
+ def rollback(self) -> None:
155
+ if self._session:
156
+ logger.debug("Abort transaction / トランザクションアボート")
157
+ self._session.abort_transaction()
158
+ self._session.end_session()
159
+ self._session = None
160
+
161
+ def close(self) -> None:
162
+ self._client.close()
163
+
164
+ def list_tables(self) -> list[str]:
165
+ return self._db.list_collection_names()
166
+
167
+ def _execute_parts(self, parts: QueryParts) -> CursorState:
168
+ if parts.subqueries:
169
+ parts = self._materialize_subqueries(parts)
170
+ if parts.uses_window and not self._window_supported:
171
+ raise_error("[mdb][E2]", "Unsupported SQL construct: WINDOW_FUNCTION")
172
+ if parts.operation == "from_subquery":
173
+ return self._execute_from_subquery(parts)
174
+ if parts.operation == "find":
175
+ return self._execute_find(parts)
176
+ if parts.operation == "insert":
177
+ return self._execute_insert(parts)
178
+ if parts.operation == "update":
179
+ return self._execute_update(parts)
180
+ if parts.operation == "delete":
181
+ return self._execute_delete(parts)
182
+ if parts.operation == "aggregate":
183
+ return self._execute_aggregate(parts)
184
+ if parts.operation == "create":
185
+ return self._execute_create(parts)
186
+ if parts.operation == "drop":
187
+ return self._execute_drop(parts)
188
+ if parts.operation == "create_index":
189
+ return self._execute_create_index(parts)
190
+ if parts.operation == "drop_index":
191
+ return self._execute_drop_index(parts)
192
+ if parts.operation == "union_all":
193
+ return self._execute_union_all(parts)
194
+ raise_error("[mdb][E2]")
195
+
196
+ def _materialize_subqueries(self, parts: QueryParts) -> QueryParts:
197
+ """Execute subqueries and substitute placeholders / サブクエリを実行し置換"""
198
+ resolved: dict[str, Any] = {}
199
+ for token, spec in (parts.subqueries or {}).items():
200
+ sub_parts: QueryParts = spec["parts"]
201
+ mode = spec.get("mode")
202
+ state = self._execute_parts(sub_parts)
203
+ if mode == "values":
204
+ resolved[token] = [row[0] for row in (state.rows or [])]
205
+ elif mode == "exists":
206
+ resolved[token] = bool(state.rows)
207
+ elif mode == "from":
208
+ cols = [c[0] for c in state.description or []] if state.description else []
209
+ rows_dicts = []
210
+ for row in state.rows or []:
211
+ rows_dicts.append({cols[i]: row[i] for i in range(len(cols))})
212
+ resolved[token] = rows_dicts
213
+ else:
214
+ resolved[token] = state.rows or []
215
+
216
+ def _replace(obj: Any) -> Any:
217
+ if isinstance(obj, str) and obj in resolved:
218
+ return resolved[obj]
219
+ if isinstance(obj, list):
220
+ return [_replace(v) for v in obj]
221
+ if isinstance(obj, dict):
222
+ return {k: _replace(v) for k, v in obj.items()}
223
+ return obj
224
+
225
+ return replace(
226
+ parts,
227
+ filter=_replace(parts.filter),
228
+ pipeline=_replace(parts.pipeline),
229
+ values=_replace(parts.values),
230
+ update=_replace(parts.update),
231
+ subqueries=None,
232
+ inline_token=None if (parts.inline_token and parts.inline_token in resolved) else parts.inline_token,
233
+ collection=_replace(parts.collection) if isinstance(parts.collection, dict) else parts.collection,
234
+ inline_rows=_replace(resolved.get(parts.inline_token)) if parts.inline_token else None,
235
+ )
236
+
237
+ def _match_filter(self, doc: dict, flt: Any) -> bool:
238
+ if flt is None:
239
+ return True
240
+ if isinstance(flt, dict):
241
+ for key, val in flt.items():
242
+ if key == "$and":
243
+ if not all(self._match_filter(doc, f) for f in val):
244
+ return False
245
+ continue
246
+ if key == "$or":
247
+ if not any(self._match_filter(doc, f) for f in val):
248
+ return False
249
+ continue
250
+ if key == "$expr":
251
+ # already reduced to literal truthy/falsey
252
+ return bool(val.get("$literal"))
253
+ actual = doc.get(key)
254
+ if isinstance(val, dict):
255
+ for op, expected in val.items():
256
+ if op == "$in":
257
+ if actual not in expected:
258
+ return False
259
+ elif op == "$gte":
260
+ if not (actual >= expected):
261
+ return False
262
+ elif op == "$lte":
263
+ if not (actual <= expected):
264
+ return False
265
+ elif op == "$gt":
266
+ if not (actual > expected):
267
+ return False
268
+ elif op == "$lt":
269
+ if not (actual < expected):
270
+ return False
271
+ elif op == "$ne":
272
+ if actual == expected:
273
+ return False
274
+ elif op == "$regex":
275
+ import re
276
+
277
+ if not isinstance(actual, str):
278
+ return False
279
+ flags = re.I if val.get("$options") == "i" else 0
280
+ if not re.match(expected, actual, flags):
281
+ return False
282
+ else:
283
+ return False
284
+ else:
285
+ if actual != val:
286
+ return False
287
+ return True
288
+ return False
289
+
290
+ def _execute_from_subquery(self, parts: QueryParts) -> CursorState:
291
+ rows = parts.inline_rows or []
292
+ filtered = [r for r in rows if self._match_filter(r, parts.filter)]
293
+ if parts.inline_aggregates:
294
+ agg_result: dict[str, Any] = {}
295
+ for alias, op, field in parts.inline_aggregates:
296
+ if op == "count":
297
+ agg_result[alias] = len(filtered)
298
+ elif op == "sum":
299
+ agg_result[alias] = sum((r.get(field) or 0) for r in filtered)
300
+ elif op == "avg":
301
+ vals = [r.get(field) for r in filtered if r.get(field) is not None]
302
+ agg_result[alias] = (sum(vals) / len(vals)) if vals else None
303
+ elif op == "min":
304
+ vals = [r.get(field) for r in filtered if r.get(field) is not None]
305
+ agg_result[alias] = min(vals) if vals else None
306
+ elif op == "max":
307
+ vals = [r.get(field) for r in filtered if r.get(field) is not None]
308
+ agg_result[alias] = max(vals) if vals else None
309
+ result_rows = [tuple(agg_result.get(alias) for alias, _, _ in parts.inline_aggregates)]
310
+ description = [(alias, None, None, None, None, None, None) for alias, _, _ in parts.inline_aggregates]
311
+ return CursorState(rows=result_rows, rowcount=len(result_rows), description=description)
312
+ if parts.sort:
313
+ for field, direction in reversed(parts.sort):
314
+ filtered.sort(key=lambda r, f=field: r.get(f), reverse=direction == -1)
315
+ if parts.skip:
316
+ filtered = filtered[parts.skip :]
317
+ if parts.limit:
318
+ filtered = filtered[: parts.limit]
319
+ columns = parts.projection or (sorted(filtered[0].keys()) if filtered else [])
320
+ result_rows = [tuple(_convert_value(r.get(c)) for c in columns) for r in filtered]
321
+ description = [(c, None, None, None, None, None, None) for c in columns] if columns else None
322
+ return CursorState(rows=result_rows, rowcount=len(result_rows), description=description)
323
+
324
+ def _execute_find(self, parts: QueryParts) -> CursorState:
325
+ proj = None
326
+ columns: list[str] | None = None
327
+ if parts.projection_paths:
328
+ proj = {path: 1 for path, _ in parts.projection_paths}
329
+ columns = [alias for _, alias in parts.projection_paths]
330
+ elif parts.projection:
331
+ proj = {field: 1 for field in parts.projection}
332
+ columns = parts.projection
333
+ logger.debug(
334
+ "Executing find / find 実行: collection=%s filter=%s projection=%s sort=%s limit=%s",
335
+ parts.collection,
336
+ parts.filter,
337
+ proj,
338
+ parts.sort,
339
+ parts.limit,
340
+ )
341
+ cursor = self._db[parts.collection].find(parts.filter or {}, projection=proj, session=self._session)
342
+ if parts.sort:
343
+ cursor = cursor.sort(parts.sort)
344
+ if parts.skip:
345
+ cursor = cursor.skip(parts.skip)
346
+ if parts.limit:
347
+ cursor = cursor.limit(parts.limit)
348
+ docs = list(cursor)
349
+ if columns is None and docs:
350
+ columns = sorted(docs[0].keys())
351
+ rows = []
352
+ for doc in docs:
353
+ if parts.projection_paths:
354
+ def _get_path(doc: dict, path: str) -> Any:
355
+ current = doc
356
+ for seg in path.split("."):
357
+ if isinstance(current, dict):
358
+ current = current.get(seg)
359
+ else:
360
+ current = None
361
+ return _convert_value(current)
362
+
363
+ row = tuple(_get_path(doc, path) for path, _ in parts.projection_paths)
364
+ else:
365
+ row = tuple(_convert_value(doc.get(col)) for col in columns or [])
366
+ rows.append(row)
367
+ description = None
368
+ if columns:
369
+ description = [(col, None, None, None, None, None, None) for col in columns]
370
+ return CursorState(rows=rows, rowcount=len(rows), description=description)
371
+
372
+ def _execute_insert(self, parts: QueryParts) -> CursorState:
373
+ logger.debug("Executing insert_one / insert_one 実行: collection=%s values=%s", parts.collection, parts.values)
374
+ doc = {k: _convert_value(v) for k, v in (parts.values or {}).items()}
375
+ result = self._db[parts.collection].insert_one(doc, session=self._session)
376
+ return CursorState(rows=[], rowcount=1, lastrowid=_convert_value(result.inserted_id))
377
+
378
+ def _execute_update(self, parts: QueryParts) -> CursorState:
379
+ logger.debug(
380
+ "Executing update_many / update_many 実行: collection=%s filter=%s update=%s",
381
+ parts.collection,
382
+ parts.filter,
383
+ parts.update,
384
+ )
385
+ result = self._db[parts.collection].update_many(parts.filter or {}, parts.update or {}, session=self._session)
386
+ return CursorState(rows=[], rowcount=result.modified_count)
387
+
388
+ def _execute_delete(self, parts: QueryParts) -> CursorState:
389
+ logger.debug(
390
+ "Executing delete_many / delete_many 実行: collection=%s filter=%s",
391
+ parts.collection,
392
+ parts.filter,
393
+ )
394
+ result = self._db[parts.collection].delete_many(parts.filter or {}, session=self._session)
395
+ return CursorState(rows=[], rowcount=result.deleted_count)
396
+
397
+ def _execute_create_index(self, parts: QueryParts) -> CursorState:
398
+ logger.debug(
399
+ "Executing create_index / インデックス作成: collection=%s name=%s keys=%s unique=%s",
400
+ parts.collection,
401
+ parts.index_name,
402
+ parts.index_keys,
403
+ parts.unique,
404
+ )
405
+ try:
406
+ self._db[parts.collection].create_index(parts.index_keys or [], name=parts.index_name, unique=parts.unique)
407
+ except Exception:
408
+ pass
409
+ return CursorState(rows=[], rowcount=0)
410
+
411
+ def _execute_drop_index(self, parts: QueryParts) -> CursorState:
412
+ logger.debug("Executing drop_index / インデックス削除: collection=%s name=%s", parts.collection, parts.index_name)
413
+ try:
414
+ self._db[parts.collection].drop_index(parts.index_name)
415
+ except Exception:
416
+ pass
417
+ return CursorState(rows=[], rowcount=0)
418
+
419
+ def _execute_union_all(self, parts: QueryParts) -> CursorState:
420
+ rows: list[tuple] = []
421
+ description = None
422
+ for sub in parts.union_parts or []:
423
+ state = self._execute_parts(sub)
424
+ if description is None:
425
+ description = state.description
426
+ rows.extend(state.rows or [])
427
+ if parts.sort:
428
+ rows.sort(key=lambda r: tuple(r[0:len(parts.sort)]), reverse=False)
429
+ if parts.limit is not None:
430
+ rows = rows[: parts.limit]
431
+ return CursorState(rows=rows, rowcount=len(rows), description=description)
432
+
433
+ def _execute_aggregate(self, parts: QueryParts) -> CursorState:
434
+ logger.debug("Executing aggregate / aggregate 実行: collection=%s pipeline=%s", parts.collection, parts.pipeline)
435
+ cursor = self._db[parts.collection].aggregate(parts.pipeline or [], session=self._session)
436
+ docs = list(cursor)
437
+ projection_paths = parts.projection_paths
438
+ rows: list[tuple] = []
439
+
440
+ def _get_path(doc: dict, path: str) -> Any:
441
+ current = doc
442
+ for seg in path.split("."):
443
+ if isinstance(current, dict):
444
+ current = current.get(seg)
445
+ else:
446
+ current = None
447
+ return _convert_value(current)
448
+
449
+ if projection_paths:
450
+ columns = [out for _, out in projection_paths]
451
+ for doc in docs:
452
+ row = tuple(_get_path(doc, path) for path, _ in projection_paths)
453
+ rows.append(row)
454
+ description = [(col, None, None, None, None, None, None) for col in columns]
455
+ else:
456
+ if docs:
457
+ join_keys = sorted([k for k in docs[0].keys() if k.startswith("__join")])
458
+ columns_left = sorted([k for k in docs[0].keys() if not k.startswith("__join")])
459
+ columns_join: list[str] = []
460
+ for idx, jk in enumerate(join_keys):
461
+ if isinstance(docs[0].get(jk), dict):
462
+ columns_join.extend([f"{jk}.{k}" for k in sorted(docs[0][jk].keys())])
463
+ columns = columns_left + columns_join
464
+ for doc in docs:
465
+ left_vals = tuple(_convert_value(doc.get(k)) for k in columns_left)
466
+ join_vals_list = []
467
+ for jk in join_keys:
468
+ join_doc = doc.get(jk) or {}
469
+ for col in sorted(join_doc.keys()) if isinstance(join_doc, dict) else []:
470
+ join_vals_list.append(_convert_value(join_doc.get(col)))
471
+ rows.append(left_vals + tuple(join_vals_list))
472
+ description = [(col, None, None, None, None, None, None) for col in columns] if columns else None
473
+ else:
474
+ description = None
475
+ return CursorState(rows=rows, rowcount=len(rows), description=description)
476
+
477
+ def _execute_create(self, parts: QueryParts) -> CursorState:
478
+ logger.debug("Executing create_collection / コレクション作成: %s", parts.collection)
479
+ try:
480
+ self._db.create_collection(parts.collection)
481
+ except Exception:
482
+ pass
483
+ return CursorState(rows=[], rowcount=0)
484
+
485
+ def _execute_drop(self, parts: QueryParts) -> CursorState:
486
+ logger.debug("Executing drop_collection / コレクション削除: %s", parts.collection)
487
+ self._db.drop_collection(parts.collection)
488
+ return CursorState(rows=[], rowcount=0)
489
+
490
+
491
+ def connect(uri: str, db_name: str, **_: Any) -> Connection:
492
+ """DBAPI entry point / DBAPI エントリーポイント"""
493
+ return Connection(uri, db_name)
mongo_dbapi/errors.py ADDED
@@ -0,0 +1,28 @@
1
+ class MongoDbApiError(Exception):
2
+ """Mongo DBAPI error with Error ID prefix / Mongo DBAPI エラー(Error ID 付き)"""
3
+
4
+ def __init__(self, code: str, message: str, cause: Exception | None = None):
5
+ full = f"{code} {message}"
6
+ super().__init__(full)
7
+ self.code = code
8
+ self.message = message
9
+ if cause:
10
+ self.__cause__ = cause
11
+
12
+
13
+ ERROR_MESSAGES: dict[str, str] = {
14
+ "[mdb][E1]": "Invalid connection URI",
15
+ "[mdb][E2]": "Unsupported SQL construct: <keyword>",
16
+ "[mdb][E3]": "Unsafe operation without WHERE",
17
+ "[mdb][E4]": "Parameter count mismatch",
18
+ "[mdb][E5]": "Failed to parse SQL",
19
+ "[mdb][E6]": "Transactions not supported on this server",
20
+ "[mdb][E7]": "Connection failed",
21
+ "[mdb][E8]": "Authentication failed",
22
+ }
23
+
24
+
25
+ def raise_error(code: str, message: str | None = None, cause: Exception | None = None) -> None:
26
+ """Raise MongoDbApiError with code / コード付き例外を送出"""
27
+ msg = message or ERROR_MESSAGES.get(code, "")
28
+ raise MongoDbApiError(code, msg, cause)
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Tuple
4
+
5
+ from sqlalchemy.engine import default, url
6
+ from sqlalchemy import pool
7
+
8
+ import mongo_dbapi
9
+
10
+ # Register dialect entry point style
11
+ from sqlalchemy.dialects import registry # noqa: E402
12
+
13
+ registry.register("mongodb+dbapi", "mongo_dbapi.sqlalchemy_dialect", "MongoDBAPIDialect")
14
+ registry.register("mongodb.dbapi", "mongo_dbapi.sqlalchemy_dialect", "MongoDBAPIDialect")
15
+
16
+
17
+ class MongoDBAPIDialect(default.DefaultDialect):
18
+ name = "mongodb+dbapi"
19
+ driver = "dbapi"
20
+ paramstyle = "pyformat"
21
+ supports_native_boolean = True
22
+ supports_sane_rowcount = False
23
+ supports_native_decimal = False
24
+ default_paramstyle = "pyformat"
25
+ poolclass = pool.SingletonThreadPool
26
+ supports_statement_cache = False
27
+ requires_name_normalize = True
28
+ driver = "mongo-dbapi"
29
+
30
+ @classmethod
31
+ def dbapi(cls):
32
+ return mongo_dbapi
33
+
34
+ @classmethod
35
+ def import_dbapi(cls):
36
+ return mongo_dbapi
37
+
38
+ def has_table(self, connection, table_name, schema=None, **kw):
39
+ db = connection.connection._db # pymongo database
40
+ return table_name in db.list_collection_names()
41
+
42
+ def get_driver_connection(self, connection):
43
+ return connection.connection
44
+
45
+ def create_connect_args(self, url_obj: url.URL) -> Tuple[tuple, Dict[str, Any]]:
46
+ host = url_obj.host or "127.0.0.1"
47
+ uri = f"mongodb://{host}"
48
+ if url_obj.port:
49
+ uri += f":{url_obj.port}"
50
+ if url_obj.username:
51
+ uri = uri.replace("mongodb://", f"mongodb://{url_obj.username}:{url_obj.password or ''}@")
52
+ db_name = url_obj.database or ""
53
+ return (), {"uri": uri, "db_name": db_name}
54
+
55
+
56
+ def dialect(**kwargs):
57
+ return MongoDBAPIDialect(**kwargs)