mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,1412 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import random
|
|
8
|
+
import re
|
|
9
|
+
import time
|
|
10
|
+
from contextlib import asynccontextmanager
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import (
|
|
13
|
+
TYPE_CHECKING,
|
|
14
|
+
Any,
|
|
15
|
+
Awaitable,
|
|
16
|
+
Callable,
|
|
17
|
+
Dict,
|
|
18
|
+
Generic,
|
|
19
|
+
List,
|
|
20
|
+
Mapping,
|
|
21
|
+
Optional,
|
|
22
|
+
Sequence,
|
|
23
|
+
Tuple,
|
|
24
|
+
Type,
|
|
25
|
+
TypeVar,
|
|
26
|
+
cast,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
import aiologic
|
|
30
|
+
|
|
31
|
+
from mantisdk.utils.metrics import MetricsBackend
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from typing import Self
|
|
35
|
+
|
|
36
|
+
from pydantic import BaseModel, TypeAdapter
|
|
37
|
+
from pymongo import AsyncMongoClient, ReadPreference, ReturnDocument, WriteConcern
|
|
38
|
+
from pymongo.asynchronous.client_session import AsyncClientSession
|
|
39
|
+
from pymongo.asynchronous.collection import AsyncCollection
|
|
40
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
41
|
+
from pymongo.errors import (
|
|
42
|
+
BulkWriteError,
|
|
43
|
+
CollectionInvalid,
|
|
44
|
+
ConnectionFailure,
|
|
45
|
+
DuplicateKeyError,
|
|
46
|
+
OperationFailure,
|
|
47
|
+
PyMongoError,
|
|
48
|
+
)
|
|
49
|
+
from pymongo.read_concern import ReadConcern
|
|
50
|
+
|
|
51
|
+
from mantisdk.types import (
|
|
52
|
+
Attempt,
|
|
53
|
+
FilterOptions,
|
|
54
|
+
PaginatedResult,
|
|
55
|
+
ResourcesUpdate,
|
|
56
|
+
Rollout,
|
|
57
|
+
SortOptions,
|
|
58
|
+
Span,
|
|
59
|
+
Worker,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
from .base import (
|
|
63
|
+
AtomicLabels,
|
|
64
|
+
AtomicMode,
|
|
65
|
+
Collection,
|
|
66
|
+
DuplicatedPrimaryKeyError,
|
|
67
|
+
KeyValue,
|
|
68
|
+
LightningCollections,
|
|
69
|
+
Queue,
|
|
70
|
+
ensure_numeric,
|
|
71
|
+
normalize_filter_options,
|
|
72
|
+
resolve_sort_options,
|
|
73
|
+
tracked,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
T_model = TypeVar("T_model", bound=BaseModel)
|
|
77
|
+
|
|
78
|
+
T_generic = TypeVar("T_generic")
|
|
79
|
+
|
|
80
|
+
T_mapping = TypeVar("T_mapping", bound=Mapping[str, Any])
|
|
81
|
+
|
|
82
|
+
T_callable = TypeVar("T_callable", bound=Callable[..., Any])
|
|
83
|
+
|
|
84
|
+
K = TypeVar("K")
|
|
85
|
+
V = TypeVar("V")
|
|
86
|
+
|
|
87
|
+
logger = logging.getLogger(__name__)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def resolve_mongo_error_type(exc: BaseException | None) -> str | None:
|
|
91
|
+
is_transient = isinstance(exc, PyMongoError) and exc.has_error_label("TransientTransactionError")
|
|
92
|
+
if isinstance(exc, OperationFailure):
|
|
93
|
+
if is_transient:
|
|
94
|
+
return f"OperationFailure-{exc.code}-Transient"
|
|
95
|
+
else:
|
|
96
|
+
return f"OperationFailure-{exc.code}"
|
|
97
|
+
if isinstance(exc, DuplicateKeyError):
|
|
98
|
+
return "DuplicateKeyError-Transient" if is_transient else "DuplicateKeyError"
|
|
99
|
+
if isinstance(exc, PyMongoError):
|
|
100
|
+
if is_transient:
|
|
101
|
+
return f"{exc.__class__.__name__}-Transient"
|
|
102
|
+
else:
|
|
103
|
+
return exc.__class__.__name__
|
|
104
|
+
if isinstance(exc, ConnectionFailure):
|
|
105
|
+
return "ConnectionFailure-Transient" if is_transient else "ConnectionFailure"
|
|
106
|
+
if is_transient:
|
|
107
|
+
return "Other-Transient"
|
|
108
|
+
else:
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _field_ops_to_conditions(field: str, ops: Mapping[str, Any]) -> List[Dict[str, Any]]:
|
|
113
|
+
"""Convert a FilterField (ops) into one or more Mongo conditions."""
|
|
114
|
+
conditions: List[Dict[str, Any]] = []
|
|
115
|
+
|
|
116
|
+
for op_name, raw_value in ops.items():
|
|
117
|
+
if op_name == "exact":
|
|
118
|
+
if raw_value is None:
|
|
119
|
+
logger.debug(f"Skipping exact filter for field '{field}' with None value")
|
|
120
|
+
continue
|
|
121
|
+
conditions.append({field: raw_value})
|
|
122
|
+
elif op_name == "within":
|
|
123
|
+
if raw_value is None:
|
|
124
|
+
logger.debug(f"Skipping within filter for field '{field}' with None value")
|
|
125
|
+
continue
|
|
126
|
+
try:
|
|
127
|
+
iterable = list(raw_value)
|
|
128
|
+
except TypeError as exc:
|
|
129
|
+
raise ValueError(f"Invalid iterable for within filter for field '{field}': {raw_value!r}") from exc
|
|
130
|
+
conditions.append({field: {"$in": iterable}})
|
|
131
|
+
elif op_name == "contains":
|
|
132
|
+
if raw_value is None:
|
|
133
|
+
logger.debug(f"Skipping contains filter for field '{field}' with None value")
|
|
134
|
+
continue
|
|
135
|
+
value = str(raw_value)
|
|
136
|
+
pattern = f".*{re.escape(value)}.*"
|
|
137
|
+
conditions.append({field: {"$regex": pattern, "$options": "i"}})
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Unsupported filter operator '{op_name}' for field '{field}'")
|
|
140
|
+
|
|
141
|
+
return conditions
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _build_mongo_filter(filter_options: Optional[FilterOptions]) -> Dict[str, Any]:
|
|
145
|
+
"""Translate FilterOptions into a MongoDB filter dict."""
|
|
146
|
+
normalized, must_filters, aggregate = normalize_filter_options(filter_options)
|
|
147
|
+
|
|
148
|
+
regular_conditions: List[Dict[str, Any]] = []
|
|
149
|
+
must_conditions: List[Dict[str, Any]] = []
|
|
150
|
+
|
|
151
|
+
# Normal filters
|
|
152
|
+
if normalized:
|
|
153
|
+
for field_name, ops in normalized.items():
|
|
154
|
+
regular_conditions.extend(_field_ops_to_conditions(field_name, ops))
|
|
155
|
+
|
|
156
|
+
# Must filters
|
|
157
|
+
if must_filters:
|
|
158
|
+
for field_name, ops in must_filters.items():
|
|
159
|
+
must_conditions.extend(_field_ops_to_conditions(field_name, ops))
|
|
160
|
+
|
|
161
|
+
# No filters at all
|
|
162
|
+
if not regular_conditions and not must_conditions:
|
|
163
|
+
return {}
|
|
164
|
+
|
|
165
|
+
# Aggregate logic for regular conditions; _must always ANDs in.
|
|
166
|
+
if aggregate == "and":
|
|
167
|
+
all_conds = regular_conditions + must_conditions
|
|
168
|
+
if len(all_conds) == 1:
|
|
169
|
+
return all_conds[0]
|
|
170
|
+
return {"$and": all_conds}
|
|
171
|
+
|
|
172
|
+
# aggregate == "or"
|
|
173
|
+
if regular_conditions and must_conditions:
|
|
174
|
+
# (OR of regular) AND (all must)
|
|
175
|
+
if len(regular_conditions) == 1:
|
|
176
|
+
or_part: Dict[str, Any] = regular_conditions[0]
|
|
177
|
+
else:
|
|
178
|
+
or_part = {"$or": regular_conditions}
|
|
179
|
+
|
|
180
|
+
and_parts: List[Dict[str, Any]] = [or_part] + must_conditions
|
|
181
|
+
if len(and_parts) == 1:
|
|
182
|
+
return and_parts[0]
|
|
183
|
+
return {"$and": and_parts}
|
|
184
|
+
|
|
185
|
+
if regular_conditions:
|
|
186
|
+
if len(regular_conditions) == 1:
|
|
187
|
+
return regular_conditions[0]
|
|
188
|
+
return {"$or": regular_conditions}
|
|
189
|
+
|
|
190
|
+
# Only must conditions
|
|
191
|
+
if len(must_conditions) == 1:
|
|
192
|
+
return must_conditions[0]
|
|
193
|
+
return {"$and": must_conditions}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
async def _ensure_collection(
|
|
197
|
+
db: AsyncDatabase[Mapping[str, Any]],
|
|
198
|
+
collection_name: str,
|
|
199
|
+
primary_keys: Optional[Sequence[str]] = None,
|
|
200
|
+
extra_indexes: Optional[Sequence[Sequence[str]]] = None,
|
|
201
|
+
) -> bool:
|
|
202
|
+
"""Ensure the backing MongoDB collection exists.
|
|
203
|
+
|
|
204
|
+
This method is idempotent and safe to call multiple times.
|
|
205
|
+
"""
|
|
206
|
+
# Create collection if it doesn't exist yet
|
|
207
|
+
try:
|
|
208
|
+
await db.create_collection(collection_name)
|
|
209
|
+
except CollectionInvalid as exc:
|
|
210
|
+
# Thrown if collection already exists
|
|
211
|
+
logger.debug(f"Collection '{collection_name}' may have already existed. No need to create it: {exc!r}")
|
|
212
|
+
except OperationFailure as exc:
|
|
213
|
+
logger.debug(f"Failed to create collection '{collection_name}'. Probably already exists: {exc!r}")
|
|
214
|
+
# Some servers use OperationFailure w/ specific codes for "NamespaceExists"
|
|
215
|
+
if exc.code in (48, 68): # 48: NamespaceExists, 68: already exists on older versions
|
|
216
|
+
pass
|
|
217
|
+
else:
|
|
218
|
+
raise
|
|
219
|
+
|
|
220
|
+
# Optionally create a unique index on primary keys (scoped by partition_id)
|
|
221
|
+
if primary_keys:
|
|
222
|
+
# Always include the partition field in the unique index.
|
|
223
|
+
keys = [("partition_id", 1)] + [(pk, 1) for pk in primary_keys]
|
|
224
|
+
try:
|
|
225
|
+
await db[collection_name].create_index(keys, name=f"uniq_partition_{'_'.join(primary_keys)}", unique=True)
|
|
226
|
+
except OperationFailure as exc:
|
|
227
|
+
logger.debug(f"Index for collection '{collection_name}' already exists. No need to create it: {exc!r}")
|
|
228
|
+
# Ignore "index already exists" type errors
|
|
229
|
+
if exc.code in (68, 85): # IndexOptionsConflict, etc.
|
|
230
|
+
pass
|
|
231
|
+
else:
|
|
232
|
+
raise
|
|
233
|
+
|
|
234
|
+
# Optionally create extra indexes
|
|
235
|
+
if extra_indexes:
|
|
236
|
+
for index in extra_indexes:
|
|
237
|
+
try:
|
|
238
|
+
await db[collection_name].create_index(index, name=f"idx_{'_'.join(index)}")
|
|
239
|
+
except OperationFailure as exc:
|
|
240
|
+
logger.debug(f"Index for collection '{collection_name}' already exists. No need to create it: {exc!r}")
|
|
241
|
+
# Ignore "index already exists" type errors
|
|
242
|
+
if exc.code in (68, 85): # IndexOptionsConflict, etc.
|
|
243
|
+
pass
|
|
244
|
+
else:
|
|
245
|
+
raise
|
|
246
|
+
|
|
247
|
+
return True
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class MongoClientPool(Generic[T_mapping]):
|
|
251
|
+
"""A pool of MongoDB clients, each bound to a specific event loop.
|
|
252
|
+
|
|
253
|
+
The pool lazily creates `AsyncMongoClient` instances per event loop using the provided
|
|
254
|
+
connection parameters, ensuring we never try to reuse a client across loops.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
def __init__(self, *, mongo_uri: str, mongo_client_kwargs: Mapping[str, Any] | None = None):
|
|
258
|
+
self._get_collection_lock = aiologic.Lock()
|
|
259
|
+
self._get_client_lock = aiologic.Lock()
|
|
260
|
+
self._mongo_uri = mongo_uri
|
|
261
|
+
self._mongo_client_kwargs = dict(mongo_client_kwargs or {})
|
|
262
|
+
self._client_pool: Dict[int, AsyncMongoClient[T_mapping]] = {}
|
|
263
|
+
self._collection_pool: Dict[Tuple[int, str, str], AsyncCollection[T_mapping]] = {}
|
|
264
|
+
|
|
265
|
+
async def __aenter__(self) -> Self:
|
|
266
|
+
return self
|
|
267
|
+
|
|
268
|
+
async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any) -> None:
|
|
269
|
+
await self.close()
|
|
270
|
+
|
|
271
|
+
async def close(self) -> None:
|
|
272
|
+
"""Close all clients currently tracked by the pool."""
|
|
273
|
+
|
|
274
|
+
async with self._get_client_lock, self._get_collection_lock:
|
|
275
|
+
clients = list(self._client_pool.values())
|
|
276
|
+
self._client_pool.clear()
|
|
277
|
+
self._collection_pool.clear()
|
|
278
|
+
|
|
279
|
+
for client in clients:
|
|
280
|
+
try:
|
|
281
|
+
await client.close()
|
|
282
|
+
except Exception:
|
|
283
|
+
logger.exception("Error closing MongoDB client: %s", client)
|
|
284
|
+
|
|
285
|
+
async def get_client(self) -> AsyncMongoClient[T_mapping]:
|
|
286
|
+
loop = asyncio.get_running_loop()
|
|
287
|
+
key = id(loop)
|
|
288
|
+
|
|
289
|
+
# If there is already a client specifically for this loop, return it.
|
|
290
|
+
existing = self._client_pool.get(key)
|
|
291
|
+
if existing is not None:
|
|
292
|
+
await existing.aconnect() # This actually does nothing if the client is already connected.
|
|
293
|
+
return existing
|
|
294
|
+
|
|
295
|
+
async with self._get_client_lock:
|
|
296
|
+
# Another coroutine may have already created the client.
|
|
297
|
+
if key in self._client_pool:
|
|
298
|
+
await self._client_pool[key].aconnect()
|
|
299
|
+
return self._client_pool[key]
|
|
300
|
+
|
|
301
|
+
# Create a new client for this loop.
|
|
302
|
+
client = AsyncMongoClient[T_mapping](self._mongo_uri, **self._mongo_client_kwargs)
|
|
303
|
+
await client.aconnect()
|
|
304
|
+
self._client_pool[key] = client
|
|
305
|
+
return client
|
|
306
|
+
|
|
307
|
+
async def get_collection(self, database_name: str, collection_name: str) -> AsyncCollection[T_mapping]:
|
|
308
|
+
loop = asyncio.get_running_loop()
|
|
309
|
+
key = (id(loop), database_name, collection_name)
|
|
310
|
+
if key in self._collection_pool:
|
|
311
|
+
return self._collection_pool[key]
|
|
312
|
+
|
|
313
|
+
async with self._get_collection_lock:
|
|
314
|
+
# Another coroutine may have already created the collection.
|
|
315
|
+
if key in self._collection_pool:
|
|
316
|
+
return self._collection_pool[key]
|
|
317
|
+
|
|
318
|
+
# Create a new collection for this loop.
|
|
319
|
+
client = await self.get_client()
|
|
320
|
+
collection = client[database_name][collection_name]
|
|
321
|
+
self._collection_pool.setdefault(key, collection)
|
|
322
|
+
return collection
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class MongoBasedCollection(Collection[T_model]):
|
|
326
|
+
"""Mongo-based implementation of Collection.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
client_pool: The pool of MongoDB clients.
|
|
330
|
+
database_name: The name of the database.
|
|
331
|
+
collection_name: The name of the collection.
|
|
332
|
+
partition_id: The partition ID. Used to partition the collection into multiple collections.
|
|
333
|
+
primary_keys: The primary keys of the collection.
|
|
334
|
+
item_type: The type of the items in the collection.
|
|
335
|
+
extra_indexes: The extra indexes to create on the collection.
|
|
336
|
+
tracker: The metrics tracker to use.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def __init__(
|
|
340
|
+
self,
|
|
341
|
+
client_pool: MongoClientPool[Mapping[str, Any]],
|
|
342
|
+
database_name: str,
|
|
343
|
+
collection_name: str,
|
|
344
|
+
partition_id: str,
|
|
345
|
+
primary_keys: Sequence[str],
|
|
346
|
+
item_type: Type[T_model],
|
|
347
|
+
extra_indexes: Sequence[Sequence[str]] = [],
|
|
348
|
+
tracker: MetricsBackend | None = None,
|
|
349
|
+
):
|
|
350
|
+
super().__init__(tracker=tracker)
|
|
351
|
+
self._client_pool = client_pool
|
|
352
|
+
self._database_name = database_name
|
|
353
|
+
self._collection_name = collection_name
|
|
354
|
+
self._partition_id = partition_id
|
|
355
|
+
self._collection_created = False
|
|
356
|
+
self._extra_indexes = [list(index) for index in extra_indexes]
|
|
357
|
+
self._session: Optional[AsyncClientSession] = None
|
|
358
|
+
|
|
359
|
+
if not primary_keys:
|
|
360
|
+
raise ValueError("primary_keys must be non-empty")
|
|
361
|
+
self._primary_keys = list(primary_keys)
|
|
362
|
+
|
|
363
|
+
if not issubclass(item_type, BaseModel): # type: ignore
|
|
364
|
+
raise ValueError(f"item_type must be a subclass of BaseModel, got {item_type.__name__}")
|
|
365
|
+
self._item_type = item_type
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def collection_name(self) -> str:
|
|
369
|
+
return self._collection_name
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def extra_tracking_labels(self) -> Mapping[str, str]:
|
|
373
|
+
return {
|
|
374
|
+
"database": self._database_name,
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
@tracked("ensure_collection")
|
|
378
|
+
async def ensure_collection(self) -> AsyncCollection[Mapping[str, Any]]:
|
|
379
|
+
"""Ensure the backing MongoDB collection exists (and optionally its indexes).
|
|
380
|
+
|
|
381
|
+
This method is idempotent and safe to call multiple times.
|
|
382
|
+
|
|
383
|
+
It will also create a unique index across the configured primary key fields.
|
|
384
|
+
"""
|
|
385
|
+
if not self._collection_created:
|
|
386
|
+
client = await self._client_pool.get_client()
|
|
387
|
+
self._collection_created = await _ensure_collection(
|
|
388
|
+
client[self._database_name], self._collection_name, self._primary_keys, self._extra_indexes
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
return await self._client_pool.get_collection(self._database_name, self._collection_name)
|
|
392
|
+
|
|
393
|
+
def with_session(self, session: AsyncClientSession) -> MongoBasedCollection[T_model]:
|
|
394
|
+
"""Create a new collection with the same configuration but a new session."""
|
|
395
|
+
collection = MongoBasedCollection(
|
|
396
|
+
client_pool=self._client_pool,
|
|
397
|
+
database_name=self._database_name,
|
|
398
|
+
collection_name=self._collection_name,
|
|
399
|
+
partition_id=self._partition_id,
|
|
400
|
+
primary_keys=self._primary_keys,
|
|
401
|
+
item_type=self._item_type,
|
|
402
|
+
extra_indexes=self._extra_indexes,
|
|
403
|
+
tracker=self._tracker,
|
|
404
|
+
)
|
|
405
|
+
collection._collection_created = self._collection_created
|
|
406
|
+
collection._session = session
|
|
407
|
+
return collection
|
|
408
|
+
|
|
409
|
+
def primary_keys(self) -> Sequence[str]:
|
|
410
|
+
"""Return the primary key field names for this collection."""
|
|
411
|
+
return self._primary_keys
|
|
412
|
+
|
|
413
|
+
def item_type(self) -> Type[T_model]:
|
|
414
|
+
return self._item_type
|
|
415
|
+
|
|
416
|
+
@tracked("size")
|
|
417
|
+
async def size(self) -> int:
|
|
418
|
+
collection = await self.ensure_collection()
|
|
419
|
+
return await collection.count_documents({"partition_id": self._partition_id}, session=self._session)
|
|
420
|
+
|
|
421
|
+
def _pk_filter(self, item: T_model) -> Dict[str, Any]:
|
|
422
|
+
"""Build a Mongo filter for the primary key(s) of a model instance."""
|
|
423
|
+
data = item.model_dump()
|
|
424
|
+
missing = [pk for pk in self._primary_keys if pk not in data]
|
|
425
|
+
if missing:
|
|
426
|
+
raise ValueError(f"Missing primary key fields {missing} on item {item!r}")
|
|
427
|
+
pk_filter: Dict[str, Any] = {"partition_id": self._partition_id}
|
|
428
|
+
pk_filter.update({pk: data[pk] for pk in self._primary_keys})
|
|
429
|
+
return pk_filter
|
|
430
|
+
|
|
431
|
+
def _render_pk_values(self, values: Sequence[Any]) -> str:
|
|
432
|
+
return ", ".join(f"{pk}={value!r}" for pk, value in zip(self._primary_keys, values))
|
|
433
|
+
|
|
434
|
+
def _ensure_item_type(self, item: T_model) -> None:
|
|
435
|
+
if not isinstance(item, self._item_type):
|
|
436
|
+
raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
|
|
437
|
+
|
|
438
|
+
def _inject_partition_filter(self, filter: Optional[FilterOptions]) -> Dict[str, Any]:
|
|
439
|
+
"""Ensure every query is scoped to this collection's partition."""
|
|
440
|
+
combined: Dict[str, Any]
|
|
441
|
+
if filter is None:
|
|
442
|
+
combined = {}
|
|
443
|
+
else:
|
|
444
|
+
combined = dict(filter)
|
|
445
|
+
|
|
446
|
+
partition_must = {"partition_id": {"exact": self._partition_id}}
|
|
447
|
+
existing_must = combined.get("_must")
|
|
448
|
+
if existing_must is None:
|
|
449
|
+
combined["_must"] = partition_must
|
|
450
|
+
return combined
|
|
451
|
+
|
|
452
|
+
if isinstance(existing_must, Mapping):
|
|
453
|
+
combined["_must"] = [existing_must, partition_must]
|
|
454
|
+
elif isinstance(existing_must, Sequence) and not isinstance(existing_must, (str, bytes)):
|
|
455
|
+
combined["_must"] = [*existing_must, partition_must]
|
|
456
|
+
else:
|
|
457
|
+
raise TypeError("`_must` filters must be a mapping or sequence of mappings")
|
|
458
|
+
|
|
459
|
+
return combined
|
|
460
|
+
|
|
461
|
+
def _model_validate_item(self, raw: Mapping[str, Any]) -> T_model:
|
|
462
|
+
item_type_has_id = "_id" in self._item_type.model_fields
|
|
463
|
+
# Remove _id from the raw document if the item type does not have it.
|
|
464
|
+
if not item_type_has_id:
|
|
465
|
+
raw = {k: v for k, v in raw.items() if k != "_id"}
|
|
466
|
+
# Convert Mongo document to Pydantic model
|
|
467
|
+
return self._item_type.model_validate(raw) # type: ignore[arg-type]
|
|
468
|
+
|
|
469
|
+
@tracked("query")
|
|
470
|
+
async def query(
|
|
471
|
+
self,
|
|
472
|
+
filter: Optional[FilterOptions] = None,
|
|
473
|
+
sort: Optional[SortOptions] = None,
|
|
474
|
+
limit: int = -1,
|
|
475
|
+
offset: int = 0,
|
|
476
|
+
) -> PaginatedResult[T_model]:
|
|
477
|
+
"""Mongo-based implementation of Collection.query.
|
|
478
|
+
|
|
479
|
+
The handling of null-values in sorting is different from memory-based implementation.
|
|
480
|
+
In MongoDB, null values are treated as less than non-null values.
|
|
481
|
+
"""
|
|
482
|
+
collection = await self.ensure_collection()
|
|
483
|
+
|
|
484
|
+
combined = self._inject_partition_filter(filter)
|
|
485
|
+
mongo_filter = _build_mongo_filter(cast(FilterOptions, combined))
|
|
486
|
+
|
|
487
|
+
total = await collection.count_documents(mongo_filter, session=self._session)
|
|
488
|
+
|
|
489
|
+
if limit == 0:
|
|
490
|
+
return PaginatedResult[T_model](items=[], limit=0, offset=offset, total=total)
|
|
491
|
+
|
|
492
|
+
cursor = collection.find(mongo_filter, session=self._session)
|
|
493
|
+
|
|
494
|
+
sort_name, sort_order = resolve_sort_options(sort)
|
|
495
|
+
if sort_name is not None:
|
|
496
|
+
model_fields = getattr(self._item_type, "model_fields", {})
|
|
497
|
+
if sort_name not in model_fields:
|
|
498
|
+
raise ValueError(
|
|
499
|
+
f"Failed to sort items by '{sort_name}': field does not exist on {self._item_type.__name__}"
|
|
500
|
+
)
|
|
501
|
+
direction = 1 if sort_order == "asc" else -1
|
|
502
|
+
cursor = cursor.sort(sort_name, direction)
|
|
503
|
+
|
|
504
|
+
if offset > 0:
|
|
505
|
+
cursor = cursor.skip(offset)
|
|
506
|
+
if limit >= 0:
|
|
507
|
+
cursor = cursor.limit(limit)
|
|
508
|
+
|
|
509
|
+
items: List[T_model] = []
|
|
510
|
+
async for raw in cursor:
|
|
511
|
+
items.append(self._model_validate_item(raw))
|
|
512
|
+
|
|
513
|
+
return PaginatedResult[T_model](items=items, limit=limit, offset=offset, total=total)
|
|
514
|
+
|
|
515
|
+
@tracked("get")
|
|
516
|
+
async def get(
|
|
517
|
+
self,
|
|
518
|
+
filter: Optional[FilterOptions] = None,
|
|
519
|
+
sort: Optional[SortOptions] = None,
|
|
520
|
+
) -> Optional[T_model]:
|
|
521
|
+
collection = await self.ensure_collection()
|
|
522
|
+
|
|
523
|
+
combined = self._inject_partition_filter(filter)
|
|
524
|
+
mongo_filter = _build_mongo_filter(cast(FilterOptions, combined))
|
|
525
|
+
|
|
526
|
+
sort_name, sort_order = resolve_sort_options(sort)
|
|
527
|
+
mongo_sort: Optional[List[Tuple[str, int]]] = None
|
|
528
|
+
if sort_name is not None:
|
|
529
|
+
model_fields = getattr(self._item_type, "model_fields", {})
|
|
530
|
+
if sort_name not in model_fields:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
f"Failed to sort items by '{sort_name}': field does not exist on {self._item_type.__name__}"
|
|
533
|
+
)
|
|
534
|
+
direction = 1 if sort_order == "asc" else -1
|
|
535
|
+
mongo_sort = [(sort_name, direction)]
|
|
536
|
+
|
|
537
|
+
raw = await collection.find_one(mongo_filter, sort=mongo_sort, session=self._session)
|
|
538
|
+
|
|
539
|
+
if raw is None:
|
|
540
|
+
return None
|
|
541
|
+
|
|
542
|
+
return self._model_validate_item(raw)
|
|
543
|
+
|
|
544
|
+
@tracked("insert")
|
|
545
|
+
async def insert(self, items: Sequence[T_model]) -> None:
|
|
546
|
+
"""Insert items into the collection.
|
|
547
|
+
|
|
548
|
+
The implementation does NOT do checks for duplicate primary keys,
|
|
549
|
+
neither within the same insert call nor across different insert calls.
|
|
550
|
+
It relies on the database to enforce uniqueness via indexes.
|
|
551
|
+
"""
|
|
552
|
+
if not items:
|
|
553
|
+
return
|
|
554
|
+
|
|
555
|
+
collection = await self.ensure_collection()
|
|
556
|
+
docs: List[Mapping[str, Any]] = []
|
|
557
|
+
for item in items:
|
|
558
|
+
self._ensure_item_type(item)
|
|
559
|
+
doc = item.model_dump()
|
|
560
|
+
doc["partition_id"] = self._partition_id
|
|
561
|
+
docs.append(doc)
|
|
562
|
+
|
|
563
|
+
if not docs:
|
|
564
|
+
return
|
|
565
|
+
|
|
566
|
+
try:
|
|
567
|
+
async with self.tracking_context("insert.insert_many", self._collection_name):
|
|
568
|
+
await collection.insert_many(docs, session=self._session)
|
|
569
|
+
except DuplicateKeyError as exc:
|
|
570
|
+
# In case the DB enforces uniqueness via index, normalize to ValueError
|
|
571
|
+
raise DuplicatedPrimaryKeyError("Duplicated primary key(s) while inserting items") from exc
|
|
572
|
+
except BulkWriteError as exc:
|
|
573
|
+
write_errors = exc.details.get("writeErrors", [])
|
|
574
|
+
if write_errors and write_errors[0].get("code") == 11000:
|
|
575
|
+
raise DuplicatedPrimaryKeyError("Duplicated primary key(s) while inserting items") from exc
|
|
576
|
+
raise
|
|
577
|
+
|
|
578
|
+
@tracked("update")
|
|
579
|
+
async def update(self, items: Sequence[T_model], update_fields: Sequence[str] | None = None) -> List[T_model]:
|
|
580
|
+
if not items:
|
|
581
|
+
return []
|
|
582
|
+
|
|
583
|
+
updated_items: List[T_model] = []
|
|
584
|
+
collection = await self.ensure_collection()
|
|
585
|
+
|
|
586
|
+
for item in items:
|
|
587
|
+
self._ensure_item_type(item)
|
|
588
|
+
pk_filter = self._pk_filter(item)
|
|
589
|
+
doc = item.model_dump()
|
|
590
|
+
doc["partition_id"] = self._partition_id
|
|
591
|
+
|
|
592
|
+
updated_doc = None
|
|
593
|
+
|
|
594
|
+
# Branch 1: Full Replace
|
|
595
|
+
if update_fields is None:
|
|
596
|
+
async with self.tracking_context("update.find_one_and_replace", self._collection_name):
|
|
597
|
+
updated_doc = await collection.find_one_and_replace(
|
|
598
|
+
filter=pk_filter,
|
|
599
|
+
replacement=doc,
|
|
600
|
+
session=self._session,
|
|
601
|
+
return_document=ReturnDocument.AFTER, # Returns the new version
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Branch 2: Partial Update
|
|
605
|
+
else:
|
|
606
|
+
update_doc = {field: doc[field] for field in update_fields if field in doc}
|
|
607
|
+
async with self.tracking_context("update.find_one_and_update", self._collection_name):
|
|
608
|
+
updated_doc = await collection.find_one_and_update(
|
|
609
|
+
filter=pk_filter,
|
|
610
|
+
update={"$set": update_doc},
|
|
611
|
+
session=self._session,
|
|
612
|
+
return_document=ReturnDocument.AFTER, # Returns the new version
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
# Validation and Reconstruction
|
|
616
|
+
if updated_doc is None: # type: ignore
|
|
617
|
+
raise ValueError(f"Item with primary key(s) {pk_filter} does not exist")
|
|
618
|
+
|
|
619
|
+
# Re-instantiate the model from the raw MongoDB dictionary.
|
|
620
|
+
new_item = self._model_validate_item(updated_doc)
|
|
621
|
+
updated_items.append(new_item)
|
|
622
|
+
|
|
623
|
+
return updated_items
|
|
624
|
+
|
|
625
|
+
@tracked("upsert")
|
|
626
|
+
async def upsert(self, items: Sequence[T_model], update_fields: Sequence[str] | None = None) -> List[T_model]:
|
|
627
|
+
if not items:
|
|
628
|
+
return []
|
|
629
|
+
|
|
630
|
+
upserted_items: List[T_model] = []
|
|
631
|
+
collection = await self.ensure_collection()
|
|
632
|
+
|
|
633
|
+
for item in items:
|
|
634
|
+
self._ensure_item_type(item)
|
|
635
|
+
pk_filter = self._pk_filter(item)
|
|
636
|
+
|
|
637
|
+
insert_doc = item.model_dump()
|
|
638
|
+
insert_doc["partition_id"] = self._partition_id
|
|
639
|
+
|
|
640
|
+
# If update_fields is None, we update ALL fields (standard upsert behavior).
|
|
641
|
+
# Otherwise, we only update specific fields, but insert the full doc if it's new.
|
|
642
|
+
target_fields = update_fields if update_fields is not None else list(insert_doc.keys())
|
|
643
|
+
|
|
644
|
+
# 1. $set: Fields that should be overwritten if the document exists
|
|
645
|
+
update_subset = {field: insert_doc[field] for field in target_fields if field in insert_doc}
|
|
646
|
+
|
|
647
|
+
# 2. $setOnInsert: Fields that are only set if we are creating a NEW document
|
|
648
|
+
# (Everything in the model that isn't in the update_subset)
|
|
649
|
+
set_on_insert = {k: v for k, v in insert_doc.items() if k not in update_subset}
|
|
650
|
+
|
|
651
|
+
update_spec: Dict[str, Dict[str, Any]] = {}
|
|
652
|
+
if set_on_insert:
|
|
653
|
+
update_spec["$setOnInsert"] = set_on_insert
|
|
654
|
+
if update_subset:
|
|
655
|
+
update_spec["$set"] = update_subset
|
|
656
|
+
|
|
657
|
+
async with self.tracking_context("upsert.find_one_and_update", self._collection_name):
|
|
658
|
+
result_doc = await collection.find_one_and_update(
|
|
659
|
+
filter=pk_filter,
|
|
660
|
+
update=update_spec,
|
|
661
|
+
upsert=True,
|
|
662
|
+
session=self._session,
|
|
663
|
+
return_document=ReturnDocument.AFTER,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
# Because upsert=True, result_doc is guaranteed to be not None
|
|
667
|
+
new_item = self._model_validate_item(result_doc)
|
|
668
|
+
upserted_items.append(new_item)
|
|
669
|
+
|
|
670
|
+
return upserted_items
|
|
671
|
+
|
|
672
|
+
@tracked("delete")
|
|
673
|
+
async def delete(self, items: Sequence[T_model]) -> None:
|
|
674
|
+
if not items:
|
|
675
|
+
return
|
|
676
|
+
|
|
677
|
+
collection = await self.ensure_collection()
|
|
678
|
+
for item in items:
|
|
679
|
+
self._ensure_item_type(item)
|
|
680
|
+
pk_filter = self._pk_filter(item)
|
|
681
|
+
async with self.tracking_context("delete.delete_one", self._collection_name):
|
|
682
|
+
result = await collection.delete_one(pk_filter, session=self._session)
|
|
683
|
+
if result.deleted_count == 0:
|
|
684
|
+
raise ValueError(f"Item with primary key(s) {pk_filter} does not exist")
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
class MongoBasedQueue(Queue[T_generic], Generic[T_generic]):
|
|
688
|
+
"""Mongo-based implementation of Queue backed by a MongoDB collection.
|
|
689
|
+
|
|
690
|
+
Items are stored append-only; dequeue marks items as consumed instead of deleting them.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
def __init__(
|
|
694
|
+
self,
|
|
695
|
+
client_pool: MongoClientPool[Mapping[str, Any]],
|
|
696
|
+
database_name: str,
|
|
697
|
+
collection_name: str,
|
|
698
|
+
partition_id: str,
|
|
699
|
+
item_type: Type[T_generic],
|
|
700
|
+
tracker: MetricsBackend | None = None,
|
|
701
|
+
) -> None:
|
|
702
|
+
"""
|
|
703
|
+
Args:
|
|
704
|
+
client_pool: The pool of MongoDB clients.
|
|
705
|
+
database_name: The name of the database.
|
|
706
|
+
collection_name: The name of the collection backing the queue.
|
|
707
|
+
partition_id: Partition identifier; allows multiple logical queues in one collection.
|
|
708
|
+
item_type: The Python type of queue items (primitive or BaseModel subclass).
|
|
709
|
+
"""
|
|
710
|
+
super().__init__(tracker=tracker)
|
|
711
|
+
self._client_pool = client_pool
|
|
712
|
+
self._database_name = database_name
|
|
713
|
+
self._collection_name = collection_name
|
|
714
|
+
self._partition_id = partition_id
|
|
715
|
+
self._item_type = item_type
|
|
716
|
+
self._adapter: TypeAdapter[T_generic] = TypeAdapter(item_type)
|
|
717
|
+
self._collection_created = False
|
|
718
|
+
|
|
719
|
+
self._session: Optional[AsyncClientSession] = None
|
|
720
|
+
|
|
721
|
+
def item_type(self) -> Type[T_generic]:
|
|
722
|
+
return self._item_type
|
|
723
|
+
|
|
724
|
+
@property
|
|
725
|
+
def extra_tracking_labels(self) -> Mapping[str, str]:
|
|
726
|
+
return {
|
|
727
|
+
"database": self._database_name,
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
@property
|
|
731
|
+
def collection_name(self) -> str:
|
|
732
|
+
return self._collection_name
|
|
733
|
+
|
|
734
|
+
@tracked("ensure_collection")
|
|
735
|
+
async def ensure_collection(self) -> AsyncCollection[Mapping[str, Any]]:
|
|
736
|
+
"""Ensure the backing collection exists.
|
|
737
|
+
|
|
738
|
+
If it already exists, it returns the existing collection.
|
|
739
|
+
"""
|
|
740
|
+
if not self._collection_created:
|
|
741
|
+
client = await self._client_pool.get_client()
|
|
742
|
+
self._collection_created = await _ensure_collection(
|
|
743
|
+
client[self._database_name], self._collection_name, primary_keys=["consumed", "_id"]
|
|
744
|
+
)
|
|
745
|
+
return await self._client_pool.get_collection(self._database_name, self._collection_name)
|
|
746
|
+
|
|
747
|
+
def with_session(self, session: AsyncClientSession) -> MongoBasedQueue[T_generic]:
|
|
748
|
+
queue = MongoBasedQueue(
|
|
749
|
+
client_pool=self._client_pool,
|
|
750
|
+
database_name=self._database_name,
|
|
751
|
+
collection_name=self._collection_name,
|
|
752
|
+
partition_id=self._partition_id,
|
|
753
|
+
item_type=self._item_type,
|
|
754
|
+
tracker=self._tracker,
|
|
755
|
+
)
|
|
756
|
+
queue._collection_created = self._collection_created
|
|
757
|
+
queue._session = session
|
|
758
|
+
return queue
|
|
759
|
+
|
|
760
|
+
@tracked("has")
|
|
761
|
+
async def has(self, item: T_generic) -> bool:
|
|
762
|
+
collection = await self.ensure_collection()
|
|
763
|
+
encoded = self._adapter.dump_python(item, mode="python")
|
|
764
|
+
doc = await collection.find_one(
|
|
765
|
+
{
|
|
766
|
+
"partition_id": self._partition_id,
|
|
767
|
+
"consumed": False,
|
|
768
|
+
"value": encoded,
|
|
769
|
+
},
|
|
770
|
+
session=self._session,
|
|
771
|
+
)
|
|
772
|
+
return doc is not None
|
|
773
|
+
|
|
774
|
+
@tracked("enqueue")
|
|
775
|
+
async def enqueue(self, items: Sequence[T_generic]) -> Sequence[T_generic]:
|
|
776
|
+
if not items:
|
|
777
|
+
return []
|
|
778
|
+
|
|
779
|
+
collection = await self.ensure_collection()
|
|
780
|
+
docs: List[Mapping[str, Any]] = []
|
|
781
|
+
for item in items:
|
|
782
|
+
if not isinstance(item, self._item_type):
|
|
783
|
+
raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
|
|
784
|
+
docs.append(
|
|
785
|
+
{
|
|
786
|
+
"partition_id": self._partition_id,
|
|
787
|
+
"value": self._adapter.dump_python(item, mode="python"),
|
|
788
|
+
"consumed": False,
|
|
789
|
+
"created_at": datetime.now(),
|
|
790
|
+
}
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
async with self.tracking_context("enqueue.insert_many", self.collection_name):
|
|
794
|
+
await collection.insert_many(docs, session=self._session)
|
|
795
|
+
return list(items)
|
|
796
|
+
|
|
797
|
+
@tracked("dequeue")
|
|
798
|
+
async def dequeue(self, limit: int = 1) -> Sequence[T_generic]:
|
|
799
|
+
if limit <= 0:
|
|
800
|
+
return []
|
|
801
|
+
|
|
802
|
+
collection = await self.ensure_collection()
|
|
803
|
+
results: list[T_generic] = []
|
|
804
|
+
|
|
805
|
+
# Atomic claim loop using find_one_and_update
|
|
806
|
+
for _ in range(limit):
|
|
807
|
+
async with self.tracking_context("dequeue.find_one_and_update", self.collection_name):
|
|
808
|
+
doc = await collection.find_one_and_update(
|
|
809
|
+
{
|
|
810
|
+
"partition_id": self._partition_id,
|
|
811
|
+
"consumed": False,
|
|
812
|
+
},
|
|
813
|
+
{"$set": {"consumed": True}},
|
|
814
|
+
sort=[("_id", 1)], # FIFO using insertion order
|
|
815
|
+
return_document=True,
|
|
816
|
+
session=self._session,
|
|
817
|
+
)
|
|
818
|
+
if doc is None: # type: ignore
|
|
819
|
+
# No more items to dequeue
|
|
820
|
+
break
|
|
821
|
+
|
|
822
|
+
raw_value = doc["value"]
|
|
823
|
+
item = self._adapter.validate_python(raw_value)
|
|
824
|
+
results.append(item)
|
|
825
|
+
|
|
826
|
+
return results
|
|
827
|
+
|
|
828
|
+
@tracked("peek")
|
|
829
|
+
async def peek(self, limit: int = 1) -> Sequence[T_generic]:
|
|
830
|
+
if limit <= 0:
|
|
831
|
+
return []
|
|
832
|
+
|
|
833
|
+
collection = await self.ensure_collection()
|
|
834
|
+
async with self.tracking_context("peek.find", self.collection_name):
|
|
835
|
+
cursor = (
|
|
836
|
+
collection.find(
|
|
837
|
+
{
|
|
838
|
+
"partition_id": self._partition_id,
|
|
839
|
+
"consumed": False,
|
|
840
|
+
},
|
|
841
|
+
session=self._session,
|
|
842
|
+
)
|
|
843
|
+
.sort("_id", 1)
|
|
844
|
+
.limit(limit)
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
items: list[T_generic] = []
|
|
848
|
+
async for doc in cursor:
|
|
849
|
+
raw_value = doc["value"]
|
|
850
|
+
items.append(self._adapter.validate_python(raw_value))
|
|
851
|
+
|
|
852
|
+
return items
|
|
853
|
+
|
|
854
|
+
@tracked("size")
|
|
855
|
+
async def size(self) -> int:
|
|
856
|
+
collection = await self.ensure_collection()
|
|
857
|
+
return await collection.count_documents(
|
|
858
|
+
{
|
|
859
|
+
"partition_id": self._partition_id,
|
|
860
|
+
"consumed": False,
|
|
861
|
+
},
|
|
862
|
+
session=self._session,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
class MongoBasedKeyValue(KeyValue[K, V], Generic[K, V]):
|
|
867
|
+
"""Mongo-based implementation of KeyValue."""
|
|
868
|
+
|
|
869
|
+
def __init__(
|
|
870
|
+
self,
|
|
871
|
+
client_pool: MongoClientPool[Mapping[str, Any]],
|
|
872
|
+
database_name: str,
|
|
873
|
+
collection_name: str,
|
|
874
|
+
partition_id: str,
|
|
875
|
+
key_type: Type[K],
|
|
876
|
+
value_type: Type[V],
|
|
877
|
+
tracker: MetricsBackend | None = None,
|
|
878
|
+
) -> None:
|
|
879
|
+
"""
|
|
880
|
+
Args:
|
|
881
|
+
client_pool: The pool of MongoDB clients.
|
|
882
|
+
database_name: The name of the database.
|
|
883
|
+
collection_name: The name of the collection backing the key-value store.
|
|
884
|
+
partition_id: Partition identifier; allows multiple logical maps in one collection.
|
|
885
|
+
key_type: The Python type of keys (primitive or BaseModel).
|
|
886
|
+
value_type: The Python type of values (primitive or BaseModel).
|
|
887
|
+
tracker: The metrics tracker to use.
|
|
888
|
+
"""
|
|
889
|
+
super().__init__(tracker=tracker)
|
|
890
|
+
self._client_pool = client_pool
|
|
891
|
+
self._database_name = database_name
|
|
892
|
+
self._collection_name = collection_name
|
|
893
|
+
self._partition_id = partition_id
|
|
894
|
+
self._key_type = key_type
|
|
895
|
+
self._value_type = value_type
|
|
896
|
+
self._key_adapter: TypeAdapter[K] = TypeAdapter(key_type)
|
|
897
|
+
self._value_adapter: TypeAdapter[V] = TypeAdapter(value_type)
|
|
898
|
+
self._collection_created = False
|
|
899
|
+
|
|
900
|
+
self._session: Optional[AsyncClientSession] = None
|
|
901
|
+
|
|
902
|
+
@property
|
|
903
|
+
def extra_tracking_labels(self) -> Mapping[str, str]:
|
|
904
|
+
return {
|
|
905
|
+
"database": self._database_name,
|
|
906
|
+
}
|
|
907
|
+
|
|
908
|
+
@property
|
|
909
|
+
def collection_name(self) -> str:
|
|
910
|
+
return self._collection_name
|
|
911
|
+
|
|
912
|
+
@tracked("ensure_collection")
|
|
913
|
+
async def ensure_collection(self, *, create_indexes: bool = True) -> AsyncCollection[Mapping[str, Any]]:
|
|
914
|
+
"""Ensure the backing collection exists (and optionally its indexes)."""
|
|
915
|
+
if not self._collection_created:
|
|
916
|
+
client = await self._client_pool.get_client()
|
|
917
|
+
self._collection_created = await _ensure_collection(
|
|
918
|
+
client[self._database_name], self._collection_name, primary_keys=["key"]
|
|
919
|
+
)
|
|
920
|
+
return await self._client_pool.get_collection(self._database_name, self._collection_name)
|
|
921
|
+
|
|
922
|
+
def with_session(self, session: AsyncClientSession) -> MongoBasedKeyValue[K, V]:
|
|
923
|
+
key_value = MongoBasedKeyValue(
|
|
924
|
+
client_pool=self._client_pool,
|
|
925
|
+
database_name=self._database_name,
|
|
926
|
+
collection_name=self._collection_name,
|
|
927
|
+
partition_id=self._partition_id,
|
|
928
|
+
key_type=self._key_type,
|
|
929
|
+
value_type=self._value_type,
|
|
930
|
+
tracker=self._tracker,
|
|
931
|
+
)
|
|
932
|
+
key_value._collection_created = self._collection_created
|
|
933
|
+
key_value._session = session
|
|
934
|
+
|
|
935
|
+
return key_value
|
|
936
|
+
|
|
937
|
+
@tracked("has")
|
|
938
|
+
async def has(self, key: K) -> bool:
|
|
939
|
+
collection = await self.ensure_collection()
|
|
940
|
+
encoded_key = self._key_adapter.dump_python(key, mode="python")
|
|
941
|
+
doc = await collection.find_one(
|
|
942
|
+
{
|
|
943
|
+
"partition_id": self._partition_id,
|
|
944
|
+
"key": encoded_key,
|
|
945
|
+
},
|
|
946
|
+
session=self._session,
|
|
947
|
+
)
|
|
948
|
+
return doc is not None
|
|
949
|
+
|
|
950
|
+
@tracked("get")
|
|
951
|
+
async def get(self, key: K, default: V | None = None) -> V | None:
|
|
952
|
+
collection = await self.ensure_collection()
|
|
953
|
+
encoded_key = self._key_adapter.dump_python(key, mode="python")
|
|
954
|
+
doc = await collection.find_one(
|
|
955
|
+
{
|
|
956
|
+
"partition_id": self._partition_id,
|
|
957
|
+
"key": encoded_key,
|
|
958
|
+
},
|
|
959
|
+
session=self._session,
|
|
960
|
+
)
|
|
961
|
+
if doc is None:
|
|
962
|
+
return default
|
|
963
|
+
|
|
964
|
+
raw_value = doc["value"]
|
|
965
|
+
return self._value_adapter.validate_python(raw_value)
|
|
966
|
+
|
|
967
|
+
@tracked("set")
|
|
968
|
+
async def set(self, key: K, value: V) -> None:
|
|
969
|
+
collection = await self.ensure_collection()
|
|
970
|
+
encoded_key = self._key_adapter.dump_python(key, mode="python")
|
|
971
|
+
encoded_value = self._value_adapter.dump_python(value, mode="python")
|
|
972
|
+
try:
|
|
973
|
+
async with self.tracking_context("set.replace_one", self.collection_name):
|
|
974
|
+
await collection.replace_one(
|
|
975
|
+
{
|
|
976
|
+
"partition_id": self._partition_id,
|
|
977
|
+
"key": encoded_key,
|
|
978
|
+
},
|
|
979
|
+
{
|
|
980
|
+
"partition_id": self._partition_id,
|
|
981
|
+
"key": encoded_key,
|
|
982
|
+
"value": encoded_value,
|
|
983
|
+
},
|
|
984
|
+
upsert=True,
|
|
985
|
+
session=self._session,
|
|
986
|
+
)
|
|
987
|
+
except DuplicateKeyError as exc:
|
|
988
|
+
# Very unlikely with replace_one+upsert, but normalize anyway.
|
|
989
|
+
raise DuplicatedPrimaryKeyError("Duplicate key error while setting key-value item") from exc
|
|
990
|
+
|
|
991
|
+
@tracked("inc")
|
|
992
|
+
async def inc(self, key: K, amount: V) -> V:
|
|
993
|
+
assert ensure_numeric(amount, description="amount")
|
|
994
|
+
collection = await self.ensure_collection()
|
|
995
|
+
encoded_key = self._key_adapter.dump_python(key, mode="python")
|
|
996
|
+
encoded_amount = self._value_adapter.dump_python(amount, mode="python")
|
|
997
|
+
try:
|
|
998
|
+
async with self.tracking_context("inc.find_one_and_update", self.collection_name):
|
|
999
|
+
doc = await collection.find_one_and_update(
|
|
1000
|
+
{
|
|
1001
|
+
"partition_id": self._partition_id,
|
|
1002
|
+
"key": encoded_key,
|
|
1003
|
+
},
|
|
1004
|
+
{
|
|
1005
|
+
"$inc": {"value": encoded_amount},
|
|
1006
|
+
},
|
|
1007
|
+
upsert=True,
|
|
1008
|
+
return_document=ReturnDocument.AFTER,
|
|
1009
|
+
session=self._session,
|
|
1010
|
+
)
|
|
1011
|
+
except OperationFailure as exc:
|
|
1012
|
+
if exc.code == 14 or "Cannot apply $inc" in str(exc):
|
|
1013
|
+
raise TypeError(f"value for key {key!r} is not numeric") from exc
|
|
1014
|
+
raise
|
|
1015
|
+
if doc is None: # type: ignore
|
|
1016
|
+
raise RuntimeError("Failed to increment value; MongoDB did not return a document")
|
|
1017
|
+
raw_value = doc["value"]
|
|
1018
|
+
return self._value_adapter.validate_python(raw_value)
|
|
1019
|
+
|
|
1020
|
+
@tracked("chmax")
|
|
1021
|
+
async def chmax(self, key: K, value: V) -> V:
|
|
1022
|
+
assert ensure_numeric(value, description="value")
|
|
1023
|
+
collection = await self.ensure_collection()
|
|
1024
|
+
encoded_key = self._key_adapter.dump_python(key, mode="python")
|
|
1025
|
+
encoded_value = self._value_adapter.dump_python(value, mode="python")
|
|
1026
|
+
try:
|
|
1027
|
+
async with self.tracking_context("chmax.find_one_and_update", self.collection_name):
|
|
1028
|
+
doc = await collection.find_one_and_update(
|
|
1029
|
+
{
|
|
1030
|
+
"partition_id": self._partition_id,
|
|
1031
|
+
"key": encoded_key,
|
|
1032
|
+
},
|
|
1033
|
+
{
|
|
1034
|
+
"$max": {"value": encoded_value},
|
|
1035
|
+
},
|
|
1036
|
+
upsert=True,
|
|
1037
|
+
return_document=ReturnDocument.AFTER,
|
|
1038
|
+
session=self._session,
|
|
1039
|
+
)
|
|
1040
|
+
except OperationFailure as exc:
|
|
1041
|
+
if exc.code == 14 or "Cannot apply $max" in str(exc):
|
|
1042
|
+
raise TypeError(f"value for key {key!r} is not numeric") from exc
|
|
1043
|
+
raise
|
|
1044
|
+
if doc is None: # type: ignore
|
|
1045
|
+
raise RuntimeError("Failed to update value; MongoDB did not return a document")
|
|
1046
|
+
raw_value = doc["value"]
|
|
1047
|
+
return self._value_adapter.validate_python(raw_value)
|
|
1048
|
+
|
|
1049
|
+
@tracked("pop")
|
|
1050
|
+
async def pop(self, key: K, default: V | None = None) -> V | None:
|
|
1051
|
+
collection = await self.ensure_collection()
|
|
1052
|
+
encoded_key = self._key_adapter.dump_python(key, mode="python")
|
|
1053
|
+
doc = await collection.find_one_and_delete(
|
|
1054
|
+
{
|
|
1055
|
+
"partition_id": self._partition_id,
|
|
1056
|
+
"key": encoded_key,
|
|
1057
|
+
},
|
|
1058
|
+
session=self._session,
|
|
1059
|
+
)
|
|
1060
|
+
if doc is None: # type: ignore
|
|
1061
|
+
return default
|
|
1062
|
+
|
|
1063
|
+
raw_value = doc["value"]
|
|
1064
|
+
return self._value_adapter.validate_python(raw_value)
|
|
1065
|
+
|
|
1066
|
+
@tracked("size")
|
|
1067
|
+
async def size(self) -> int:
|
|
1068
|
+
collection = await self.ensure_collection()
|
|
1069
|
+
return await collection.count_documents(
|
|
1070
|
+
{
|
|
1071
|
+
"partition_id": self._partition_id,
|
|
1072
|
+
},
|
|
1073
|
+
session=self._session,
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
class MongoLightningCollections(LightningCollections):
|
|
1078
|
+
"""Mongo implementation of LightningCollections using MongoDB collections.
|
|
1079
|
+
|
|
1080
|
+
Serves as the storage base for [`MongoLightningStore`][mantisdk.store.mongo.MongoLightningStore].
|
|
1081
|
+
"""
|
|
1082
|
+
|
|
1083
|
+
def __init__(
|
|
1084
|
+
self,
|
|
1085
|
+
client_pool: MongoClientPool[Mapping[str, Any]],
|
|
1086
|
+
database_name: str,
|
|
1087
|
+
partition_id: str,
|
|
1088
|
+
rollouts: Optional[MongoBasedCollection[Rollout]] = None,
|
|
1089
|
+
attempts: Optional[MongoBasedCollection[Attempt]] = None,
|
|
1090
|
+
spans: Optional[MongoBasedCollection[Span]] = None,
|
|
1091
|
+
resources: Optional[MongoBasedCollection[ResourcesUpdate]] = None,
|
|
1092
|
+
workers: Optional[MongoBasedCollection[Worker]] = None,
|
|
1093
|
+
rollout_queue: Optional[MongoBasedQueue[str]] = None,
|
|
1094
|
+
span_sequence_ids: Optional[MongoBasedKeyValue[str, int]] = None,
|
|
1095
|
+
tracker: MetricsBackend | None = None,
|
|
1096
|
+
):
|
|
1097
|
+
super().__init__(tracker=tracker, extra_labels=["database"])
|
|
1098
|
+
self._client_pool = client_pool
|
|
1099
|
+
self._database_name = database_name
|
|
1100
|
+
self._partition_id = partition_id
|
|
1101
|
+
self._collection_ensured = False
|
|
1102
|
+
self._lock = aiologic.Lock() # used for generic atomic operations like scan debounce seconds
|
|
1103
|
+
self._rollouts = (
|
|
1104
|
+
rollouts
|
|
1105
|
+
if rollouts is not None
|
|
1106
|
+
else MongoBasedCollection(
|
|
1107
|
+
self._client_pool,
|
|
1108
|
+
self._database_name,
|
|
1109
|
+
"rollouts",
|
|
1110
|
+
self._partition_id,
|
|
1111
|
+
["rollout_id"],
|
|
1112
|
+
Rollout,
|
|
1113
|
+
[["status"]],
|
|
1114
|
+
tracker=self._tracker,
|
|
1115
|
+
)
|
|
1116
|
+
)
|
|
1117
|
+
self._attempts = (
|
|
1118
|
+
attempts
|
|
1119
|
+
if attempts is not None
|
|
1120
|
+
else MongoBasedCollection(
|
|
1121
|
+
self._client_pool,
|
|
1122
|
+
self._database_name,
|
|
1123
|
+
"attempts",
|
|
1124
|
+
self._partition_id,
|
|
1125
|
+
["rollout_id", "attempt_id"],
|
|
1126
|
+
Attempt,
|
|
1127
|
+
[["status"], ["sequence_id"]],
|
|
1128
|
+
tracker=self._tracker,
|
|
1129
|
+
)
|
|
1130
|
+
)
|
|
1131
|
+
self._spans = (
|
|
1132
|
+
spans
|
|
1133
|
+
if spans is not None
|
|
1134
|
+
else MongoBasedCollection(
|
|
1135
|
+
self._client_pool,
|
|
1136
|
+
self._database_name,
|
|
1137
|
+
"spans",
|
|
1138
|
+
self._partition_id,
|
|
1139
|
+
["rollout_id", "attempt_id", "span_id"],
|
|
1140
|
+
Span,
|
|
1141
|
+
[["sequence_id"]],
|
|
1142
|
+
tracker=self._tracker,
|
|
1143
|
+
)
|
|
1144
|
+
)
|
|
1145
|
+
self._resources = (
|
|
1146
|
+
resources
|
|
1147
|
+
if resources is not None
|
|
1148
|
+
else MongoBasedCollection(
|
|
1149
|
+
self._client_pool,
|
|
1150
|
+
self._database_name,
|
|
1151
|
+
"resources",
|
|
1152
|
+
self._partition_id,
|
|
1153
|
+
["resources_id"],
|
|
1154
|
+
ResourcesUpdate,
|
|
1155
|
+
["update_time"],
|
|
1156
|
+
tracker=self._tracker,
|
|
1157
|
+
)
|
|
1158
|
+
)
|
|
1159
|
+
self._workers = (
|
|
1160
|
+
workers
|
|
1161
|
+
if workers is not None
|
|
1162
|
+
else MongoBasedCollection(
|
|
1163
|
+
self._client_pool,
|
|
1164
|
+
self._database_name,
|
|
1165
|
+
"workers",
|
|
1166
|
+
self._partition_id,
|
|
1167
|
+
["worker_id"],
|
|
1168
|
+
Worker,
|
|
1169
|
+
["status"],
|
|
1170
|
+
tracker=self._tracker,
|
|
1171
|
+
)
|
|
1172
|
+
)
|
|
1173
|
+
self._rollout_queue = (
|
|
1174
|
+
rollout_queue
|
|
1175
|
+
if rollout_queue is not None
|
|
1176
|
+
else MongoBasedQueue(
|
|
1177
|
+
self._client_pool,
|
|
1178
|
+
self._database_name,
|
|
1179
|
+
"rollout_queue",
|
|
1180
|
+
self._partition_id,
|
|
1181
|
+
str,
|
|
1182
|
+
tracker=self._tracker,
|
|
1183
|
+
)
|
|
1184
|
+
)
|
|
1185
|
+
self._span_sequence_ids = (
|
|
1186
|
+
span_sequence_ids
|
|
1187
|
+
if span_sequence_ids is not None
|
|
1188
|
+
else MongoBasedKeyValue(
|
|
1189
|
+
self._client_pool,
|
|
1190
|
+
self._database_name,
|
|
1191
|
+
"span_sequence_ids",
|
|
1192
|
+
self._partition_id,
|
|
1193
|
+
str,
|
|
1194
|
+
int,
|
|
1195
|
+
tracker=self._tracker,
|
|
1196
|
+
)
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
@property
|
|
1200
|
+
def collection_name(self) -> str:
|
|
1201
|
+
return "router" # Special collection name for tracking transactions
|
|
1202
|
+
|
|
1203
|
+
@property
|
|
1204
|
+
def extra_tracking_labels(self) -> Mapping[str, str]:
|
|
1205
|
+
return {
|
|
1206
|
+
"database": self._database_name,
|
|
1207
|
+
}
|
|
1208
|
+
|
|
1209
|
+
def with_session(self, session: AsyncClientSession) -> Self:
|
|
1210
|
+
instance = self.__class__(
|
|
1211
|
+
client_pool=self._client_pool,
|
|
1212
|
+
database_name=self._database_name,
|
|
1213
|
+
partition_id=self._partition_id,
|
|
1214
|
+
rollouts=self._rollouts.with_session(session),
|
|
1215
|
+
attempts=self._attempts.with_session(session),
|
|
1216
|
+
spans=self._spans.with_session(session),
|
|
1217
|
+
resources=self._resources.with_session(session),
|
|
1218
|
+
workers=self._workers.with_session(session),
|
|
1219
|
+
rollout_queue=self._rollout_queue.with_session(session),
|
|
1220
|
+
span_sequence_ids=self._span_sequence_ids.with_session(session),
|
|
1221
|
+
tracker=self._tracker,
|
|
1222
|
+
)
|
|
1223
|
+
instance._collection_ensured = self._collection_ensured
|
|
1224
|
+
return instance
|
|
1225
|
+
|
|
1226
|
+
@property
|
|
1227
|
+
def rollouts(self) -> MongoBasedCollection[Rollout]:
|
|
1228
|
+
return self._rollouts
|
|
1229
|
+
|
|
1230
|
+
@property
|
|
1231
|
+
def attempts(self) -> MongoBasedCollection[Attempt]:
|
|
1232
|
+
return self._attempts
|
|
1233
|
+
|
|
1234
|
+
@property
|
|
1235
|
+
def spans(self) -> MongoBasedCollection[Span]:
|
|
1236
|
+
return self._spans
|
|
1237
|
+
|
|
1238
|
+
@property
|
|
1239
|
+
def resources(self) -> MongoBasedCollection[ResourcesUpdate]:
|
|
1240
|
+
return self._resources
|
|
1241
|
+
|
|
1242
|
+
@property
|
|
1243
|
+
def workers(self) -> MongoBasedCollection[Worker]:
|
|
1244
|
+
return self._workers
|
|
1245
|
+
|
|
1246
|
+
@property
|
|
1247
|
+
def rollout_queue(self) -> MongoBasedQueue[str]:
|
|
1248
|
+
return self._rollout_queue
|
|
1249
|
+
|
|
1250
|
+
@property
|
|
1251
|
+
def span_sequence_ids(self) -> MongoBasedKeyValue[str, int]:
|
|
1252
|
+
return self._span_sequence_ids
|
|
1253
|
+
|
|
1254
|
+
@tracked("ensure_collections")
|
|
1255
|
+
async def _ensure_collections(self) -> None:
|
|
1256
|
+
"""Ensure all collections exist."""
|
|
1257
|
+
if self._collection_ensured:
|
|
1258
|
+
return
|
|
1259
|
+
await self._rollouts.ensure_collection()
|
|
1260
|
+
await self._attempts.ensure_collection()
|
|
1261
|
+
await self._spans.ensure_collection()
|
|
1262
|
+
await self._resources.ensure_collection()
|
|
1263
|
+
await self._workers.ensure_collection()
|
|
1264
|
+
await self._rollout_queue.ensure_collection()
|
|
1265
|
+
await self._span_sequence_ids.ensure_collection()
|
|
1266
|
+
self._collection_ensured = True
|
|
1267
|
+
|
|
1268
|
+
@asynccontextmanager
|
|
1269
|
+
async def _lock_manager(self, labels: Optional[Sequence[AtomicLabels]]):
|
|
1270
|
+
if labels is None or "generic" not in labels:
|
|
1271
|
+
yield
|
|
1272
|
+
|
|
1273
|
+
else:
|
|
1274
|
+
# Only lock the generic label.
|
|
1275
|
+
try:
|
|
1276
|
+
async with self.tracking_context("lock", self.collection_name):
|
|
1277
|
+
await self._lock.async_acquire()
|
|
1278
|
+
yield
|
|
1279
|
+
finally:
|
|
1280
|
+
self._lock.async_release()
|
|
1281
|
+
|
|
1282
|
+
@asynccontextmanager
|
|
1283
|
+
async def atomic(
|
|
1284
|
+
self,
|
|
1285
|
+
mode: AtomicMode = "rw",
|
|
1286
|
+
snapshot: bool = False,
|
|
1287
|
+
commit: bool = False,
|
|
1288
|
+
labels: Optional[Sequence[AtomicLabels]] = None,
|
|
1289
|
+
*args: Any,
|
|
1290
|
+
**kwargs: Any,
|
|
1291
|
+
):
|
|
1292
|
+
"""Perform a atomic operation on the collections."""
|
|
1293
|
+
if commit:
|
|
1294
|
+
raise ValueError("Commit should be used with execute() instead.")
|
|
1295
|
+
async with self._lock_manager(labels):
|
|
1296
|
+
async with self.tracking_context("atomic", self.collection_name):
|
|
1297
|
+
# First step: ensure all collections exist before going into the atomic block
|
|
1298
|
+
if not self._collection_ensured:
|
|
1299
|
+
await self._ensure_collections()
|
|
1300
|
+
# Execute directly without commit
|
|
1301
|
+
yield self
|
|
1302
|
+
|
|
1303
|
+
@tracked("execute")
|
|
1304
|
+
async def execute(
|
|
1305
|
+
self,
|
|
1306
|
+
callback: Callable[[Self], Awaitable[T_generic]],
|
|
1307
|
+
*,
|
|
1308
|
+
mode: AtomicMode = "rw",
|
|
1309
|
+
snapshot: bool = False,
|
|
1310
|
+
commit: bool = False,
|
|
1311
|
+
labels: Optional[Sequence[AtomicLabels]] = None,
|
|
1312
|
+
**kwargs: Any,
|
|
1313
|
+
) -> T_generic:
|
|
1314
|
+
"""Execute the given callback within an atomic operation, and with retries on transient errors."""
|
|
1315
|
+
if not self._collection_ensured:
|
|
1316
|
+
await self._ensure_collections()
|
|
1317
|
+
client = await self._client_pool.get_client()
|
|
1318
|
+
|
|
1319
|
+
# If commit is not turned on, just execute the callback directly.
|
|
1320
|
+
if not commit:
|
|
1321
|
+
async with self._lock_manager(labels):
|
|
1322
|
+
return await callback(self)
|
|
1323
|
+
|
|
1324
|
+
# If snapshot is enabled, use snapshot read concern.
|
|
1325
|
+
read_concern = ReadConcern("snapshot") if snapshot else ReadConcern("local")
|
|
1326
|
+
# If mode is "r", write_concern is not needed.
|
|
1327
|
+
write_concern = WriteConcern("majority") if mode != "r" else None
|
|
1328
|
+
|
|
1329
|
+
async with client.start_session() as session:
|
|
1330
|
+
collections = self.with_session(session)
|
|
1331
|
+
try:
|
|
1332
|
+
async with self._lock_manager(labels):
|
|
1333
|
+
return await self.with_transaction(session, collections, callback, read_concern, write_concern)
|
|
1334
|
+
except (ConnectionFailure, OperationFailure) as exc:
|
|
1335
|
+
# Un-retryable errors.
|
|
1336
|
+
raise RuntimeError("Transaction failed with connection or operation error") from exc
|
|
1337
|
+
|
|
1338
|
+
@tracked("with_transaction")
|
|
1339
|
+
async def with_transaction(
|
|
1340
|
+
self,
|
|
1341
|
+
session: AsyncClientSession,
|
|
1342
|
+
collections: Self,
|
|
1343
|
+
callback: Callable[[Self], Awaitable[T_generic]],
|
|
1344
|
+
read_concern: ReadConcern,
|
|
1345
|
+
write_concern: Optional[WriteConcern],
|
|
1346
|
+
) -> T_generic:
|
|
1347
|
+
# This will start a transaction, run transaction callback, and commit.
|
|
1348
|
+
# It will also transparently retry on some transient errors.
|
|
1349
|
+
# Expanded implementation of with_transaction from client_session
|
|
1350
|
+
read_preference = ReadPreference.PRIMARY
|
|
1351
|
+
transaction_retry_time_limit = 120
|
|
1352
|
+
start_time = time.monotonic()
|
|
1353
|
+
|
|
1354
|
+
def _within_time_limit() -> bool:
|
|
1355
|
+
return time.monotonic() - start_time < transaction_retry_time_limit
|
|
1356
|
+
|
|
1357
|
+
async def _jitter_before_retry() -> None:
|
|
1358
|
+
async with self.tracking_context("execute.jitter", self.collection_name):
|
|
1359
|
+
await asyncio.sleep(random.uniform(0, 0.05))
|
|
1360
|
+
|
|
1361
|
+
while True:
|
|
1362
|
+
await session.start_transaction(read_concern, write_concern, read_preference)
|
|
1363
|
+
|
|
1364
|
+
try:
|
|
1365
|
+
# The _session is always the same within one transaction,
|
|
1366
|
+
# so we can use the same collections object.
|
|
1367
|
+
async with self.tracking_context("execute.callback", self.collection_name):
|
|
1368
|
+
ret = await callback(collections)
|
|
1369
|
+
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
|
1370
|
+
except BaseException as exc:
|
|
1371
|
+
if session.in_transaction:
|
|
1372
|
+
await session.abort_transaction()
|
|
1373
|
+
if (
|
|
1374
|
+
isinstance(exc, PyMongoError)
|
|
1375
|
+
and exc.has_error_label("TransientTransactionError")
|
|
1376
|
+
and _within_time_limit()
|
|
1377
|
+
):
|
|
1378
|
+
# Retry the entire transaction.
|
|
1379
|
+
await _jitter_before_retry()
|
|
1380
|
+
continue
|
|
1381
|
+
raise
|
|
1382
|
+
|
|
1383
|
+
if not session.in_transaction:
|
|
1384
|
+
# Assume callback intentionally ended the transaction.
|
|
1385
|
+
return ret
|
|
1386
|
+
|
|
1387
|
+
# Tracks the commit operation.
|
|
1388
|
+
async with self.tracking_context("execute.commit", self.collection_name):
|
|
1389
|
+
# Loop until the commit succeeds or we hit the time limit.
|
|
1390
|
+
while True:
|
|
1391
|
+
# Tracks the commit attempt.
|
|
1392
|
+
try:
|
|
1393
|
+
async with self.tracking_context("execute.commit_once", self.collection_name):
|
|
1394
|
+
await session.commit_transaction()
|
|
1395
|
+
except PyMongoError as exc:
|
|
1396
|
+
if (
|
|
1397
|
+
exc.has_error_label("UnknownTransactionCommitResult")
|
|
1398
|
+
and _within_time_limit()
|
|
1399
|
+
and not (isinstance(exc, OperationFailure) and exc.code == 50) # max_time_expired_error
|
|
1400
|
+
):
|
|
1401
|
+
# Retry the commit.
|
|
1402
|
+
await _jitter_before_retry()
|
|
1403
|
+
continue
|
|
1404
|
+
|
|
1405
|
+
if exc.has_error_label("TransientTransactionError") and _within_time_limit():
|
|
1406
|
+
# Retry the entire transaction.
|
|
1407
|
+
await _jitter_before_retry()
|
|
1408
|
+
break
|
|
1409
|
+
raise
|
|
1410
|
+
|
|
1411
|
+
# Commit succeeded.
|
|
1412
|
+
return ret
|