mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,970 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import uuid
8
+ import weakref
9
+ from collections import deque
10
+ from contextlib import AsyncExitStack, asynccontextmanager
11
+ from typing import (
12
+ Any,
13
+ Deque,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Literal,
18
+ Mapping,
19
+ MutableMapping,
20
+ Optional,
21
+ Sequence,
22
+ Tuple,
23
+ Type,
24
+ TypeVar,
25
+ Union,
26
+ cast,
27
+ )
28
+
29
+ import aiologic
30
+ from pydantic import BaseModel
31
+
32
+ from mantisdk.types import (
33
+ Attempt,
34
+ FilterField,
35
+ FilterOptions,
36
+ PaginatedResult,
37
+ ResourcesUpdate,
38
+ Rollout,
39
+ SortOptions,
40
+ Span,
41
+ Worker,
42
+ )
43
+ from mantisdk.utils.metrics import MetricsBackend
44
+
45
+ from .base import (
46
+ AtomicLabels,
47
+ AtomicMode,
48
+ Collection,
49
+ DuplicatedPrimaryKeyError,
50
+ FilterMap,
51
+ KeyValue,
52
+ LightningCollections,
53
+ Queue,
54
+ ensure_numeric,
55
+ normalize_filter_options,
56
+ resolve_sort_options,
57
+ tracked,
58
+ )
59
+
60
+ T = TypeVar("T") # Recommended to be a BaseModel, not a dict
61
+ K = TypeVar("K")
62
+ V = TypeVar("V")
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+ # Nested structure type:
67
+ # dict[pk1] -> dict[pk2] -> ... -> item
68
+ ListBasedCollectionItemType = Union[
69
+ Dict[Any, "ListBasedCollectionItemType[T]"], # intermediate node
70
+ Dict[Any, T], # leaf node dictionary
71
+ ]
72
+
73
+ MutationMode = Literal["insert", "update", "upsert", "delete"]
74
+
75
+
76
+ def _item_matches_filters(
77
+ item: object,
78
+ filters: Optional[FilterMap],
79
+ filter_logic: Literal["and", "or"],
80
+ must_filters: Optional[FilterMap] = None,
81
+ ) -> bool:
82
+ """Check whether an item matches the provided filter definition.
83
+
84
+ Filter format:
85
+
86
+ ```json
87
+ {
88
+ "_aggregate": "or",
89
+ "field_name": {
90
+ "exact": <value>,
91
+ "within": <iterable_of_allowed_values>,
92
+ "contains": <substring_or_element>,
93
+ },
94
+ ...
95
+ }
96
+ ```
97
+
98
+ Operators within the same field are stored in a unified pool and combined using
99
+ a universal logical operator.
100
+ """
101
+ if must_filters and not _item_matches_filters(item, must_filters, "and"):
102
+ return False
103
+
104
+ if not filters:
105
+ return True
106
+
107
+ all_conditions_match: List[bool] = []
108
+
109
+ for field_name, ops in filters.items():
110
+ item_value = getattr(item, field_name, None)
111
+
112
+ for op_name, expected in ops.items():
113
+ # Ignore no-op filters
114
+ if expected is None:
115
+ continue
116
+
117
+ if op_name == "exact":
118
+ all_conditions_match.append(item_value == expected)
119
+
120
+ elif op_name == "within":
121
+ try:
122
+ all_conditions_match.append(item_value in expected) # type: ignore[arg-type]
123
+ except TypeError:
124
+ all_conditions_match.append(False)
125
+
126
+ elif op_name == "contains":
127
+ if item_value is None:
128
+ all_conditions_match.append(False)
129
+ elif isinstance(item_value, str) and isinstance(expected, str):
130
+ all_conditions_match.append(expected in item_value)
131
+ else:
132
+ # Fallback: treat as generic iterable containment.
133
+ try:
134
+ all_conditions_match.append(expected in item_value) # type: ignore[arg-type]
135
+ except TypeError:
136
+ all_conditions_match.append(False)
137
+ else:
138
+ raise ValueError(f"Unsupported filter operator '{op_name}' for field '{field_name}'")
139
+
140
+ return all(all_conditions_match) if filter_logic == "and" else any(all_conditions_match)
141
+
142
+
143
+ def _get_sort_value(item: object, sort_by: str) -> Any:
144
+ """Get a sort key for the given item/field.
145
+
146
+ - If the field name ends with '_time', values are treated as comparable timestamps.
147
+ - For other fields we try to infer a safe default from the Pydantic model annotation.
148
+ """
149
+ value = getattr(item, sort_by, None)
150
+
151
+ if sort_by.endswith("_time"):
152
+ # For *_time fields, push missing values to the end.
153
+ return float("inf") if value is None else value
154
+
155
+ if value is None:
156
+ # Introspect model field type to choose a reasonable default for None.
157
+ model_fields = getattr(item.__class__, "model_fields", {})
158
+ if sort_by not in model_fields:
159
+ raise ValueError(
160
+ f"Failed to sort items by '{sort_by}': field does not exist " f"on {item.__class__.__name__}"
161
+ )
162
+
163
+ field_type_str = str(model_fields[sort_by].annotation)
164
+ if "str" in field_type_str or "Literal" in field_type_str:
165
+ return ""
166
+ if "int" in field_type_str:
167
+ return 0
168
+ if "float" in field_type_str:
169
+ return 0.0
170
+ raise ValueError(f"Failed to sort items by '{sort_by}': unsupported field type {field_type_str!r}")
171
+
172
+ return value
173
+
174
+
175
+ class ListBasedCollection(Collection[T]):
176
+ """In-memory implementation of Collection using a nested dict for O(1) primary-key lookup.
177
+
178
+ The internal structure is:
179
+
180
+ {
181
+ pk1_value: {
182
+ pk2_value: {
183
+ ...
184
+ pkN_value: item
185
+ }
186
+ }
187
+ }
188
+
189
+ where the nesting depth equals the number of primary keys.
190
+
191
+ Sorting behavior:
192
+
193
+ 1. If no sort_by is provided, the items are returned in the order of insertion.
194
+ 2. If sort_by is provided, the items are sorted by the value of the sort_by field.
195
+ 3. If the sort_by field is a timestamp, the null values are treated as infinity.
196
+ 4. If the sort_by field is not a timestamp, the null values are treated as empty string
197
+ if the field is str-like, 0 if the field is int-like, 0.0 if the field is float-like.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ items: List[T],
203
+ item_type: Type[T],
204
+ primary_keys: Sequence[str],
205
+ id: Optional[str] = None,
206
+ tracker: Optional[MetricsBackend] = None,
207
+ ):
208
+ super().__init__(tracker=tracker)
209
+ if not primary_keys:
210
+ raise ValueError("primary_keys must be non-empty")
211
+
212
+ self._id = id if id is not None else str(uuid.uuid4())
213
+ self._items: Dict[Any, Any] = {}
214
+ self._size: int = 0
215
+ if issubclass(item_type, dict):
216
+ raise TypeError(f"Expect item to be not a dict, got {item_type.__name__}")
217
+ self._item_type: Type[T] = item_type
218
+ self._primary_keys: Tuple[str, ...] = tuple(primary_keys)
219
+
220
+ # Pre-populate the collection with the given items.
221
+ for item in items or []:
222
+ self._mutate_single(item, mode="insert")
223
+
224
+ @property
225
+ def collection_name(self) -> str:
226
+ return self._id
227
+
228
+ def primary_keys(self) -> Sequence[str]:
229
+ """Return the primary key field names for this collection."""
230
+ return self._primary_keys
231
+
232
+ def item_type(self) -> Type[T]:
233
+ """Return the Pydantic model type of items stored in this collection."""
234
+ return self._item_type
235
+
236
+ async def size(self) -> int:
237
+ """Return the number of items stored in the collection."""
238
+ return self._size
239
+
240
+ def __repr__(self) -> str:
241
+ return f"<{self.__class__.__name__}[{self.item_type().__name__}] ({self._size})>"
242
+
243
+ # -------------------------------------------------------------------------
244
+ # Internal helpers
245
+ # -------------------------------------------------------------------------
246
+
247
+ def _ensure_item_type(self, item: T) -> None:
248
+ """Validate that the item matches the declared item_type."""
249
+ if not isinstance(item, self._item_type):
250
+ raise TypeError(f"Expected item of type {self._item_type.__name__}, " f"got {type(item).__name__}")
251
+
252
+ def _extract_primary_key_values(self, item: T) -> Tuple[Any, ...]:
253
+ """Extract the primary key values from an item.
254
+
255
+ Raises:
256
+ ValueError: If any primary key is missing on the item.
257
+ """
258
+ values: List[Any] = []
259
+ for key in self._primary_keys:
260
+ if not hasattr(item, key):
261
+ raise ValueError(f"Item {item} does not have primary key field '{key}'")
262
+ values.append(getattr(item, key))
263
+ return tuple(values)
264
+
265
+ def _render_key_values(self, key_values: Sequence[Any]) -> str:
266
+ return ", ".join(f"{name}={value!r}" for name, value in zip(self._primary_keys, key_values))
267
+
268
+ def _locate_node(
269
+ self,
270
+ key_values: Sequence[Any],
271
+ create_missing: bool,
272
+ ) -> Tuple[MutableMapping[Any, Any], Any]:
273
+ """Locate the parent mapping and final key for an item path.
274
+
275
+ Args:
276
+ key_values: The sequence of primary key values.
277
+ create_missing: Whether to create intermediate dictionaries as needed.
278
+
279
+ Returns:
280
+ (parent_mapping, final_key)
281
+
282
+ Raises:
283
+ KeyError: If the path does not exist and create_missing is False.
284
+ ValueError: If the internal structure is corrupted (non-dict where dict is expected).
285
+ """
286
+ if not key_values:
287
+ raise ValueError("key_values must be non-empty")
288
+
289
+ current: MutableMapping[Any, Any] = self._items
290
+ for idx, value in enumerate(key_values):
291
+ is_last = idx == len(key_values) - 1
292
+ if is_last:
293
+ # At the final level, current[value] is the item (or will be).
294
+ return current, value # type: ignore
295
+
296
+ # Intermediate level: current[value] must be a dict.
297
+ if value not in current:
298
+ if not create_missing:
299
+ raise KeyError(f"Path does not exist for given primary keys: {self._render_key_values(key_values)}")
300
+ current[value] = {}
301
+ next_node = current[value] # type: ignore
302
+ if not isinstance(next_node, dict):
303
+ raise ValueError(f"Internal structure corrupted: expected dict, got {type(next_node)!r}") # type: ignore
304
+ current = next_node # type: ignore
305
+
306
+ # We should always return inside the loop.
307
+ raise RuntimeError("Unreachable")
308
+
309
+ def _mutate_single(self, item: T, mode: MutationMode, update_fields: Sequence[str] | None = None) -> Optional[T]:
310
+ """Core mutation logic shared by insert, update, upsert, and delete."""
311
+ self._ensure_item_type(item)
312
+ key_values = self._extract_primary_key_values(item)
313
+
314
+ if mode in ("insert", "upsert"):
315
+ parent, final_key = self._locate_node(key_values, create_missing=True)
316
+ exists = final_key in parent
317
+
318
+ if mode == "insert":
319
+ if exists:
320
+ raise DuplicatedPrimaryKeyError(
321
+ f"Item already exists with primary key(s): {self._render_key_values(key_values)}"
322
+ )
323
+ parent[final_key] = item
324
+ self._size += 1
325
+ else: # upsert
326
+ if not exists:
327
+ self._size += 1
328
+ parent[final_key] = item
329
+
330
+ elif update_fields is None:
331
+ # update_or_insert: update all fields
332
+ parent[final_key] = item
333
+
334
+ else:
335
+ if not issubclass(self._item_type, BaseModel):
336
+ raise TypeError(
337
+ f"When using update_fields, the item type must be a Pydantic BaseModel, got {self._item_type.__name__}"
338
+ )
339
+
340
+ # Try to fetch the existing item
341
+ existing = parent[final_key]
342
+ if not isinstance(existing, self._item_type):
343
+ raise ValueError(
344
+ f"Internal structure corrupted: expected {self._item_type.__name__}, got {type(existing)!r}"
345
+ )
346
+
347
+ if not isinstance(item, self._item_type):
348
+ raise TypeError(
349
+ f"When using update_fields, the item type must be a Pydantic BaseModel, got {type(item).__name__}"
350
+ )
351
+
352
+ parent[final_key] = parent[final_key].model_copy(
353
+ update={field: getattr(item, field) for field in update_fields}
354
+ )
355
+
356
+ return parent[final_key]
357
+
358
+ elif mode in ("update", "delete"):
359
+ # For update/delete we must not create missing paths.
360
+ try:
361
+ parent, final_key = self._locate_node(key_values, create_missing=False)
362
+ except KeyError:
363
+ raise ValueError(
364
+ f"Item does not exist with primary key(s): {self._render_key_values(key_values)}"
365
+ ) from None
366
+
367
+ if final_key not in parent:
368
+ raise ValueError(f"Item does not exist with primary key(s): {self._render_key_values(key_values)}")
369
+
370
+ if mode == "update":
371
+ if update_fields is None:
372
+ # replace the entire item
373
+ parent[final_key] = item
374
+ else:
375
+ if not issubclass(self._item_type, BaseModel):
376
+ raise TypeError(
377
+ f"When using update_fields, the item type must be a Pydantic BaseModel, got {self._item_type.__name__}"
378
+ )
379
+ if not isinstance(item, self._item_type):
380
+ raise TypeError(
381
+ f"When using update_fields, the item type must be a Pydantic BaseModel, got {type(item).__name__}"
382
+ )
383
+ parent[final_key] = parent[final_key].model_copy(
384
+ update={field: getattr(item, field) for field in update_fields}
385
+ )
386
+ return parent[final_key]
387
+ else: # delete
388
+ del parent[final_key]
389
+ self._size -= 1
390
+ else:
391
+ raise ValueError(f"Unknown mutation mode: {mode}")
392
+
393
+ def _iter_items(
394
+ self,
395
+ root: Optional[Mapping[Any, Any]] = None,
396
+ filters: Optional[FilterMap] = None,
397
+ must_filters: Optional[FilterMap] = None,
398
+ filter_logic: Literal["and", "or"] = "and",
399
+ ) -> Iterable[T]:
400
+ """Iterate over all items in the nested dictionary structure, optionally applying filters."""
401
+ if root is None:
402
+ root = self._items
403
+ if not root:
404
+ return
405
+ stack: List[Mapping[Any, Any]] = [root]
406
+ while stack:
407
+ node = stack.pop()
408
+ for value in node.values():
409
+ # Leaf nodes contain items; intermediate nodes are dicts.
410
+ if isinstance(value, self._item_type):
411
+ if _item_matches_filters(value, filters, filter_logic, must_filters):
412
+ yield value
413
+ elif isinstance(value, dict):
414
+ stack.append(value) # type: ignore
415
+ else:
416
+ raise ValueError(
417
+ f"Internal structure corrupted: expected dict or {self._item_type.__name__}, "
418
+ f"got {type(value)!r}"
419
+ )
420
+
421
+ def _iter_matching_items(
422
+ self,
423
+ filters: Optional[FilterMap],
424
+ must_filters: Optional[FilterMap],
425
+ filter_logic: Literal["and", "or"],
426
+ ) -> Iterable[T]:
427
+ """Efficiently iterate over items matching filters, using primary-key prefix when possible."""
428
+ # Fast path: when optional filters can't form a prefix, fall back to scanning.
429
+ if filter_logic != "and" and must_filters is None:
430
+ return self._iter_items(filters=filters, must_filters=must_filters, filter_logic=filter_logic)
431
+
432
+ # Try to derive a primary-key prefix from exact filters.
433
+ pk_values_prefix: List[Any] = []
434
+ prefix_sources: List[FilterMap] = []
435
+ if must_filters:
436
+ prefix_sources.append(must_filters)
437
+ if filter_logic == "and" and filters:
438
+ prefix_sources.append(filters)
439
+
440
+ for pk in self._primary_keys:
441
+ # combined_ops are: [{"exact": value}, {"within": [...]}, ...]
442
+ combined_ops: List[FilterField] = []
443
+ for source in prefix_sources:
444
+ field_ops = source.get(pk) # type: ignore[union-attr]
445
+ if field_ops:
446
+ combined_ops.append(field_ops)
447
+ if not combined_ops:
448
+ break
449
+ # Only allow a pure {"exact": value} constraint.
450
+ exact_value: Any | None = None
451
+ allow_prefix = True
452
+ for ops in combined_ops:
453
+ if set(ops.keys()) != {"exact"}:
454
+ allow_prefix = False
455
+ break
456
+ candidate = ops.get("exact")
457
+ if candidate is None:
458
+ allow_prefix = False
459
+ break
460
+ if exact_value is not None and candidate != exact_value:
461
+ # Contradictory exact filters mean no items can match.
462
+ logger.warning(f"Contradictory exact filters for field '{pk}': {exact_value} != {candidate}")
463
+ return ()
464
+ exact_value = candidate
465
+
466
+ if not allow_prefix:
467
+ break
468
+
469
+ value = exact_value
470
+ if value is None:
471
+ break
472
+ pk_values_prefix.append(value)
473
+
474
+ if not pk_values_prefix:
475
+ return self._iter_items(filters=filters, must_filters=must_filters, filter_logic=filter_logic)
476
+
477
+ try:
478
+ if len(pk_values_prefix) == len(self._primary_keys):
479
+ # All primary keys specified -> at most a single item.
480
+ parent, final_key = self._locate_node(pk_values_prefix, create_missing=False)
481
+ single_item = parent.get(final_key)
482
+ if isinstance(single_item, self._item_type) and _item_matches_filters(
483
+ single_item,
484
+ filters,
485
+ filter_logic,
486
+ must_filters,
487
+ ):
488
+ return (single_item,)
489
+ return ()
490
+ else:
491
+ # Prefix of primary keys specified -> iterate only the subtree below that prefix.
492
+ parent, final_key = self._locate_node(pk_values_prefix, create_missing=False)
493
+ subtree = parent.get(final_key)
494
+ if isinstance(subtree, dict):
495
+ return self._iter_items(
496
+ subtree, # type: ignore
497
+ filters=filters,
498
+ must_filters=must_filters,
499
+ filter_logic=filter_logic,
500
+ )
501
+ return ()
502
+ except KeyError:
503
+ # No items exist for this primary-key prefix.
504
+ return ()
505
+
506
+ @tracked("query")
507
+ async def query(
508
+ self,
509
+ filter: Optional[FilterOptions] = None,
510
+ sort: Optional[SortOptions] = None,
511
+ limit: int = -1,
512
+ offset: int = 0,
513
+ ) -> PaginatedResult[T]:
514
+ """Query the collection with filters, sort order, and pagination.
515
+
516
+ Args:
517
+ filter: Mapping of field name to operator dict along with the optional `_aggregate` logic.
518
+ sort: Options describing which field to sort by and in which order.
519
+ limit: Max number of items to return. Use -1 for "no limit".
520
+ offset: Number of items to skip from the start of the *matching* items.
521
+ """
522
+ filters, must_filters, filter_logic = normalize_filter_options(filter)
523
+ sort_by, sort_order = resolve_sort_options(sort)
524
+ items_iter: Iterable[T] = self._iter_matching_items(filters, must_filters, filter_logic)
525
+
526
+ # No sorting: stream through items and apply pagination on the fly.
527
+ if not sort_by:
528
+ matched_items: List[T] = []
529
+ total_matched = 0
530
+
531
+ for item in items_iter:
532
+ # Count every match for 'total'
533
+ total_matched += 1
534
+
535
+ # Apply offset/limit window
536
+ if total_matched <= offset:
537
+ continue
538
+ if limit != -1 and len(matched_items) >= limit:
539
+ # Still need to finish iteration to get accurate total_matched.
540
+ continue
541
+
542
+ matched_items.append(item)
543
+
544
+ return PaginatedResult(
545
+ items=matched_items,
546
+ limit=limit,
547
+ offset=offset,
548
+ total=total_matched,
549
+ )
550
+
551
+ # With sorting: we must materialize all matching items to sort them.
552
+ all_matches: List[T] = list(items_iter)
553
+
554
+ total_matched = len(all_matches)
555
+ reverse = sort_order == "desc"
556
+ all_matches.sort(key=lambda x: _get_sort_value(x, sort_by), reverse=reverse)
557
+
558
+ if limit == -1:
559
+ paginated_items = all_matches[offset:]
560
+ else:
561
+ paginated_items = all_matches[offset : offset + limit]
562
+
563
+ return PaginatedResult(
564
+ items=paginated_items,
565
+ limit=limit,
566
+ offset=offset,
567
+ total=total_matched,
568
+ )
569
+
570
+ @tracked("get")
571
+ async def get(
572
+ self,
573
+ filter: Optional[FilterOptions] = None,
574
+ sort: Optional[SortOptions] = None,
575
+ ) -> Optional[T]:
576
+ """Return the first (or best-sorted) item that matches the given filters, or None."""
577
+ filters, must_filters, filter_logic = normalize_filter_options(filter)
578
+ sort_by, sort_order = resolve_sort_options(sort)
579
+ items_iter: Iterable[T] = self._iter_matching_items(filters, must_filters, filter_logic)
580
+
581
+ if not sort_by:
582
+ # Just return the first matching item, if any.
583
+ for item in items_iter:
584
+ return item
585
+ return None
586
+
587
+ # Single-pass min/max according to sort_order.
588
+ best_item: Optional[T] = None
589
+ best_key: Any = None
590
+
591
+ for item in items_iter:
592
+ key = _get_sort_value(item, sort_by)
593
+ if best_item is None:
594
+ best_item = item
595
+ best_key = key
596
+ continue
597
+
598
+ if sort_order == "asc":
599
+ if key < best_key:
600
+ best_item, best_key = item, key
601
+ else:
602
+ if key > best_key:
603
+ best_item, best_key = item, key
604
+
605
+ return best_item
606
+
607
+ @tracked("insert")
608
+ async def insert(self, items: Sequence[T]) -> None:
609
+ """Insert the given items.
610
+
611
+ Raises:
612
+ DuplicatedPrimaryKeyError: If any item with the same primary keys already exists.
613
+ """
614
+ seen_keys: set[Tuple[Any, ...]] = set()
615
+ prepared: List[T] = []
616
+ for item in items:
617
+ self._ensure_item_type(item)
618
+ key_values = self._extract_primary_key_values(item)
619
+ if key_values in seen_keys:
620
+ raise DuplicatedPrimaryKeyError(
621
+ f"Insert payload contains duplicated primary key(s): {self._render_key_values(key_values)}"
622
+ )
623
+ seen_keys.add(key_values)
624
+ prepared.append(item)
625
+
626
+ for item in prepared:
627
+ self._mutate_single(item, mode="insert")
628
+
629
+ @tracked("update")
630
+ async def update(self, items: Sequence[T], update_fields: Sequence[str] | None = None) -> Sequence[T]:
631
+ """Update the given items.
632
+
633
+ Raises:
634
+ ValueError: If any item with the given primary keys does not exist.
635
+ """
636
+ updated_items: List[T] = []
637
+ for item in items:
638
+ updated = self._mutate_single(item, mode="update", update_fields=update_fields)
639
+ if updated is None:
640
+ raise RuntimeError(f"_mutate_single returned None for item {item}. This should never happen.")
641
+ updated_items.append(updated)
642
+ return updated_items
643
+
644
+ @tracked("upsert")
645
+ async def upsert(self, items: Sequence[T], update_fields: Sequence[str] | None = None) -> Sequence[T]:
646
+ """Upsert the given items (insert if missing, otherwise update)."""
647
+ upserted_items: List[T] = []
648
+ for item in items:
649
+ upserted = self._mutate_single(item, mode="upsert", update_fields=update_fields)
650
+ if upserted is None:
651
+ raise RuntimeError(f"_mutate_single returned None for item {item}. This should never happen.")
652
+ upserted_items.append(upserted)
653
+ return upserted_items
654
+
655
+ @tracked("delete")
656
+ async def delete(self, items: Sequence[T]) -> None:
657
+ """Delete the given items.
658
+
659
+ Raises:
660
+ ValueError: If any item with the given primary keys does not exist.
661
+ """
662
+ # We use a two-phase approach to avoid partial deletion if one fails:
663
+ # first compute key_values to validate, then perform deletions.
664
+ for item in items:
665
+ # _mutate_single will validate existence and update size.
666
+ self._mutate_single(item, mode="delete")
667
+
668
+
669
+ class DequeBasedQueue(Queue[T]):
670
+ """Queue implementation backed by collections.deque.
671
+
672
+ Provides O(1) amortized enqueue (append) and dequeue (popleft).
673
+ """
674
+
675
+ def __init__(
676
+ self,
677
+ item_type: Type[T],
678
+ items: Optional[Sequence[T]] = None,
679
+ id: Optional[str] = None,
680
+ tracker: Optional[MetricsBackend] = None,
681
+ ):
682
+ super().__init__(tracker=tracker)
683
+ self._items: Deque[T] = deque()
684
+ self._item_type: Type[T] = item_type
685
+ self._id = id if id is not None else str(uuid.uuid4())
686
+ if items:
687
+ self._items.extend(items)
688
+
689
+ def item_type(self) -> Type[T]:
690
+ return self._item_type
691
+
692
+ @property
693
+ def collection_name(self) -> str:
694
+ return self._id
695
+
696
+ def __repr__(self) -> str:
697
+ return f"<{self.__class__.__name__}[{self.item_type().__name__}] ({len(self._items)})>"
698
+
699
+ @tracked("has")
700
+ async def has(self, item: T) -> bool:
701
+ if not isinstance(item, self._item_type):
702
+ raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
703
+ return item in self._items
704
+
705
+ @tracked("enqueue")
706
+ async def enqueue(self, items: Sequence[T]) -> Sequence[T]:
707
+ for item in items:
708
+ if not isinstance(item, self._item_type):
709
+ raise TypeError(f"Expected item of type {self._item_type.__name__}, got {type(item).__name__}")
710
+ self._items.append(item)
711
+ return items
712
+
713
+ @tracked("dequeue")
714
+ async def dequeue(self, limit: int = 1) -> Sequence[T]:
715
+ if limit <= 0:
716
+ return []
717
+ out: List[T] = []
718
+ for _ in range(min(limit, len(self._items))):
719
+ out.append(self._items.popleft())
720
+ return out
721
+
722
+ @tracked("peek")
723
+ async def peek(self, limit: int = 1) -> Sequence[T]:
724
+ if limit <= 0:
725
+ return []
726
+ result: List[T] = []
727
+ count = min(limit, len(self._items))
728
+ for idx, item in enumerate(self._items):
729
+ if idx >= count:
730
+ break
731
+ result.append(item)
732
+ return result
733
+
734
+ @tracked("size")
735
+ async def size(self) -> int:
736
+ return len(self._items)
737
+
738
+
739
+ class DictBasedKeyValue(KeyValue[K, V]):
740
+ """KeyValue implementation backed by a plain dictionary."""
741
+
742
+ def __init__(
743
+ self, data: Optional[Mapping[K, V]] = None, id: Optional[str] = None, tracker: Optional[MetricsBackend] = None
744
+ ):
745
+ super().__init__(tracker=tracker)
746
+ self._values: Dict[K, V] = dict(data) if data else {}
747
+ self._id = id if id is not None else str(uuid.uuid4())
748
+
749
+ @property
750
+ def collection_name(self) -> str:
751
+ return self._id
752
+
753
+ @tracked("has")
754
+ async def has(self, key: K) -> bool:
755
+ return key in self._values
756
+
757
+ @tracked("get")
758
+ async def get(self, key: K, default: V | None = None) -> V | None:
759
+ return self._values.get(key, default)
760
+
761
+ @tracked("set")
762
+ async def set(self, key: K, value: V) -> None:
763
+ self._values[key] = value
764
+
765
+ @tracked("inc")
766
+ async def inc(self, key: K, amount: V) -> V:
767
+ assert ensure_numeric(amount, description="amount")
768
+ if key in self._values:
769
+ current_value = self._values[key]
770
+ assert ensure_numeric(current_value, description=f"value for key {key!r}")
771
+ new_value = cast(V, current_value + amount)
772
+ self._values[key] = new_value
773
+ else:
774
+ new_value = amount
775
+ self._values[key] = new_value
776
+ return new_value
777
+
778
+ @tracked("chmax")
779
+ async def chmax(self, key: K, value: V) -> V:
780
+ assert ensure_numeric(value, description="value")
781
+ if key in self._values:
782
+ current_value = self._values[key]
783
+ assert ensure_numeric(current_value, description=f"value for key {key!r}")
784
+ if value > current_value:
785
+ self._values[key] = value
786
+ return value
787
+ return current_value
788
+ else:
789
+ self._values[key] = value
790
+ return value
791
+
792
+ @tracked("pop")
793
+ async def pop(self, key: K, default: V | None = None) -> V | None:
794
+ return self._values.pop(key, default)
795
+
796
+ @tracked("size")
797
+ async def size(self) -> int:
798
+ return len(self._values)
799
+
800
+
801
+ class InMemoryLightningCollections(LightningCollections):
802
+ """In-memory implementation of LightningCollections using Python data structures.
803
+
804
+ Serves as the storage base for [`InMemoryLightningStore`][mantisdk.InMemoryLightningStore].
805
+ """
806
+
807
+ def __init__(self, lock_type: Literal["thread", "asyncio"], tracker: MetricsBackend | None = None):
808
+ super().__init__(tracker=tracker)
809
+ self._lock: Mapping[AtomicLabels, _LoopAwareAsyncLock | _ThreadSafeAsyncLock] = {
810
+ "rollouts": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
811
+ "attempts": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
812
+ "spans": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
813
+ "resources": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
814
+ "workers": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
815
+ "rollout_queue": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
816
+ "span_sequence_ids": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
817
+ "generic": _LoopAwareAsyncLock() if lock_type == "asyncio" else _ThreadSafeAsyncLock(),
818
+ }
819
+ self._rollouts = ListBasedCollection(
820
+ items=[], item_type=Rollout, primary_keys=["rollout_id"], id="rollouts", tracker=tracker
821
+ )
822
+ self._attempts = ListBasedCollection(
823
+ items=[], item_type=Attempt, primary_keys=["rollout_id", "attempt_id"], id="attempts", tracker=tracker
824
+ )
825
+ self._spans = ListBasedCollection(
826
+ items=[], item_type=Span, primary_keys=["rollout_id", "attempt_id", "span_id"], id="spans", tracker=tracker
827
+ )
828
+ self._resources = ListBasedCollection(
829
+ items=[], item_type=ResourcesUpdate, primary_keys=["resources_id"], id="resources", tracker=tracker
830
+ )
831
+ self._workers = ListBasedCollection(
832
+ items=[], item_type=Worker, primary_keys=["worker_id"], id="workers", tracker=tracker
833
+ )
834
+ self._rollout_queue = DequeBasedQueue(items=[], item_type=str, id="rollout_queue", tracker=tracker)
835
+ self._span_sequence_ids = DictBasedKeyValue[str, int](
836
+ data={}, id="span_sequence_ids", tracker=tracker
837
+ ) # rollout_id -> sequence_id
838
+
839
+ @property
840
+ def collection_name(self) -> str:
841
+ return "router"
842
+
843
+ @property
844
+ def rollouts(self) -> ListBasedCollection[Rollout]:
845
+ return self._rollouts
846
+
847
+ @property
848
+ def attempts(self) -> ListBasedCollection[Attempt]:
849
+ return self._attempts
850
+
851
+ @property
852
+ def spans(self) -> ListBasedCollection[Span]:
853
+ return self._spans
854
+
855
+ @property
856
+ def resources(self) -> ListBasedCollection[ResourcesUpdate]:
857
+ return self._resources
858
+
859
+ @property
860
+ def workers(self) -> ListBasedCollection[Worker]:
861
+ return self._workers
862
+
863
+ @property
864
+ def rollout_queue(self) -> DequeBasedQueue[str]:
865
+ return self._rollout_queue
866
+
867
+ @property
868
+ def span_sequence_ids(self) -> DictBasedKeyValue[str, int]:
869
+ return self._span_sequence_ids
870
+
871
+ @asynccontextmanager
872
+ async def atomic(
873
+ self,
874
+ *,
875
+ mode: AtomicMode = "rw",
876
+ snapshot: bool = False,
877
+ labels: Optional[Sequence[AtomicLabels]] = None,
878
+ **kwargs: Any,
879
+ ):
880
+ """In-memory collections apply a lock outside. It doesn't need to manipulate the collections inside.
881
+
882
+ Skip the locking if mode is "r" and snapshot is False.
883
+
884
+ This collection implementation does NOT support rollback / commit.
885
+ """
886
+ if mode == "r" and not snapshot:
887
+ yield self
888
+ return
889
+ if not labels:
890
+ # If no labels are provided, use all locks.
891
+ labels = list(self._lock.keys())
892
+
893
+ # IMPORTANT: Sort the labels to ensure consistent locking order.
894
+ # This is necessary to avoid deadlocks when multiple threads/coroutines
895
+ # are trying to acquire the same locks in different orders.
896
+ labels = sorted(labels)
897
+
898
+ async with self.tracking_context(operation="atomic", collection=self.collection_name):
899
+ managers = [(label, self._lock[label]) for label in labels]
900
+ async with AsyncExitStack() as stack:
901
+ for label, manager in managers:
902
+ async with self.tracking_context(operation="lock", collection=label):
903
+ await stack.enter_async_context(manager)
904
+ yield self
905
+
906
+ @tracked("evict_spans_for_rollout")
907
+ async def evict_spans_for_rollout(self, rollout_id: str) -> None:
908
+ """Evict all spans for a given rollout ID.
909
+
910
+ Uses private API for efficiency.
911
+ """
912
+ self._spans._items.pop(rollout_id, []) # pyright: ignore[reportPrivateUsage]
913
+
914
+
915
+ class _LoopAwareAsyncLock:
916
+ """Async lock that transparently rebinds to the current event loop.
917
+
918
+ The lock intentionally remains *thread-unsafe*: callers must only use it from
919
+ one thread at a time. If multiple threads interact with the store, each
920
+ thread gets its own event loop specific lock.
921
+ """
922
+
923
+ def __init__(self) -> None:
924
+ self._locks: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, asyncio.Lock] = weakref.WeakKeyDictionary()
925
+
926
+ # When serializing and deserializing, we don't need to serialize the locks.
927
+ # Because another process will have its own set of event loops and its own lock.
928
+ def __getstate__(self) -> dict[str, Any]:
929
+ return {}
930
+
931
+ def __setstate__(self, state: dict[str, Any]) -> None:
932
+ self._locks = weakref.WeakKeyDictionary()
933
+
934
+ def _get_lock_for_current_loop(self) -> asyncio.Lock:
935
+ loop = asyncio.get_running_loop()
936
+ lock = self._locks.get(loop)
937
+ if lock is None:
938
+ lock = asyncio.Lock()
939
+ self._locks[loop] = lock
940
+ return lock
941
+
942
+ async def __aenter__(self) -> asyncio.Lock:
943
+ lock = self._get_lock_for_current_loop()
944
+ await lock.acquire()
945
+ return lock
946
+
947
+ async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any) -> None:
948
+ loop = asyncio.get_running_loop()
949
+ lock = self._locks.get(loop)
950
+ if lock is None or not lock.locked():
951
+ raise RuntimeError("Lock released without being acquired")
952
+ lock.release()
953
+
954
+
955
+ class _ThreadSafeAsyncLock:
956
+ """A thread lock powered by aiologic that can be used in both async and sync contexts.
957
+
958
+ aiologic claims itself to be a thread-safe asyncio lock.
959
+ """
960
+
961
+ def __init__(self):
962
+ self._lock = aiologic.Lock()
963
+
964
+ async def __aenter__(self):
965
+ await self._lock.async_acquire()
966
+ return self
967
+
968
+ async def __aexit__(self, *args: Any, **kwargs: Any):
969
+ # .release() is non-blocking, so we can call it directly
970
+ self._lock.async_release()