mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,1412 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import random
8
+ import re
9
+ import time
10
+ from contextlib import asynccontextmanager
11
+ from datetime import datetime
12
+ from typing import (
13
+ TYPE_CHECKING,
14
+ Any,
15
+ Awaitable,
16
+ Callable,
17
+ Dict,
18
+ Generic,
19
+ List,
20
+ Mapping,
21
+ Optional,
22
+ Sequence,
23
+ Tuple,
24
+ Type,
25
+ TypeVar,
26
+ cast,
27
+ )
28
+
29
+ import aiologic
30
+
31
+ from mantisdk.utils.metrics import MetricsBackend
32
+
33
+ if TYPE_CHECKING:
34
+ from typing import Self
35
+
36
+ from pydantic import BaseModel, TypeAdapter
37
+ from pymongo import AsyncMongoClient, ReadPreference, ReturnDocument, WriteConcern
38
+ from pymongo.asynchronous.client_session import AsyncClientSession
39
+ from pymongo.asynchronous.collection import AsyncCollection
40
+ from pymongo.asynchronous.database import AsyncDatabase
41
+ from pymongo.errors import (
42
+ BulkWriteError,
43
+ CollectionInvalid,
44
+ ConnectionFailure,
45
+ DuplicateKeyError,
46
+ OperationFailure,
47
+ PyMongoError,
48
+ )
49
+ from pymongo.read_concern import ReadConcern
50
+
51
+ from mantisdk.types import (
52
+ Attempt,
53
+ FilterOptions,
54
+ PaginatedResult,
55
+ ResourcesUpdate,
56
+ Rollout,
57
+ SortOptions,
58
+ Span,
59
+ Worker,
60
+ )
61
+
62
+ from .base import (
63
+ AtomicLabels,
64
+ AtomicMode,
65
+ Collection,
66
+ DuplicatedPrimaryKeyError,
67
+ KeyValue,
68
+ LightningCollections,
69
+ Queue,
70
+ ensure_numeric,
71
+ normalize_filter_options,
72
+ resolve_sort_options,
73
+ tracked,
74
+ )
75
+
76
+ T_model = TypeVar("T_model", bound=BaseModel)
77
+
78
+ T_generic = TypeVar("T_generic")
79
+
80
+ T_mapping = TypeVar("T_mapping", bound=Mapping[str, Any])
81
+
82
+ T_callable = TypeVar("T_callable", bound=Callable[..., Any])
83
+
84
+ K = TypeVar("K")
85
+ V = TypeVar("V")
86
+
87
+ logger = logging.getLogger(__name__)
88
+
89
+
90
+ def resolve_mongo_error_type(exc: BaseException | None) -> str | None:
91
+ is_transient = isinstance(exc, PyMongoError) and exc.has_error_label("TransientTransactionError")
92
+ if isinstance(exc, OperationFailure):
93
+ if is_transient:
94
+ return f"OperationFailure-{exc.code}-Transient"
95
+ else:
96
+ return f"OperationFailure-{exc.code}"
97
+ if isinstance(exc, DuplicateKeyError):
98
+ return "DuplicateKeyError-Transient" if is_transient else "DuplicateKeyError"
99
+ if isinstance(exc, PyMongoError):
100
+ if is_transient:
101
+ return f"{exc.__class__.__name__}-Transient"
102
+ else:
103
+ return exc.__class__.__name__
104
+ if isinstance(exc, ConnectionFailure):
105
+ return "ConnectionFailure-Transient" if is_transient else "ConnectionFailure"
106
+ if is_transient:
107
+ return "Other-Transient"
108
+ else:
109
+ return None
110
+
111
+
112
+ def _field_ops_to_conditions(field: str, ops: Mapping[str, Any]) -> List[Dict[str, Any]]:
113
+ """Convert a FilterField (ops) into one or more Mongo conditions."""
114
+ conditions: List[Dict[str, Any]] = []
115
+
116
+ for op_name, raw_value in ops.items():
117
+ if op_name == "exact":
118
+ if raw_value is None:
119
+ logger.debug(f"Skipping exact filter for field '{field}' with None value")
120
+ continue
121
+ conditions.append({field: raw_value})
122
+ elif op_name == "within":
123
+ if raw_value is None:
124
+ logger.debug(f"Skipping within filter for field '{field}' with None value")
125
+ continue
126
+ try:
127
+ iterable = list(raw_value)
128
+ except TypeError as exc:
129
+ raise ValueError(f"Invalid iterable for within filter for field '{field}': {raw_value!r}") from exc
130
+ conditions.append({field: {"$in": iterable}})
131
+ elif op_name == "contains":
132
+ if raw_value is None:
133
+ logger.debug(f"Skipping contains filter for field '{field}' with None value")
134
+ continue
135
+ value = str(raw_value)
136
+ pattern = f".*{re.escape(value)}.*"
137
+ conditions.append({field: {"$regex": pattern, "$options": "i"}})
138
+ else:
139
+ raise ValueError(f"Unsupported filter operator '{op_name}' for field '{field}'")
140
+
141
+ return conditions
142
+
143
+
144
+ def _build_mongo_filter(filter_options: Optional[FilterOptions]) -> Dict[str, Any]:
145
+ """Translate FilterOptions into a MongoDB filter dict."""
146
+ normalized, must_filters, aggregate = normalize_filter_options(filter_options)
147
+
148
+ regular_conditions: List[Dict[str, Any]] = []
149
+ must_conditions: List[Dict[str, Any]] = []
150
+
151
+ # Normal filters
152
+ if normalized:
153
+ for field_name, ops in normalized.items():
154
+ regular_conditions.extend(_field_ops_to_conditions(field_name, ops))
155
+
156
+ # Must filters
157
+ if must_filters:
158
+ for field_name, ops in must_filters.items():
159
+ must_conditions.extend(_field_ops_to_conditions(field_name, ops))
160
+
161
+ # No filters at all
162
+ if not regular_conditions and not must_conditions:
163
+ return {}
164
+
165
+ # Aggregate logic for regular conditions; _must always ANDs in.
166
+ if aggregate == "and":
167
+ all_conds = regular_conditions + must_conditions
168
+ if len(all_conds) == 1:
169
+ return all_conds[0]
170
+ return {"$and": all_conds}
171
+
172
+ # aggregate == "or"
173
+ if regular_conditions and must_conditions:
174
+ # (OR of regular) AND (all must)
175
+ if len(regular_conditions) == 1:
176
+ or_part: Dict[str, Any] = regular_conditions[0]
177
+ else:
178
+ or_part = {"$or": regular_conditions}
179
+
180
+ and_parts: List[Dict[str, Any]] = [or_part] + must_conditions
181
+ if len(and_parts) == 1:
182
+ return and_parts[0]
183
+ return {"$and": and_parts}
184
+
185
+ if regular_conditions:
186
+ if len(regular_conditions) == 1:
187
+ return regular_conditions[0]
188
+ return {"$or": regular_conditions}
189
+
190
+ # Only must conditions
191
+ if len(must_conditions) == 1:
192
+ return must_conditions[0]
193
+ return {"$and": must_conditions}
194
+
195
+
196
+ async def _ensure_collection(
197
+ db: AsyncDatabase[Mapping[str, Any]],
198
+ collection_name: str,
199
+ primary_keys: Optional[Sequence[str]] = None,
200
+ extra_indexes: Optional[Sequence[Sequence[str]]] = None,
201
+ ) -> bool:
202
+ """Ensure the backing MongoDB collection exists.
203
+
204
+ This method is idempotent and safe to call multiple times.
205
+ """
206
+ # Create collection if it doesn't exist yet
207
+ try:
208
+ await db.create_collection(collection_name)
209
+ except CollectionInvalid as exc:
210
+ # Thrown if collection already exists
211
+ logger.debug(f"Collection '{collection_name}' may have already existed. No need to create it: {exc!r}")
212
+ except OperationFailure as exc:
213
+ logger.debug(f"Failed to create collection '{collection_name}'. Probably already exists: {exc!r}")
214
+ # Some servers use OperationFailure w/ specific codes for "NamespaceExists"
215
+ if exc.code in (48, 68): # 48: NamespaceExists, 68: already exists on older versions
216
+ pass
217
+ else:
218
+ raise
219
+
220
+ # Optionally create a unique index on primary keys (scoped by partition_id)
221
+ if primary_keys:
222
+ # Always include the partition field in the unique index.
223
+ keys = [("partition_id", 1)] + [(pk, 1) for pk in primary_keys]
224
+ try:
225
+ await db[collection_name].create_index(keys, name=f"uniq_partition_{'_'.join(primary_keys)}", unique=True)
226
+ except OperationFailure as exc:
227
+ logger.debug(f"Index for collection '{collection_name}' already exists. No need to create it: {exc!r}")
228
+ # Ignore "index already exists" type errors
229
+ if exc.code in (68, 85): # IndexOptionsConflict, etc.
230
+ pass
231
+ else:
232
+ raise
233
+
234
+ # Optionally create extra indexes
235
+ if extra_indexes:
236
+ for index in extra_indexes:
237
+ try:
238
+ await db[collection_name].create_index(index, name=f"idx_{'_'.join(index)}")
239
+ except OperationFailure as exc:
240
+ logger.debug(f"Index for collection '{collection_name}' already exists. No need to create it: {exc!r}")
241
+ # Ignore "index already exists" type errors
242
+ if exc.code in (68, 85): # IndexOptionsConflict, etc.
243
+ pass
244
+ else:
245
+ raise
246
+
247
+ return True
248
+
249
+
250
+ class MongoClientPool(Generic[T_mapping]):
251
+ """A pool of MongoDB clients, each bound to a specific event loop.
252
+
253
+ The pool lazily creates `AsyncMongoClient` instances per event loop using the provided
254
+ connection parameters, ensuring we never try to reuse a client across loops.
255
+ """
256
+
257
+ def __init__(self, *, mongo_uri: str, mongo_client_kwargs: Mapping[str, Any] | None = None):
258
+ self._get_collection_lock = aiologic.Lock()
259
+ self._get_client_lock = aiologic.Lock()
260
+ self._mongo_uri = mongo_uri
261
+ self._mongo_client_kwargs = dict(mongo_client_kwargs or {})
262
+ self._client_pool: Dict[int, AsyncMongoClient[T_mapping]] = {}
263
+ self._collection_pool: Dict[Tuple[int, str, str], AsyncCollection[T_mapping]] = {}
264
+
265
+ async def __aenter__(self) -> Self:
266
+ return self
267
+
268
+ async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any) -> None:
269
+ await self.close()
270
+
271
+ async def close(self) -> None:
272
+ """Close all clients currently tracked by the pool."""
273
+
274
+ async with self._get_client_lock, self._get_collection_lock:
275
+ clients = list(self._client_pool.values())
276
+ self._client_pool.clear()
277
+ self._collection_pool.clear()
278
+
279
+ for client in clients:
280
+ try:
281
+ await client.close()
282
+ except Exception:
283
+ logger.exception("Error closing MongoDB client: %s", client)
284
+
285
+ async def get_client(self) -> AsyncMongoClient[T_mapping]:
286
+ loop = asyncio.get_running_loop()
287
+ key = id(loop)
288
+
289
+ # If there is already a client specifically for this loop, return it.
290
+ existing = self._client_pool.get(key)
291
+ if existing is not None:
292
+ await existing.aconnect() # This actually does nothing if the client is already connected.
293
+ return existing
294
+
295
+ async with self._get_client_lock:
296
+ # Another coroutine may have already created the client.
297
+ if key in self._client_pool:
298
+ await self._client_pool[key].aconnect()
299
+ return self._client_pool[key]
300
+
301
+ # Create a new client for this loop.
302
+ client = AsyncMongoClient[T_mapping](self._mongo_uri, **self._mongo_client_kwargs)
303
+ await client.aconnect()
304
+ self._client_pool[key] = client
305
+ return client
306
+
307
+ async def get_collection(self, database_name: str, collection_name: str) -> AsyncCollection[T_mapping]:
308
+ loop = asyncio.get_running_loop()
309
+ key = (id(loop), database_name, collection_name)
310
+ if key in self._collection_pool:
311
+ return self._collection_pool[key]
312
+
313
+ async with self._get_collection_lock:
314
+ # Another coroutine may have already created the collection.
315
+ if key in self._collection_pool:
316
+ return self._collection_pool[key]
317
+
318
+ # Create a new collection for this loop.
319
+ client = await self.get_client()
320
+ collection = client[database_name][collection_name]
321
+ self._collection_pool.setdefault(key, collection)
322
+ return collection
323
+
324
+
325
+ class MongoBasedCollection(Collection[T_model]):
326
+ """Mongo-based implementation of Collection.
327
+
328
+ Args:
329
+ client_pool: The pool of MongoDB clients.
330
+ database_name: The name of the database.
331
+ collection_name: The name of the collection.
332
+ partition_id: The partition ID. Used to partition the collection into multiple collections.
333
+ primary_keys: The primary keys of the collection.
334
+ item_type: The type of the items in the collection.
335
+ extra_indexes: The extra indexes to create on the collection.
336
+ tracker: The metrics tracker to use.
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ client_pool: MongoClientPool[Mapping[str, Any]],
342
+ database_name: str,
343
+ collection_name: str,
344
+ partition_id: str,
345
+ primary_keys: Sequence[str],
346
+ item_type: Type[T_model],
347
+ extra_indexes: Sequence[Sequence[str]] = [],
348
+ tracker: MetricsBackend | None = None,
349
+ ):
350
+ super().__init__(tracker=tracker)
351
+ self._client_pool = client_pool
352
+ self._database_name = database_name
353
+ self._collection_name = collection_name
354
+ self._partition_id = partition_id
355
+ self._collection_created = False
356
+ self._extra_indexes = [list(index) for index in extra_indexes]
357
+ self._session: Optional[AsyncClientSession] = None
358
+
359
+ if not primary_keys:
360
+ raise ValueError("primary_keys must be non-empty")
361
+ self._primary_keys = list(primary_keys)
362
+
363
+ if not issubclass(item_type, BaseModel): # type: ignore
364
+ raise ValueError(f"item_type must be a subclass of BaseModel, got {item_type.__name__}")
365
+ self._item_type = item_type
366
+
367
+ @property
368
+ def collection_name(self) -> str:
369
+ return self._collection_name
370
+
371
+ @property
372
+ def extra_tracking_labels(self) -> Mapping[str, str]:
373
+ return {
374
+ "database": self._database_name,
375
+ }
376
+
377
+ @tracked("ensure_collection")
378
+ async def ensure_collection(self) -> AsyncCollection[Mapping[str, Any]]:
379
+ """Ensure the backing MongoDB collection exists (and optionally its indexes).
380
+
381
+ This method is idempotent and safe to call multiple times.
382
+
383
+ It will also create a unique index across the configured primary key fields.
384
+ """
385
+ if not self._collection_created:
386
+ client = await self._client_pool.get_client()
387
+ self._collection_created = await _ensure_collection(
388
+ client[self._database_name], self._collection_name, self._primary_keys, self._extra_indexes
389
+ )
390
+
391
+ return await self._client_pool.get_collection(self._database_name, self._collection_name)
392
+
393
+ def with_session(self, session: AsyncClientSession) -> MongoBasedCollection[T_model]:
394
+ """Create a new collection with the same configuration but a new session."""
395
+ collection = MongoBasedCollection(
396
+ client_pool=self._client_pool,
397
+ database_name=self._database_name,
398
+ collection_name=self._collection_name,
399
+ partition_id=self._partition_id,
400
+ primary_keys=self._primary_keys,
401
+ item_type=self._item_type,
402
+ extra_indexes=self._extra_indexes,
403
+ tracker=self._tracker,
404
+ )
405
+ collection._collection_created = self._collection_created
406
+ collection._session = session
407
+ return collection
408
+
409
+ def primary_keys(self) -> Sequence[str]:
410
+ """Return the primary key field names for this collection."""
411
+ return self._primary_keys
412
+
413
+ def item_type(self) -> Type[T_model]:
414
+ return self._item_type
415
+
416
+ @tracked("size")
417
+ async def size(self) -> int:
418
+ collection = await self.ensure_collection()
419
+ return await collection.count_documents({"partition_id": self._partition_id}, session=self._session)
420
+
421
+ def _pk_filter(self, item: T_model) -> Dict[str, Any]:
422
+ """Build a Mongo filter for the primary key(s) of a model instance."""
423
+ data = item.model_dump()
424
+ missing = [pk for pk in self._primary_keys if pk not in data]
425
+ if missing:
426
+ raise ValueError(f"Missing primary key fields {missing} on item {item!r}")
427
+ pk_filter: Dict[str, Any] = {"partition_id": self._partition_id}
428
+ pk_filter.update({pk: data[pk] for pk in self._primary_keys})
429
+ return pk_filter
430
+
431
+ def _render_pk_values(self, values: Sequence[Any]) -> str:
432
+ return ", ".join(f"{pk}={value!r}" for pk, value in zip(self._primary_keys, values))
433
+
434
+ def _ensure_item_type(self, item: T_model) -> None:
435
+ if not isinstance(item, self._item_type):
436
+ raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
437
+
438
+ def _inject_partition_filter(self, filter: Optional[FilterOptions]) -> Dict[str, Any]:
439
+ """Ensure every query is scoped to this collection's partition."""
440
+ combined: Dict[str, Any]
441
+ if filter is None:
442
+ combined = {}
443
+ else:
444
+ combined = dict(filter)
445
+
446
+ partition_must = {"partition_id": {"exact": self._partition_id}}
447
+ existing_must = combined.get("_must")
448
+ if existing_must is None:
449
+ combined["_must"] = partition_must
450
+ return combined
451
+
452
+ if isinstance(existing_must, Mapping):
453
+ combined["_must"] = [existing_must, partition_must]
454
+ elif isinstance(existing_must, Sequence) and not isinstance(existing_must, (str, bytes)):
455
+ combined["_must"] = [*existing_must, partition_must]
456
+ else:
457
+ raise TypeError("`_must` filters must be a mapping or sequence of mappings")
458
+
459
+ return combined
460
+
461
+ def _model_validate_item(self, raw: Mapping[str, Any]) -> T_model:
462
+ item_type_has_id = "_id" in self._item_type.model_fields
463
+ # Remove _id from the raw document if the item type does not have it.
464
+ if not item_type_has_id:
465
+ raw = {k: v for k, v in raw.items() if k != "_id"}
466
+ # Convert Mongo document to Pydantic model
467
+ return self._item_type.model_validate(raw) # type: ignore[arg-type]
468
+
469
+ @tracked("query")
470
+ async def query(
471
+ self,
472
+ filter: Optional[FilterOptions] = None,
473
+ sort: Optional[SortOptions] = None,
474
+ limit: int = -1,
475
+ offset: int = 0,
476
+ ) -> PaginatedResult[T_model]:
477
+ """Mongo-based implementation of Collection.query.
478
+
479
+ The handling of null-values in sorting is different from memory-based implementation.
480
+ In MongoDB, null values are treated as less than non-null values.
481
+ """
482
+ collection = await self.ensure_collection()
483
+
484
+ combined = self._inject_partition_filter(filter)
485
+ mongo_filter = _build_mongo_filter(cast(FilterOptions, combined))
486
+
487
+ total = await collection.count_documents(mongo_filter, session=self._session)
488
+
489
+ if limit == 0:
490
+ return PaginatedResult[T_model](items=[], limit=0, offset=offset, total=total)
491
+
492
+ cursor = collection.find(mongo_filter, session=self._session)
493
+
494
+ sort_name, sort_order = resolve_sort_options(sort)
495
+ if sort_name is not None:
496
+ model_fields = getattr(self._item_type, "model_fields", {})
497
+ if sort_name not in model_fields:
498
+ raise ValueError(
499
+ f"Failed to sort items by '{sort_name}': field does not exist on {self._item_type.__name__}"
500
+ )
501
+ direction = 1 if sort_order == "asc" else -1
502
+ cursor = cursor.sort(sort_name, direction)
503
+
504
+ if offset > 0:
505
+ cursor = cursor.skip(offset)
506
+ if limit >= 0:
507
+ cursor = cursor.limit(limit)
508
+
509
+ items: List[T_model] = []
510
+ async for raw in cursor:
511
+ items.append(self._model_validate_item(raw))
512
+
513
+ return PaginatedResult[T_model](items=items, limit=limit, offset=offset, total=total)
514
+
515
+ @tracked("get")
516
+ async def get(
517
+ self,
518
+ filter: Optional[FilterOptions] = None,
519
+ sort: Optional[SortOptions] = None,
520
+ ) -> Optional[T_model]:
521
+ collection = await self.ensure_collection()
522
+
523
+ combined = self._inject_partition_filter(filter)
524
+ mongo_filter = _build_mongo_filter(cast(FilterOptions, combined))
525
+
526
+ sort_name, sort_order = resolve_sort_options(sort)
527
+ mongo_sort: Optional[List[Tuple[str, int]]] = None
528
+ if sort_name is not None:
529
+ model_fields = getattr(self._item_type, "model_fields", {})
530
+ if sort_name not in model_fields:
531
+ raise ValueError(
532
+ f"Failed to sort items by '{sort_name}': field does not exist on {self._item_type.__name__}"
533
+ )
534
+ direction = 1 if sort_order == "asc" else -1
535
+ mongo_sort = [(sort_name, direction)]
536
+
537
+ raw = await collection.find_one(mongo_filter, sort=mongo_sort, session=self._session)
538
+
539
+ if raw is None:
540
+ return None
541
+
542
+ return self._model_validate_item(raw)
543
+
544
+ @tracked("insert")
545
+ async def insert(self, items: Sequence[T_model]) -> None:
546
+ """Insert items into the collection.
547
+
548
+ The implementation does NOT do checks for duplicate primary keys,
549
+ neither within the same insert call nor across different insert calls.
550
+ It relies on the database to enforce uniqueness via indexes.
551
+ """
552
+ if not items:
553
+ return
554
+
555
+ collection = await self.ensure_collection()
556
+ docs: List[Mapping[str, Any]] = []
557
+ for item in items:
558
+ self._ensure_item_type(item)
559
+ doc = item.model_dump()
560
+ doc["partition_id"] = self._partition_id
561
+ docs.append(doc)
562
+
563
+ if not docs:
564
+ return
565
+
566
+ try:
567
+ async with self.tracking_context("insert.insert_many", self._collection_name):
568
+ await collection.insert_many(docs, session=self._session)
569
+ except DuplicateKeyError as exc:
570
+ # In case the DB enforces uniqueness via index, normalize to ValueError
571
+ raise DuplicatedPrimaryKeyError("Duplicated primary key(s) while inserting items") from exc
572
+ except BulkWriteError as exc:
573
+ write_errors = exc.details.get("writeErrors", [])
574
+ if write_errors and write_errors[0].get("code") == 11000:
575
+ raise DuplicatedPrimaryKeyError("Duplicated primary key(s) while inserting items") from exc
576
+ raise
577
+
578
+ @tracked("update")
579
+ async def update(self, items: Sequence[T_model], update_fields: Sequence[str] | None = None) -> List[T_model]:
580
+ if not items:
581
+ return []
582
+
583
+ updated_items: List[T_model] = []
584
+ collection = await self.ensure_collection()
585
+
586
+ for item in items:
587
+ self._ensure_item_type(item)
588
+ pk_filter = self._pk_filter(item)
589
+ doc = item.model_dump()
590
+ doc["partition_id"] = self._partition_id
591
+
592
+ updated_doc = None
593
+
594
+ # Branch 1: Full Replace
595
+ if update_fields is None:
596
+ async with self.tracking_context("update.find_one_and_replace", self._collection_name):
597
+ updated_doc = await collection.find_one_and_replace(
598
+ filter=pk_filter,
599
+ replacement=doc,
600
+ session=self._session,
601
+ return_document=ReturnDocument.AFTER, # Returns the new version
602
+ )
603
+
604
+ # Branch 2: Partial Update
605
+ else:
606
+ update_doc = {field: doc[field] for field in update_fields if field in doc}
607
+ async with self.tracking_context("update.find_one_and_update", self._collection_name):
608
+ updated_doc = await collection.find_one_and_update(
609
+ filter=pk_filter,
610
+ update={"$set": update_doc},
611
+ session=self._session,
612
+ return_document=ReturnDocument.AFTER, # Returns the new version
613
+ )
614
+
615
+ # Validation and Reconstruction
616
+ if updated_doc is None: # type: ignore
617
+ raise ValueError(f"Item with primary key(s) {pk_filter} does not exist")
618
+
619
+ # Re-instantiate the model from the raw MongoDB dictionary.
620
+ new_item = self._model_validate_item(updated_doc)
621
+ updated_items.append(new_item)
622
+
623
+ return updated_items
624
+
625
+ @tracked("upsert")
626
+ async def upsert(self, items: Sequence[T_model], update_fields: Sequence[str] | None = None) -> List[T_model]:
627
+ if not items:
628
+ return []
629
+
630
+ upserted_items: List[T_model] = []
631
+ collection = await self.ensure_collection()
632
+
633
+ for item in items:
634
+ self._ensure_item_type(item)
635
+ pk_filter = self._pk_filter(item)
636
+
637
+ insert_doc = item.model_dump()
638
+ insert_doc["partition_id"] = self._partition_id
639
+
640
+ # If update_fields is None, we update ALL fields (standard upsert behavior).
641
+ # Otherwise, we only update specific fields, but insert the full doc if it's new.
642
+ target_fields = update_fields if update_fields is not None else list(insert_doc.keys())
643
+
644
+ # 1. $set: Fields that should be overwritten if the document exists
645
+ update_subset = {field: insert_doc[field] for field in target_fields if field in insert_doc}
646
+
647
+ # 2. $setOnInsert: Fields that are only set if we are creating a NEW document
648
+ # (Everything in the model that isn't in the update_subset)
649
+ set_on_insert = {k: v for k, v in insert_doc.items() if k not in update_subset}
650
+
651
+ update_spec: Dict[str, Dict[str, Any]] = {}
652
+ if set_on_insert:
653
+ update_spec["$setOnInsert"] = set_on_insert
654
+ if update_subset:
655
+ update_spec["$set"] = update_subset
656
+
657
+ async with self.tracking_context("upsert.find_one_and_update", self._collection_name):
658
+ result_doc = await collection.find_one_and_update(
659
+ filter=pk_filter,
660
+ update=update_spec,
661
+ upsert=True,
662
+ session=self._session,
663
+ return_document=ReturnDocument.AFTER,
664
+ )
665
+
666
+ # Because upsert=True, result_doc is guaranteed to be not None
667
+ new_item = self._model_validate_item(result_doc)
668
+ upserted_items.append(new_item)
669
+
670
+ return upserted_items
671
+
672
+ @tracked("delete")
673
+ async def delete(self, items: Sequence[T_model]) -> None:
674
+ if not items:
675
+ return
676
+
677
+ collection = await self.ensure_collection()
678
+ for item in items:
679
+ self._ensure_item_type(item)
680
+ pk_filter = self._pk_filter(item)
681
+ async with self.tracking_context("delete.delete_one", self._collection_name):
682
+ result = await collection.delete_one(pk_filter, session=self._session)
683
+ if result.deleted_count == 0:
684
+ raise ValueError(f"Item with primary key(s) {pk_filter} does not exist")
685
+
686
+
687
+ class MongoBasedQueue(Queue[T_generic], Generic[T_generic]):
688
+ """Mongo-based implementation of Queue backed by a MongoDB collection.
689
+
690
+ Items are stored append-only; dequeue marks items as consumed instead of deleting them.
691
+ """
692
+
693
+ def __init__(
694
+ self,
695
+ client_pool: MongoClientPool[Mapping[str, Any]],
696
+ database_name: str,
697
+ collection_name: str,
698
+ partition_id: str,
699
+ item_type: Type[T_generic],
700
+ tracker: MetricsBackend | None = None,
701
+ ) -> None:
702
+ """
703
+ Args:
704
+ client_pool: The pool of MongoDB clients.
705
+ database_name: The name of the database.
706
+ collection_name: The name of the collection backing the queue.
707
+ partition_id: Partition identifier; allows multiple logical queues in one collection.
708
+ item_type: The Python type of queue items (primitive or BaseModel subclass).
709
+ """
710
+ super().__init__(tracker=tracker)
711
+ self._client_pool = client_pool
712
+ self._database_name = database_name
713
+ self._collection_name = collection_name
714
+ self._partition_id = partition_id
715
+ self._item_type = item_type
716
+ self._adapter: TypeAdapter[T_generic] = TypeAdapter(item_type)
717
+ self._collection_created = False
718
+
719
+ self._session: Optional[AsyncClientSession] = None
720
+
721
+ def item_type(self) -> Type[T_generic]:
722
+ return self._item_type
723
+
724
+ @property
725
+ def extra_tracking_labels(self) -> Mapping[str, str]:
726
+ return {
727
+ "database": self._database_name,
728
+ }
729
+
730
+ @property
731
+ def collection_name(self) -> str:
732
+ return self._collection_name
733
+
734
+ @tracked("ensure_collection")
735
+ async def ensure_collection(self) -> AsyncCollection[Mapping[str, Any]]:
736
+ """Ensure the backing collection exists.
737
+
738
+ If it already exists, it returns the existing collection.
739
+ """
740
+ if not self._collection_created:
741
+ client = await self._client_pool.get_client()
742
+ self._collection_created = await _ensure_collection(
743
+ client[self._database_name], self._collection_name, primary_keys=["consumed", "_id"]
744
+ )
745
+ return await self._client_pool.get_collection(self._database_name, self._collection_name)
746
+
747
+ def with_session(self, session: AsyncClientSession) -> MongoBasedQueue[T_generic]:
748
+ queue = MongoBasedQueue(
749
+ client_pool=self._client_pool,
750
+ database_name=self._database_name,
751
+ collection_name=self._collection_name,
752
+ partition_id=self._partition_id,
753
+ item_type=self._item_type,
754
+ tracker=self._tracker,
755
+ )
756
+ queue._collection_created = self._collection_created
757
+ queue._session = session
758
+ return queue
759
+
760
+ @tracked("has")
761
+ async def has(self, item: T_generic) -> bool:
762
+ collection = await self.ensure_collection()
763
+ encoded = self._adapter.dump_python(item, mode="python")
764
+ doc = await collection.find_one(
765
+ {
766
+ "partition_id": self._partition_id,
767
+ "consumed": False,
768
+ "value": encoded,
769
+ },
770
+ session=self._session,
771
+ )
772
+ return doc is not None
773
+
774
+ @tracked("enqueue")
775
+ async def enqueue(self, items: Sequence[T_generic]) -> Sequence[T_generic]:
776
+ if not items:
777
+ return []
778
+
779
+ collection = await self.ensure_collection()
780
+ docs: List[Mapping[str, Any]] = []
781
+ for item in items:
782
+ if not isinstance(item, self._item_type):
783
+ raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
784
+ docs.append(
785
+ {
786
+ "partition_id": self._partition_id,
787
+ "value": self._adapter.dump_python(item, mode="python"),
788
+ "consumed": False,
789
+ "created_at": datetime.now(),
790
+ }
791
+ )
792
+
793
+ async with self.tracking_context("enqueue.insert_many", self.collection_name):
794
+ await collection.insert_many(docs, session=self._session)
795
+ return list(items)
796
+
797
+ @tracked("dequeue")
798
+ async def dequeue(self, limit: int = 1) -> Sequence[T_generic]:
799
+ if limit <= 0:
800
+ return []
801
+
802
+ collection = await self.ensure_collection()
803
+ results: list[T_generic] = []
804
+
805
+ # Atomic claim loop using find_one_and_update
806
+ for _ in range(limit):
807
+ async with self.tracking_context("dequeue.find_one_and_update", self.collection_name):
808
+ doc = await collection.find_one_and_update(
809
+ {
810
+ "partition_id": self._partition_id,
811
+ "consumed": False,
812
+ },
813
+ {"$set": {"consumed": True}},
814
+ sort=[("_id", 1)], # FIFO using insertion order
815
+ return_document=True,
816
+ session=self._session,
817
+ )
818
+ if doc is None: # type: ignore
819
+ # No more items to dequeue
820
+ break
821
+
822
+ raw_value = doc["value"]
823
+ item = self._adapter.validate_python(raw_value)
824
+ results.append(item)
825
+
826
+ return results
827
+
828
+ @tracked("peek")
829
+ async def peek(self, limit: int = 1) -> Sequence[T_generic]:
830
+ if limit <= 0:
831
+ return []
832
+
833
+ collection = await self.ensure_collection()
834
+ async with self.tracking_context("peek.find", self.collection_name):
835
+ cursor = (
836
+ collection.find(
837
+ {
838
+ "partition_id": self._partition_id,
839
+ "consumed": False,
840
+ },
841
+ session=self._session,
842
+ )
843
+ .sort("_id", 1)
844
+ .limit(limit)
845
+ )
846
+
847
+ items: list[T_generic] = []
848
+ async for doc in cursor:
849
+ raw_value = doc["value"]
850
+ items.append(self._adapter.validate_python(raw_value))
851
+
852
+ return items
853
+
854
+ @tracked("size")
855
+ async def size(self) -> int:
856
+ collection = await self.ensure_collection()
857
+ return await collection.count_documents(
858
+ {
859
+ "partition_id": self._partition_id,
860
+ "consumed": False,
861
+ },
862
+ session=self._session,
863
+ )
864
+
865
+
866
+ class MongoBasedKeyValue(KeyValue[K, V], Generic[K, V]):
867
+ """Mongo-based implementation of KeyValue."""
868
+
869
+ def __init__(
870
+ self,
871
+ client_pool: MongoClientPool[Mapping[str, Any]],
872
+ database_name: str,
873
+ collection_name: str,
874
+ partition_id: str,
875
+ key_type: Type[K],
876
+ value_type: Type[V],
877
+ tracker: MetricsBackend | None = None,
878
+ ) -> None:
879
+ """
880
+ Args:
881
+ client_pool: The pool of MongoDB clients.
882
+ database_name: The name of the database.
883
+ collection_name: The name of the collection backing the key-value store.
884
+ partition_id: Partition identifier; allows multiple logical maps in one collection.
885
+ key_type: The Python type of keys (primitive or BaseModel).
886
+ value_type: The Python type of values (primitive or BaseModel).
887
+ tracker: The metrics tracker to use.
888
+ """
889
+ super().__init__(tracker=tracker)
890
+ self._client_pool = client_pool
891
+ self._database_name = database_name
892
+ self._collection_name = collection_name
893
+ self._partition_id = partition_id
894
+ self._key_type = key_type
895
+ self._value_type = value_type
896
+ self._key_adapter: TypeAdapter[K] = TypeAdapter(key_type)
897
+ self._value_adapter: TypeAdapter[V] = TypeAdapter(value_type)
898
+ self._collection_created = False
899
+
900
+ self._session: Optional[AsyncClientSession] = None
901
+
902
+ @property
903
+ def extra_tracking_labels(self) -> Mapping[str, str]:
904
+ return {
905
+ "database": self._database_name,
906
+ }
907
+
908
+ @property
909
+ def collection_name(self) -> str:
910
+ return self._collection_name
911
+
912
+ @tracked("ensure_collection")
913
+ async def ensure_collection(self, *, create_indexes: bool = True) -> AsyncCollection[Mapping[str, Any]]:
914
+ """Ensure the backing collection exists (and optionally its indexes)."""
915
+ if not self._collection_created:
916
+ client = await self._client_pool.get_client()
917
+ self._collection_created = await _ensure_collection(
918
+ client[self._database_name], self._collection_name, primary_keys=["key"]
919
+ )
920
+ return await self._client_pool.get_collection(self._database_name, self._collection_name)
921
+
922
+ def with_session(self, session: AsyncClientSession) -> MongoBasedKeyValue[K, V]:
923
+ key_value = MongoBasedKeyValue(
924
+ client_pool=self._client_pool,
925
+ database_name=self._database_name,
926
+ collection_name=self._collection_name,
927
+ partition_id=self._partition_id,
928
+ key_type=self._key_type,
929
+ value_type=self._value_type,
930
+ tracker=self._tracker,
931
+ )
932
+ key_value._collection_created = self._collection_created
933
+ key_value._session = session
934
+
935
+ return key_value
936
+
937
+ @tracked("has")
938
+ async def has(self, key: K) -> bool:
939
+ collection = await self.ensure_collection()
940
+ encoded_key = self._key_adapter.dump_python(key, mode="python")
941
+ doc = await collection.find_one(
942
+ {
943
+ "partition_id": self._partition_id,
944
+ "key": encoded_key,
945
+ },
946
+ session=self._session,
947
+ )
948
+ return doc is not None
949
+
950
+ @tracked("get")
951
+ async def get(self, key: K, default: V | None = None) -> V | None:
952
+ collection = await self.ensure_collection()
953
+ encoded_key = self._key_adapter.dump_python(key, mode="python")
954
+ doc = await collection.find_one(
955
+ {
956
+ "partition_id": self._partition_id,
957
+ "key": encoded_key,
958
+ },
959
+ session=self._session,
960
+ )
961
+ if doc is None:
962
+ return default
963
+
964
+ raw_value = doc["value"]
965
+ return self._value_adapter.validate_python(raw_value)
966
+
967
+ @tracked("set")
968
+ async def set(self, key: K, value: V) -> None:
969
+ collection = await self.ensure_collection()
970
+ encoded_key = self._key_adapter.dump_python(key, mode="python")
971
+ encoded_value = self._value_adapter.dump_python(value, mode="python")
972
+ try:
973
+ async with self.tracking_context("set.replace_one", self.collection_name):
974
+ await collection.replace_one(
975
+ {
976
+ "partition_id": self._partition_id,
977
+ "key": encoded_key,
978
+ },
979
+ {
980
+ "partition_id": self._partition_id,
981
+ "key": encoded_key,
982
+ "value": encoded_value,
983
+ },
984
+ upsert=True,
985
+ session=self._session,
986
+ )
987
+ except DuplicateKeyError as exc:
988
+ # Very unlikely with replace_one+upsert, but normalize anyway.
989
+ raise DuplicatedPrimaryKeyError("Duplicate key error while setting key-value item") from exc
990
+
991
+ @tracked("inc")
992
+ async def inc(self, key: K, amount: V) -> V:
993
+ assert ensure_numeric(amount, description="amount")
994
+ collection = await self.ensure_collection()
995
+ encoded_key = self._key_adapter.dump_python(key, mode="python")
996
+ encoded_amount = self._value_adapter.dump_python(amount, mode="python")
997
+ try:
998
+ async with self.tracking_context("inc.find_one_and_update", self.collection_name):
999
+ doc = await collection.find_one_and_update(
1000
+ {
1001
+ "partition_id": self._partition_id,
1002
+ "key": encoded_key,
1003
+ },
1004
+ {
1005
+ "$inc": {"value": encoded_amount},
1006
+ },
1007
+ upsert=True,
1008
+ return_document=ReturnDocument.AFTER,
1009
+ session=self._session,
1010
+ )
1011
+ except OperationFailure as exc:
1012
+ if exc.code == 14 or "Cannot apply $inc" in str(exc):
1013
+ raise TypeError(f"value for key {key!r} is not numeric") from exc
1014
+ raise
1015
+ if doc is None: # type: ignore
1016
+ raise RuntimeError("Failed to increment value; MongoDB did not return a document")
1017
+ raw_value = doc["value"]
1018
+ return self._value_adapter.validate_python(raw_value)
1019
+
1020
+ @tracked("chmax")
1021
+ async def chmax(self, key: K, value: V) -> V:
1022
+ assert ensure_numeric(value, description="value")
1023
+ collection = await self.ensure_collection()
1024
+ encoded_key = self._key_adapter.dump_python(key, mode="python")
1025
+ encoded_value = self._value_adapter.dump_python(value, mode="python")
1026
+ try:
1027
+ async with self.tracking_context("chmax.find_one_and_update", self.collection_name):
1028
+ doc = await collection.find_one_and_update(
1029
+ {
1030
+ "partition_id": self._partition_id,
1031
+ "key": encoded_key,
1032
+ },
1033
+ {
1034
+ "$max": {"value": encoded_value},
1035
+ },
1036
+ upsert=True,
1037
+ return_document=ReturnDocument.AFTER,
1038
+ session=self._session,
1039
+ )
1040
+ except OperationFailure as exc:
1041
+ if exc.code == 14 or "Cannot apply $max" in str(exc):
1042
+ raise TypeError(f"value for key {key!r} is not numeric") from exc
1043
+ raise
1044
+ if doc is None: # type: ignore
1045
+ raise RuntimeError("Failed to update value; MongoDB did not return a document")
1046
+ raw_value = doc["value"]
1047
+ return self._value_adapter.validate_python(raw_value)
1048
+
1049
+ @tracked("pop")
1050
+ async def pop(self, key: K, default: V | None = None) -> V | None:
1051
+ collection = await self.ensure_collection()
1052
+ encoded_key = self._key_adapter.dump_python(key, mode="python")
1053
+ doc = await collection.find_one_and_delete(
1054
+ {
1055
+ "partition_id": self._partition_id,
1056
+ "key": encoded_key,
1057
+ },
1058
+ session=self._session,
1059
+ )
1060
+ if doc is None: # type: ignore
1061
+ return default
1062
+
1063
+ raw_value = doc["value"]
1064
+ return self._value_adapter.validate_python(raw_value)
1065
+
1066
+ @tracked("size")
1067
+ async def size(self) -> int:
1068
+ collection = await self.ensure_collection()
1069
+ return await collection.count_documents(
1070
+ {
1071
+ "partition_id": self._partition_id,
1072
+ },
1073
+ session=self._session,
1074
+ )
1075
+
1076
+
1077
+ class MongoLightningCollections(LightningCollections):
1078
+ """Mongo implementation of LightningCollections using MongoDB collections.
1079
+
1080
+ Serves as the storage base for [`MongoLightningStore`][mantisdk.store.mongo.MongoLightningStore].
1081
+ """
1082
+
1083
+ def __init__(
1084
+ self,
1085
+ client_pool: MongoClientPool[Mapping[str, Any]],
1086
+ database_name: str,
1087
+ partition_id: str,
1088
+ rollouts: Optional[MongoBasedCollection[Rollout]] = None,
1089
+ attempts: Optional[MongoBasedCollection[Attempt]] = None,
1090
+ spans: Optional[MongoBasedCollection[Span]] = None,
1091
+ resources: Optional[MongoBasedCollection[ResourcesUpdate]] = None,
1092
+ workers: Optional[MongoBasedCollection[Worker]] = None,
1093
+ rollout_queue: Optional[MongoBasedQueue[str]] = None,
1094
+ span_sequence_ids: Optional[MongoBasedKeyValue[str, int]] = None,
1095
+ tracker: MetricsBackend | None = None,
1096
+ ):
1097
+ super().__init__(tracker=tracker, extra_labels=["database"])
1098
+ self._client_pool = client_pool
1099
+ self._database_name = database_name
1100
+ self._partition_id = partition_id
1101
+ self._collection_ensured = False
1102
+ self._lock = aiologic.Lock() # used for generic atomic operations like scan debounce seconds
1103
+ self._rollouts = (
1104
+ rollouts
1105
+ if rollouts is not None
1106
+ else MongoBasedCollection(
1107
+ self._client_pool,
1108
+ self._database_name,
1109
+ "rollouts",
1110
+ self._partition_id,
1111
+ ["rollout_id"],
1112
+ Rollout,
1113
+ [["status"]],
1114
+ tracker=self._tracker,
1115
+ )
1116
+ )
1117
+ self._attempts = (
1118
+ attempts
1119
+ if attempts is not None
1120
+ else MongoBasedCollection(
1121
+ self._client_pool,
1122
+ self._database_name,
1123
+ "attempts",
1124
+ self._partition_id,
1125
+ ["rollout_id", "attempt_id"],
1126
+ Attempt,
1127
+ [["status"], ["sequence_id"]],
1128
+ tracker=self._tracker,
1129
+ )
1130
+ )
1131
+ self._spans = (
1132
+ spans
1133
+ if spans is not None
1134
+ else MongoBasedCollection(
1135
+ self._client_pool,
1136
+ self._database_name,
1137
+ "spans",
1138
+ self._partition_id,
1139
+ ["rollout_id", "attempt_id", "span_id"],
1140
+ Span,
1141
+ [["sequence_id"]],
1142
+ tracker=self._tracker,
1143
+ )
1144
+ )
1145
+ self._resources = (
1146
+ resources
1147
+ if resources is not None
1148
+ else MongoBasedCollection(
1149
+ self._client_pool,
1150
+ self._database_name,
1151
+ "resources",
1152
+ self._partition_id,
1153
+ ["resources_id"],
1154
+ ResourcesUpdate,
1155
+ ["update_time"],
1156
+ tracker=self._tracker,
1157
+ )
1158
+ )
1159
+ self._workers = (
1160
+ workers
1161
+ if workers is not None
1162
+ else MongoBasedCollection(
1163
+ self._client_pool,
1164
+ self._database_name,
1165
+ "workers",
1166
+ self._partition_id,
1167
+ ["worker_id"],
1168
+ Worker,
1169
+ ["status"],
1170
+ tracker=self._tracker,
1171
+ )
1172
+ )
1173
+ self._rollout_queue = (
1174
+ rollout_queue
1175
+ if rollout_queue is not None
1176
+ else MongoBasedQueue(
1177
+ self._client_pool,
1178
+ self._database_name,
1179
+ "rollout_queue",
1180
+ self._partition_id,
1181
+ str,
1182
+ tracker=self._tracker,
1183
+ )
1184
+ )
1185
+ self._span_sequence_ids = (
1186
+ span_sequence_ids
1187
+ if span_sequence_ids is not None
1188
+ else MongoBasedKeyValue(
1189
+ self._client_pool,
1190
+ self._database_name,
1191
+ "span_sequence_ids",
1192
+ self._partition_id,
1193
+ str,
1194
+ int,
1195
+ tracker=self._tracker,
1196
+ )
1197
+ )
1198
+
1199
+ @property
1200
+ def collection_name(self) -> str:
1201
+ return "router" # Special collection name for tracking transactions
1202
+
1203
+ @property
1204
+ def extra_tracking_labels(self) -> Mapping[str, str]:
1205
+ return {
1206
+ "database": self._database_name,
1207
+ }
1208
+
1209
+ def with_session(self, session: AsyncClientSession) -> Self:
1210
+ instance = self.__class__(
1211
+ client_pool=self._client_pool,
1212
+ database_name=self._database_name,
1213
+ partition_id=self._partition_id,
1214
+ rollouts=self._rollouts.with_session(session),
1215
+ attempts=self._attempts.with_session(session),
1216
+ spans=self._spans.with_session(session),
1217
+ resources=self._resources.with_session(session),
1218
+ workers=self._workers.with_session(session),
1219
+ rollout_queue=self._rollout_queue.with_session(session),
1220
+ span_sequence_ids=self._span_sequence_ids.with_session(session),
1221
+ tracker=self._tracker,
1222
+ )
1223
+ instance._collection_ensured = self._collection_ensured
1224
+ return instance
1225
+
1226
+ @property
1227
+ def rollouts(self) -> MongoBasedCollection[Rollout]:
1228
+ return self._rollouts
1229
+
1230
+ @property
1231
+ def attempts(self) -> MongoBasedCollection[Attempt]:
1232
+ return self._attempts
1233
+
1234
+ @property
1235
+ def spans(self) -> MongoBasedCollection[Span]:
1236
+ return self._spans
1237
+
1238
+ @property
1239
+ def resources(self) -> MongoBasedCollection[ResourcesUpdate]:
1240
+ return self._resources
1241
+
1242
+ @property
1243
+ def workers(self) -> MongoBasedCollection[Worker]:
1244
+ return self._workers
1245
+
1246
+ @property
1247
+ def rollout_queue(self) -> MongoBasedQueue[str]:
1248
+ return self._rollout_queue
1249
+
1250
+ @property
1251
+ def span_sequence_ids(self) -> MongoBasedKeyValue[str, int]:
1252
+ return self._span_sequence_ids
1253
+
1254
+ @tracked("ensure_collections")
1255
+ async def _ensure_collections(self) -> None:
1256
+ """Ensure all collections exist."""
1257
+ if self._collection_ensured:
1258
+ return
1259
+ await self._rollouts.ensure_collection()
1260
+ await self._attempts.ensure_collection()
1261
+ await self._spans.ensure_collection()
1262
+ await self._resources.ensure_collection()
1263
+ await self._workers.ensure_collection()
1264
+ await self._rollout_queue.ensure_collection()
1265
+ await self._span_sequence_ids.ensure_collection()
1266
+ self._collection_ensured = True
1267
+
1268
+ @asynccontextmanager
1269
+ async def _lock_manager(self, labels: Optional[Sequence[AtomicLabels]]):
1270
+ if labels is None or "generic" not in labels:
1271
+ yield
1272
+
1273
+ else:
1274
+ # Only lock the generic label.
1275
+ try:
1276
+ async with self.tracking_context("lock", self.collection_name):
1277
+ await self._lock.async_acquire()
1278
+ yield
1279
+ finally:
1280
+ self._lock.async_release()
1281
+
1282
+ @asynccontextmanager
1283
+ async def atomic(
1284
+ self,
1285
+ mode: AtomicMode = "rw",
1286
+ snapshot: bool = False,
1287
+ commit: bool = False,
1288
+ labels: Optional[Sequence[AtomicLabels]] = None,
1289
+ *args: Any,
1290
+ **kwargs: Any,
1291
+ ):
1292
+ """Perform a atomic operation on the collections."""
1293
+ if commit:
1294
+ raise ValueError("Commit should be used with execute() instead.")
1295
+ async with self._lock_manager(labels):
1296
+ async with self.tracking_context("atomic", self.collection_name):
1297
+ # First step: ensure all collections exist before going into the atomic block
1298
+ if not self._collection_ensured:
1299
+ await self._ensure_collections()
1300
+ # Execute directly without commit
1301
+ yield self
1302
+
1303
+ @tracked("execute")
1304
+ async def execute(
1305
+ self,
1306
+ callback: Callable[[Self], Awaitable[T_generic]],
1307
+ *,
1308
+ mode: AtomicMode = "rw",
1309
+ snapshot: bool = False,
1310
+ commit: bool = False,
1311
+ labels: Optional[Sequence[AtomicLabels]] = None,
1312
+ **kwargs: Any,
1313
+ ) -> T_generic:
1314
+ """Execute the given callback within an atomic operation, and with retries on transient errors."""
1315
+ if not self._collection_ensured:
1316
+ await self._ensure_collections()
1317
+ client = await self._client_pool.get_client()
1318
+
1319
+ # If commit is not turned on, just execute the callback directly.
1320
+ if not commit:
1321
+ async with self._lock_manager(labels):
1322
+ return await callback(self)
1323
+
1324
+ # If snapshot is enabled, use snapshot read concern.
1325
+ read_concern = ReadConcern("snapshot") if snapshot else ReadConcern("local")
1326
+ # If mode is "r", write_concern is not needed.
1327
+ write_concern = WriteConcern("majority") if mode != "r" else None
1328
+
1329
+ async with client.start_session() as session:
1330
+ collections = self.with_session(session)
1331
+ try:
1332
+ async with self._lock_manager(labels):
1333
+ return await self.with_transaction(session, collections, callback, read_concern, write_concern)
1334
+ except (ConnectionFailure, OperationFailure) as exc:
1335
+ # Un-retryable errors.
1336
+ raise RuntimeError("Transaction failed with connection or operation error") from exc
1337
+
1338
+ @tracked("with_transaction")
1339
+ async def with_transaction(
1340
+ self,
1341
+ session: AsyncClientSession,
1342
+ collections: Self,
1343
+ callback: Callable[[Self], Awaitable[T_generic]],
1344
+ read_concern: ReadConcern,
1345
+ write_concern: Optional[WriteConcern],
1346
+ ) -> T_generic:
1347
+ # This will start a transaction, run transaction callback, and commit.
1348
+ # It will also transparently retry on some transient errors.
1349
+ # Expanded implementation of with_transaction from client_session
1350
+ read_preference = ReadPreference.PRIMARY
1351
+ transaction_retry_time_limit = 120
1352
+ start_time = time.monotonic()
1353
+
1354
+ def _within_time_limit() -> bool:
1355
+ return time.monotonic() - start_time < transaction_retry_time_limit
1356
+
1357
+ async def _jitter_before_retry() -> None:
1358
+ async with self.tracking_context("execute.jitter", self.collection_name):
1359
+ await asyncio.sleep(random.uniform(0, 0.05))
1360
+
1361
+ while True:
1362
+ await session.start_transaction(read_concern, write_concern, read_preference)
1363
+
1364
+ try:
1365
+ # The _session is always the same within one transaction,
1366
+ # so we can use the same collections object.
1367
+ async with self.tracking_context("execute.callback", self.collection_name):
1368
+ ret = await callback(collections)
1369
+ # Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
1370
+ except BaseException as exc:
1371
+ if session.in_transaction:
1372
+ await session.abort_transaction()
1373
+ if (
1374
+ isinstance(exc, PyMongoError)
1375
+ and exc.has_error_label("TransientTransactionError")
1376
+ and _within_time_limit()
1377
+ ):
1378
+ # Retry the entire transaction.
1379
+ await _jitter_before_retry()
1380
+ continue
1381
+ raise
1382
+
1383
+ if not session.in_transaction:
1384
+ # Assume callback intentionally ended the transaction.
1385
+ return ret
1386
+
1387
+ # Tracks the commit operation.
1388
+ async with self.tracking_context("execute.commit", self.collection_name):
1389
+ # Loop until the commit succeeds or we hit the time limit.
1390
+ while True:
1391
+ # Tracks the commit attempt.
1392
+ try:
1393
+ async with self.tracking_context("execute.commit_once", self.collection_name):
1394
+ await session.commit_transaction()
1395
+ except PyMongoError as exc:
1396
+ if (
1397
+ exc.has_error_label("UnknownTransactionCommitResult")
1398
+ and _within_time_limit()
1399
+ and not (isinstance(exc, OperationFailure) and exc.code == 50) # max_time_expired_error
1400
+ ):
1401
+ # Retry the commit.
1402
+ await _jitter_before_retry()
1403
+ continue
1404
+
1405
+ if exc.has_error_label("TransientTransactionError") and _within_time_limit():
1406
+ # Retry the entire transaction.
1407
+ await _jitter_before_retry()
1408
+ break
1409
+ raise
1410
+
1411
+ # Commit succeeded.
1412
+ return ret