cfgit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cfg/adapters/mongo.py ADDED
@@ -0,0 +1,570 @@
1
+ # Copyright 2026 Mohammad Ausaf. Licensed under the Apache License, Version 2.0.
2
+ """Mongo StorageAdapter for cfgit."""
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+ from datetime import datetime, timezone
7
+ from typing import Any
8
+ from urllib.parse import unquote, urlsplit
9
+
10
+ from cfg.adapters.base import (
11
+ AmbiguousConfig,
12
+ ApplyResult,
13
+ AtomicityUnavailable,
14
+ AtomicityReport,
15
+ HistoryEnvMismatch,
16
+ NoSuchConfig,
17
+ ReconcileReport,
18
+ StaleHead,
19
+ StaleLive,
20
+ history_env_mismatch_message,
21
+ )
22
+ from cfg.core.config import ProjectConfig
23
+ from cfg.core.hashing import hash_doc
24
+
25
+ try: # pragma: no cover - exercised only when mongo extra is installed
26
+ from pymongo import ASCENDING, MongoClient
27
+ from pymongo.client_session import ClientSession
28
+ from pymongo.errors import OperationFailure
29
+ except ModuleNotFoundError as exc: # pragma: no cover
30
+ raise ModuleNotFoundError("install cfgit[mongo] to use MongoAdapter") from exc
31
+
32
+
33
+ class MongoAdapter:
34
+ def __init__(self, *, project: ProjectConfig, env_name: str):
35
+ env = project.envs[env_name]
36
+ if not env.uri:
37
+ raise ValueError(f"missing Mongo URI for env {env_name}")
38
+ self.project = project
39
+ self.env_name = env_name
40
+ self.runtime_uri = env.runtime_uri or env.uri
41
+ self.history_uri = env.history_uri or env.uri
42
+ self.runtime_db_name = env.runtime_db or env.db
43
+ self.history_db_name = env.history_db or env.db
44
+ self.client = MongoClient(self.history_uri)
45
+ self.history_client = self.client
46
+ self.runtime_client = self.client if self.runtime_uri == self.history_uri else MongoClient(self.runtime_uri)
47
+ self.db = self.runtime_client[self.runtime_db_name]
48
+ self.history_db = self.history_client[self.history_db_name]
49
+ self.history = self.history_db[project.history.history_collection]
50
+ self.heads = self.history_db[project.history.heads_collection]
51
+ self.refs = self.history_db[project.branches.refs_collection]
52
+
53
+ def get_record(self, collection: str, record_id: str) -> dict | None:
54
+ docs = list(self.db[collection].find(self._runtime_query(collection, record_id)).limit(2))
55
+ if len(docs) > 1:
56
+ raise AmbiguousConfig(f"{collection}:{record_id}")
57
+ return docs[0] if docs else None
58
+
59
+ def put_record(self, collection: str, record_id: str, doc: dict) -> None:
60
+ self._put_record(collection, record_id, doc, session=None)
61
+
62
+ def seed_record(self, collection: str, record_id: str, doc: dict) -> None:
63
+ self._seed_record(collection, record_id, doc, session=None)
64
+
65
+ def list_record_ids(self, collection: str) -> list[str]:
66
+ coll = self.project.collection(collection)
67
+ values = self.db[collection].distinct(coll.id_field, coll.live_when)
68
+ return sorted(str(v) for v in values if v is not None)
69
+
70
+ def get_head(self, collection: str, record_id: str) -> dict | None:
71
+ ptr = self.heads.find_one(self._head_query(collection, record_id))
72
+ if not ptr:
73
+ return None
74
+ row = self.history.find_one(
75
+ {
76
+ "env": self.env_name,
77
+ "collection": collection,
78
+ "record_id": record_id,
79
+ "seq": ptr["head_seq"],
80
+ }
81
+ )
82
+ return _history_row(row, with_doc=True) if row else None
83
+
84
+ def query_history(
85
+ self,
86
+ *,
87
+ collection: str | None = None,
88
+ record_id: str | None = None,
89
+ ref: str | None = None,
90
+ as_of_recorded: datetime | None = None,
91
+ as_of_valid: datetime | None = None,
92
+ tag: str | None = None,
93
+ git_sha: str | None = None,
94
+ limit: int | None = None,
95
+ order: str = "desc",
96
+ with_doc: bool = False,
97
+ ) -> list[dict]:
98
+ query: dict[str, Any] = {"env": self.env_name}
99
+ if collection is not None:
100
+ query["collection"] = collection
101
+ if record_id is not None:
102
+ query["record_id"] = record_id
103
+ if tag is not None:
104
+ query["tags"] = tag
105
+ if git_sha is not None:
106
+ query["git_shas"] = git_sha
107
+ if as_of_recorded is not None:
108
+ query["recorded_at"] = {"$lte": as_of_recorded}
109
+ if as_of_valid is not None:
110
+ query["valid_from"] = {"$lte": as_of_valid}
111
+ query["$or"] = [{"valid_to": None}, {"valid_to": {"$gt": as_of_valid}}]
112
+ if ref is not None:
113
+ if ref.startswith("@"):
114
+ query["seq"] = int(ref[1:])
115
+ else:
116
+ oid = ref.removeprefix("sha256:").removeprefix("#")
117
+ query["oid"] = {"$regex": f"^{oid}"}
118
+
119
+ projection = None if with_doc else {"doc": 0}
120
+ direction = -1 if order == "desc" else 1
121
+ cursor = self.history.find(query, projection).sort(
122
+ [("collection", ASCENDING), ("record_id", ASCENDING), ("seq", direction)]
123
+ )
124
+ if limit is not None:
125
+ cursor = cursor.limit(limit)
126
+ rows = [_history_row(row, with_doc=with_doc) for row in cursor]
127
+ if not rows and limit != 0 and collection is not None and record_id is not None:
128
+ self._raise_env_mismatch_if_history_exists(collection, record_id, query)
129
+ return rows
130
+
131
+ def list_tags(self) -> list[dict]:
132
+ pipeline = [
133
+ {"$match": {"env": self.env_name, "tags": {"$exists": True, "$ne": []}}},
134
+ {"$unwind": "$tags"},
135
+ {"$group": {"_id": "$tags", "count": {"$sum": 1}}},
136
+ {"$sort": {"_id": 1}},
137
+ ]
138
+ return [{"tag": row["_id"], "count": row["count"]} for row in self.history.aggregate(pipeline)]
139
+
140
+ def put_ref(self, doc: dict) -> None:
141
+ stored = deepcopy(doc)
142
+ stored["env"] = self.env_name
143
+ ref_type = str(stored["type"])
144
+ ref_id = str(stored["id"])
145
+ self.refs.replace_one(
146
+ {"env": self.env_name, "type": ref_type, "id": ref_id},
147
+ stored,
148
+ upsert=True,
149
+ )
150
+
151
+ def get_ref(self, ref_type: str, ref_id: str) -> dict | None:
152
+ row = self.refs.find_one({"env": self.env_name, "type": ref_type, "id": ref_id})
153
+ return _ref_row(row) if row else None
154
+
155
+ def list_refs(self, ref_type: str, **filters) -> list[dict]:
156
+ query: dict[str, Any] = {"env": self.env_name, "type": ref_type}
157
+ query.update({key: value for key, value in filters.items() if value is not None})
158
+ cursor = self.refs.find(query).sort([("created_at", ASCENDING), ("id", ASCENDING)])
159
+ return [_ref_row(row) for row in cursor]
160
+
161
+ def delete_ref(self, ref_type: str, ref_id: str) -> None:
162
+ self.refs.delete_one({"env": self.env_name, "type": ref_type, "id": ref_id})
163
+
164
+ def apply(
165
+ self,
166
+ *,
167
+ collection: str,
168
+ record_id: str,
169
+ new_doc: dict | None,
170
+ entry: dict,
171
+ expected_head_oid: str | None,
172
+ expected_live_oid: str | None = None,
173
+ make_head: bool = True,
174
+ seed_missing: bool = False,
175
+ ) -> ApplyResult:
176
+ atomicity = self.check_atomicity_scope()
177
+ if not atomicity.atomic:
178
+ raise AtomicityUnavailable(atomicity.reason)
179
+ for attempt in range(3):
180
+ try:
181
+ return self._apply_once(
182
+ collection=collection,
183
+ record_id=record_id,
184
+ new_doc=new_doc,
185
+ entry=entry,
186
+ expected_head_oid=expected_head_oid,
187
+ expected_live_oid=expected_live_oid,
188
+ make_head=make_head,
189
+ seed_missing=seed_missing,
190
+ )
191
+ except OperationFailure as exc:
192
+ if attempt >= 2 or not _is_transient_transaction_error(exc):
193
+ raise
194
+ raise RuntimeError("unreachable")
195
+
196
+ def _apply_once(
197
+ self,
198
+ *,
199
+ collection: str,
200
+ record_id: str,
201
+ new_doc: dict | None,
202
+ entry: dict,
203
+ expected_head_oid: str | None,
204
+ expected_live_oid: str | None,
205
+ make_head: bool,
206
+ seed_missing: bool,
207
+ ) -> ApplyResult:
208
+ coll_cfg = self.project.collection(collection)
209
+ with self.client.start_session() as session:
210
+ with session.start_transaction():
211
+ ptr = self.heads.find_one(self._head_query(collection, record_id), session=session)
212
+ current_head = ptr.get("head_oid") if ptr else None
213
+ if current_head != expected_head_oid:
214
+ raise StaleHead(current_head)
215
+
216
+ if expected_live_oid is not None:
217
+ live = self._get_record(collection, record_id, session=session)
218
+ if live is None:
219
+ raise NoSuchConfig(f"{collection}:{record_id}")
220
+ live_oid = hash_doc(live, coll_cfg)
221
+ if live_oid != expected_live_oid:
222
+ raise StaleLive(live_oid)
223
+
224
+ seq = int(ptr.get("head_seq", 0)) + 1 if ptr else 1
225
+ entry = dict(entry)
226
+ entry.update(
227
+ {
228
+ "env": self.env_name,
229
+ "collection": collection,
230
+ "record_id": record_id,
231
+ "seq": seq,
232
+ }
233
+ )
234
+ self.history.insert_one(entry, session=session)
235
+
236
+ if current_head:
237
+ self.history.update_one(
238
+ {
239
+ "env": self.env_name,
240
+ "collection": collection,
241
+ "record_id": record_id,
242
+ "seq": ptr["head_seq"],
243
+ "valid_to": None,
244
+ },
245
+ {"$set": {"valid_to": entry["valid_from"]}},
246
+ session=session,
247
+ )
248
+
249
+ if new_doc is not None:
250
+ if seed_missing:
251
+ if self._get_record(collection, record_id, session=session) is not None:
252
+ raise StaleLive(f"{collection}:{record_id} reappeared before restore")
253
+ self._seed_record(collection, record_id, new_doc, session=session)
254
+ else:
255
+ self._put_record(collection, record_id, new_doc, session=session)
256
+
257
+ if make_head:
258
+ head_filter = self._head_query(collection, record_id)
259
+ if ptr:
260
+ head_filter = {
261
+ **head_filter,
262
+ "head_oid": expected_head_oid,
263
+ "head_seq": ptr["head_seq"],
264
+ }
265
+ result = self.heads.update_one(
266
+ head_filter,
267
+ {
268
+ "$set": {
269
+ "head_oid": entry["oid"],
270
+ "head_seq": seq,
271
+ "updated_at": entry["recorded_at"],
272
+ },
273
+ "$setOnInsert": {
274
+ "env": self.env_name,
275
+ "collection": collection,
276
+ "record_id": record_id,
277
+ },
278
+ },
279
+ upsert=not ptr,
280
+ session=session,
281
+ )
282
+ if ptr and result.matched_count != 1:
283
+ raise StaleHead(expected_head_oid)
284
+
285
+ return ApplyResult(
286
+ collection=collection,
287
+ record_id=record_id,
288
+ seq=seq,
289
+ oid=entry["oid"],
290
+ head_oid=entry["oid"],
291
+ )
292
+
293
+ def add_tag(self, *, collection: str, record_id: str, seq: int, tag: str) -> None:
294
+ self.history.update_one(
295
+ {"env": self.env_name, "collection": collection, "record_id": record_id, "seq": seq},
296
+ {"$addToSet": {"tags": tag}},
297
+ )
298
+
299
+ def remove_tag(self, *, collection: str, record_id: str, seq: int, tag: str) -> None:
300
+ self.history.update_one(
301
+ {"env": self.env_name, "collection": collection, "record_id": record_id, "seq": seq},
302
+ {"$pull": {"tags": tag}},
303
+ )
304
+
305
+ def list_pending(self) -> list[dict]:
306
+ return list(self.history.find({"env": self.env_name, "pending": True}))
307
+
308
+ def reconcile(self) -> ReconcileReport:
309
+ return ReconcileReport(rolled_forward=[], rolled_back=[])
310
+
311
+ def ensure_schema(self) -> None:
312
+ self.history.create_index(
313
+ [("env", ASCENDING), ("collection", ASCENDING), ("record_id", ASCENDING), ("oid", ASCENDING)],
314
+ )
315
+ self.history.create_index(
316
+ [("env", ASCENDING), ("collection", ASCENDING), ("record_id", ASCENDING), ("seq", ASCENDING)],
317
+ unique=True,
318
+ )
319
+ self.history.create_index([("env", ASCENDING), ("recorded_at", ASCENDING)])
320
+ self.history.create_index([("env", ASCENDING), ("valid_from", ASCENDING)])
321
+ self.history.create_index([("env", ASCENDING), ("valid_to", ASCENDING)])
322
+ self.history.create_index([("tags", ASCENDING)])
323
+ self.heads.create_index(
324
+ [("env", ASCENDING), ("collection", ASCENDING), ("record_id", ASCENDING)],
325
+ unique=True,
326
+ )
327
+ if self.project.branches.enabled:
328
+ self.refs.create_index(
329
+ [("env", ASCENDING), ("type", ASCENDING), ("id", ASCENDING)],
330
+ unique=True,
331
+ )
332
+ self.refs.create_index([("env", ASCENDING), ("type", ASCENDING), ("branch", ASCENDING), ("created_at", ASCENDING)])
333
+ self.refs.create_index([("env", ASCENDING), ("type", ASCENDING), ("status", ASCENDING), ("updated_at", ASCENDING)])
334
+
335
+ def check_runtime_invariant(self, collection: str | None = None) -> list[str]:
336
+ names = [collection] if collection else [c.name for c in self.project.collections]
337
+ violations: list[str] = []
338
+ for name in names:
339
+ coll = self.project.collection(name)
340
+ pipeline = [
341
+ {"$match": coll.live_when},
342
+ {"$group": {"_id": f"${coll.id_field}", "n": {"$sum": 1}}},
343
+ {"$match": {"_id": {"$ne": None}, "n": {"$gt": 1}}},
344
+ {"$sort": {"_id": 1}},
345
+ ]
346
+ for row in self.db[name].aggregate(pipeline):
347
+ violations.append(f"{name}:{row['_id']} ({row['n']} live records)")
348
+ return violations
349
+
350
+ def check_atomicity_scope(self) -> AtomicityReport:
351
+ runtime_cluster = self._cluster_name(self.runtime_client)
352
+ history_cluster = self._cluster_name(self.history_client)
353
+ same_client = self.runtime_client is self.history_client
354
+ runtime_txn = self._supports_transactions(self.runtime_client)
355
+ history_txn = self._supports_transactions(self.history_client)
356
+ ok = same_client and runtime_txn and history_txn
357
+ if ok:
358
+ reason = "ok"
359
+ elif not same_client:
360
+ reason = (
361
+ "runtime and cfgit history use separate Mongo clients; v1 requires one URI/client "
362
+ "so live writes and history writes share a transaction"
363
+ )
364
+ else:
365
+ reason = "Mongo deployment is not a replica set or sharded cluster"
366
+ return AtomicityReport(
367
+ atomic=ok,
368
+ runtime_cluster=runtime_cluster,
369
+ history_cluster=history_cluster,
370
+ reason=reason,
371
+ )
372
+
373
+ def backend_name(self) -> str:
374
+ return "mongo"
375
+
376
+ def supports_transactions(self) -> bool:
377
+ return self._supports_transactions(self.history_client)
378
+
379
+ def authenticated_principal(self) -> str | None:
380
+ try:
381
+ status = self.runtime_client.admin.command("connectionStatus")
382
+ users = ((status.get("authInfo") or {}).get("authenticatedUsers") or [])
383
+ if users:
384
+ user = users[0]
385
+ name = str(user.get("user") or "").strip()
386
+ db = str(user.get("db") or "").strip()
387
+ if name and db:
388
+ return f"{name}@{db}"
389
+ if name:
390
+ return name
391
+ except Exception:
392
+ pass
393
+ parsed = urlsplit(self.runtime_uri)
394
+ username = unquote(parsed.username or "").strip()
395
+ auth_source = ""
396
+ if parsed.query:
397
+ for part in parsed.query.split("&"):
398
+ key, _, value = part.partition("=")
399
+ if key.lower() == "authsource":
400
+ auth_source = unquote(value)
401
+ break
402
+ if username and auth_source:
403
+ return f"{username}@{auth_source}"
404
+ return username or None
405
+
406
+ def now(self) -> datetime:
407
+ return datetime.now(timezone.utc)
408
+
409
+ def _runtime_query(self, collection: str, record_id: str) -> dict[str, Any]:
410
+ coll = self.project.collection(collection)
411
+ return {coll.id_field: record_id, **coll.live_when}
412
+
413
+ def _head_query(self, collection: str, record_id: str) -> dict[str, Any]:
414
+ return {"env": self.env_name, "collection": collection, "record_id": record_id}
415
+
416
+ def _raise_env_mismatch_if_history_exists(
417
+ self,
418
+ collection: str,
419
+ record_id: str,
420
+ current_query: dict[str, Any],
421
+ ) -> None:
422
+ history_query = {key: value for key, value in current_query.items() if key != "env"}
423
+ history_envs = self.history.distinct("env", history_query)
424
+ if self.env_name in {str(env) for env in history_envs or []}:
425
+ return
426
+ head_envs = []
427
+ if set(history_query) == {"collection", "record_id"}:
428
+ head_envs = self.heads.distinct("env", history_query)
429
+ if self.env_name in {str(env) for env in head_envs or []}:
430
+ return
431
+ other_envs = _other_env_names(history_envs, head_envs, current=self.env_name)
432
+ if other_envs:
433
+ raise HistoryEnvMismatch(
434
+ history_env_mismatch_message(
435
+ collection=collection,
436
+ record_id=record_id,
437
+ current_env=self.env_name,
438
+ other_envs=other_envs,
439
+ )
440
+ )
441
+
442
+ def _get_record(self, collection: str, record_id: str, *, session: ClientSession | None) -> dict | None:
443
+ docs = list(self.db[collection].find(self._runtime_query(collection, record_id), session=session).limit(2))
444
+ if len(docs) > 1:
445
+ raise AmbiguousConfig(f"{collection}:{record_id}")
446
+ return docs[0] if docs else None
447
+
448
+ def _put_record(
449
+ self,
450
+ collection: str,
451
+ record_id: str,
452
+ doc: dict,
453
+ *,
454
+ session: ClientSession | None,
455
+ ) -> None:
456
+ coll = self.project.collection(collection)
457
+ current = self._get_record(collection, record_id, session=session)
458
+ if current is None:
459
+ raise NoSuchConfig(f"{collection}:{record_id}")
460
+
461
+ effective = self._runtime_doc(collection, record_id, doc)
462
+ for path in coll.secret_fields:
463
+ if _get_path(effective, path) is None:
464
+ secret_value = _get_path(current, path)
465
+ if secret_value is not None:
466
+ _set_path(effective, path, secret_value)
467
+
468
+ protected = {"_id", *coll.ignore_fields}
469
+ set_doc = {k: v for k, v in effective.items() if k not in protected}
470
+ unset_doc = {
471
+ k: ""
472
+ for k in current
473
+ if k not in protected and k not in effective
474
+ }
475
+ update: dict[str, Any] = {}
476
+ if set_doc:
477
+ update["$set"] = set_doc
478
+ if unset_doc:
479
+ update["$unset"] = unset_doc
480
+ if not update:
481
+ return
482
+
483
+ result = self.db[collection].update_one(
484
+ self._runtime_query(collection, record_id),
485
+ update,
486
+ session=session,
487
+ )
488
+ if result.matched_count == 0:
489
+ raise NoSuchConfig(f"{collection}:{record_id}")
490
+
491
+ def _seed_record(
492
+ self,
493
+ collection: str,
494
+ record_id: str,
495
+ doc: dict,
496
+ *,
497
+ session: ClientSession | None,
498
+ ) -> None:
499
+ if self._get_record(collection, record_id, session=session) is not None:
500
+ raise AmbiguousConfig(f"{collection}:{record_id}")
501
+ self.db[collection].insert_one(
502
+ self._runtime_doc(collection, record_id, doc),
503
+ session=session,
504
+ )
505
+
506
+ def _runtime_doc(self, collection: str, record_id: str, doc: dict) -> dict[str, Any]:
507
+ coll = self.project.collection(collection)
508
+ effective = deepcopy(doc)
509
+ effective[coll.id_field] = record_id
510
+ for key, configured_value in coll.live_when.items():
511
+ effective[key] = configured_value
512
+ return effective
513
+
514
+ def _cluster_name(self, client: MongoClient) -> str:
515
+ hello = client.admin.command("hello")
516
+ hosts = ",".join(sorted(str(host) for host in hello.get("hosts", [])))
517
+ return str(hello.get("setName") or hello.get("msg") or hello.get("me") or hosts or "standalone")
518
+
519
+ def _supports_transactions(self, client: MongoClient) -> bool:
520
+ hello = client.admin.command("hello")
521
+ return bool(hello.get("setName") or hello.get("msg") == "isdbgrid")
522
+
523
+
524
+ def _history_row(row: dict[str, Any], *, with_doc: bool) -> dict[str, Any]:
525
+ return {key: value for key, value in row.items() if key != "_id" and (with_doc or key != "doc")}
526
+
527
+
528
+ def _ref_row(row: dict[str, Any]) -> dict[str, Any]:
529
+ return {key: value for key, value in row.items() if key != "_id"}
530
+
531
+
532
+ def _other_env_names(*env_lists: Any, current: str) -> list[str]:
533
+ seen: set[str] = set()
534
+ for envs in env_lists:
535
+ for env in envs or []:
536
+ if env is None:
537
+ continue
538
+ name = str(env)
539
+ if name and name != current:
540
+ seen.add(name)
541
+ return sorted(seen)
542
+
543
+
544
+ def _get_path(doc: dict[str, Any], dotted: str) -> Any:
545
+ cur: Any = doc
546
+ for part in dotted.split("."):
547
+ if not isinstance(cur, dict) or part not in cur:
548
+ return None
549
+ cur = cur[part]
550
+ return cur
551
+
552
+
553
+ def _set_path(doc: dict[str, Any], dotted: str, value: Any) -> None:
554
+ cur: Any = doc
555
+ parts = dotted.split(".")
556
+ for part in parts[:-1]:
557
+ nxt = cur.get(part)
558
+ if not isinstance(nxt, dict):
559
+ nxt = {}
560
+ cur[part] = nxt
561
+ cur = nxt
562
+ cur[parts[-1]] = value
563
+
564
+
565
+ def _is_transient_transaction_error(exc: OperationFailure) -> bool:
566
+ has_error_label = getattr(exc, "has_error_label", None)
567
+ if callable(has_error_label) and has_error_label("TransientTransactionError"):
568
+ return True
569
+ details = getattr(exc, "details", None) or {}
570
+ return "TransientTransactionError" in details.get("errorLabels", [])