mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,1823 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Collection-based LightningStore implementation.
4
+
5
+ To developers, please check whether the implementation is correct by checking the following:
6
+
7
+ 1. Whether all `_unlocked_*` methods are guarded by some `atomic()` or `execute()` context.
8
+ 2. Whether all `atomic()` or `execute()` contexts are labeled (labels="...") correctly.
9
+ 3. `_unlocked_update_attempt_and_rollout` should be accompanied by `_post_update_rollout`, `_unlocked_sync_worker_with_attempt`.
10
+ 4. `_post_add_spans` should be called after the spans are inserted into the store.
11
+ 5. `_unlocked_update_rollout_only` should be accompanied by `_post_update_rollout`.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import functools
18
+ import logging
19
+ import time
20
+ import warnings
21
+ from collections import defaultdict
22
+ from contextvars import ContextVar
23
+ from types import CoroutineType
24
+ from typing import (
25
+ TYPE_CHECKING,
26
+ Any,
27
+ Callable,
28
+ Dict,
29
+ Generic,
30
+ List,
31
+ Literal,
32
+ Optional,
33
+ ParamSpec,
34
+ Sequence,
35
+ Tuple,
36
+ TypeVar,
37
+ Union,
38
+ cast,
39
+ )
40
+
41
+ if TYPE_CHECKING:
42
+ from .listener import StorageListener
43
+
44
+ from opentelemetry.sdk.trace import ReadableSpan
45
+ from pydantic import BaseModel
46
+ from typing_extensions import Concatenate
47
+
48
+ from mantisdk.types import (
49
+ Attempt,
50
+ AttemptedRollout,
51
+ AttemptStatus,
52
+ EnqueueRolloutRequest,
53
+ FilterField,
54
+ NamedResources,
55
+ PaginatedResult,
56
+ ResourcesUpdate,
57
+ Rollout,
58
+ RolloutConfig,
59
+ RolloutStatus,
60
+ SortOptions,
61
+ Span,
62
+ TaskInput,
63
+ Worker,
64
+ WorkerStatus,
65
+ )
66
+ from mantisdk.utils.id import generate_id
67
+ from mantisdk.utils.metrics import MetricsBackend
68
+
69
+ from .base import (
70
+ UNSET,
71
+ LightningStore,
72
+ LightningStoreCapabilities,
73
+ LightningStoreStatistics,
74
+ Unset,
75
+ is_finished,
76
+ is_queuing,
77
+ )
78
+ from .collection import FilterOptions, LightningCollections
79
+ from .collection.base import AtomicLabels, DuplicatedPrimaryKeyError
80
+ from .utils import LATENCY_BUCKETS, rollout_status_from_attempt, scan_unhealthy_rollouts
81
+
82
+ T_callable = TypeVar("T_callable", bound=Callable[..., Any])
83
+ T_model = TypeVar("T_model", bound=BaseModel)
84
+ T_collections = TypeVar("T_collections", bound=LightningCollections)
85
+
86
+ P = ParamSpec("P")
87
+ R = TypeVar("R")
88
+ C = TypeVar("C") # The collections type
89
+
90
+ SelfT = TypeVar("SelfT", bound="CollectionBasedLightningStore[Any]")
91
+
92
+ logger = logging.getLogger(__name__)
93
+
94
+ # ContextVars for tracking the current store method without expensive stack introspection.
95
+ # These are set by the @tracked decorator and read by tracking_context in collection/base.py.
96
+ _UNKNOWN_STORE_METHOD = "unknown"
97
+ _current_public_store_method: ContextVar[str] = ContextVar("public_store_method", default=_UNKNOWN_STORE_METHOD)
98
+ _current_private_store_method: ContextVar[str] = ContextVar("private_store_method", default=_UNKNOWN_STORE_METHOD)
99
+
100
+
101
+ def _with_collections_execute(labels: Sequence[AtomicLabels]):
102
+ """Hands over the function execution to the collections.execute method.
103
+ Used to enable committing and automatic retries.
104
+
105
+ The wrapped function should accept an extra locked collection as its first argument.
106
+ """
107
+
108
+ def decorator(
109
+ func: Callable[Concatenate[SelfT, T_collections, P], CoroutineType[Any, Any, R]],
110
+ ) -> Callable[Concatenate[SelfT, P], CoroutineType[Any, Any, R]]:
111
+
112
+ @functools.wraps(func)
113
+ async def wrapper(self: SelfT, *args: P.args, **kwargs: P.kwargs) -> R:
114
+ async def callback(collections: T_collections) -> R:
115
+ return await func(self, collections, *args, **kwargs)
116
+
117
+ return await self.collections.execute(
118
+ callback,
119
+ mode="rw", # Read-write all enabled.
120
+ snapshot=self._read_snapshot, # pyright: ignore[reportPrivateUsage]
121
+ commit=True, # Enable committing.
122
+ labels=labels,
123
+ )
124
+
125
+ return wrapper
126
+
127
+ return decorator
128
+
129
+
130
+ def tracked(name: str):
131
+ """Decorator to track the execution of the decorated method with Prometheus."""
132
+
133
+ def decorator(func: T_callable) -> T_callable:
134
+
135
+ @functools.wraps(func)
136
+ async def wrapper(self: CollectionBasedLightningStore[T_collections], *args: Any, **kwargs: Any) -> Any:
137
+ # Get the current public method from ContextVar (set by outer tracked methods)
138
+ public_meth_in_stack = _current_public_store_method.get()
139
+
140
+ # Set ContextVars for nested calls to read. Use tokens for proper cleanup.
141
+ pub_token = None
142
+ priv_token = None
143
+ if name in COLLECTION_STORE_PUBLIC_METHODS:
144
+ pub_token = _current_public_store_method.set(name)
145
+ public_meth_in_stack = name # We are in a public method already.
146
+ if name in COLLECTION_STORE_ALL_METHODS:
147
+ priv_token = _current_private_store_method.set(name)
148
+
149
+ try:
150
+ if self._tracker is None: # pyright: ignore[reportPrivateUsage]
151
+ # Skip the tracking because tracking is not configured
152
+ return await func(self, *args, **kwargs)
153
+
154
+ start_time = time.perf_counter()
155
+ status: str = "OK"
156
+ try:
157
+ return await func(self, *args, **kwargs)
158
+ except BaseException as exc:
159
+ status = exc.__class__.__name__
160
+ raise
161
+ finally:
162
+ elapsed = time.perf_counter() - start_time
163
+ await self._tracker.inc_counter( # pyright: ignore[reportPrivateUsage]
164
+ "msk.store.total",
165
+ labels={"method": name, "store_pubmeth": public_meth_in_stack, "status": status},
166
+ )
167
+ await self._tracker.observe_histogram( # pyright: ignore[reportPrivateUsage]
168
+ "msk.store.latency",
169
+ value=elapsed,
170
+ labels={"method": name, "store_pubmeth": public_meth_in_stack, "status": status},
171
+ )
172
+ finally:
173
+ # Reset ContextVars to their previous values
174
+ if pub_token is not None:
175
+ _current_public_store_method.reset(pub_token)
176
+ if priv_token is not None:
177
+ _current_private_store_method.reset(priv_token)
178
+
179
+ return cast(T_callable, wrapper)
180
+
181
+ return decorator
182
+
183
+
184
+ def healthcheck_before(func: T_callable) -> T_callable:
185
+ """
186
+ Decorator to run the watchdog healthcheck **before** executing the decorated method.
187
+ Only runs if the store has a watchdog configured.
188
+ Prevents recursive healthcheck execution using a flag on the store instance.
189
+ """
190
+
191
+ @functools.wraps(func)
192
+ async def wrapper(self: CollectionBasedLightningStore[T_collections], *args: Any, **kwargs: Any) -> Any:
193
+ # Check if healthcheck is already running to prevent recursion
194
+ if getattr(self, "_healthcheck_running", False):
195
+ # Skip healthcheck if already running
196
+ return await func(self, *args, **kwargs)
197
+
198
+ # Set flag to prevent recursive healthcheck calls
199
+ # This flag is not asyncio/thread-safe, but it doesn't matter
200
+ self._healthcheck_running = True # type: ignore
201
+ try:
202
+ # The following methods should live inside one lock.
203
+ await self._scan_for_unhealthy_rollouts() # pyright: ignore[reportPrivateUsage]
204
+ finally:
205
+ # Always clear the flag, even if healthcheck fails
206
+ self._healthcheck_running = False # type: ignore
207
+
208
+ # Execute the original method
209
+ # This should be outside the lock.
210
+ return await func(self, *args, **kwargs)
211
+
212
+ return cast(T_callable, wrapper)
213
+
214
+
215
+ def _generate_resources_id() -> str:
216
+ return "rs-" + generate_id(12)
217
+
218
+
219
+ def _generate_rollout_id() -> str:
220
+ return "ro-" + generate_id(12)
221
+
222
+
223
+ def _generate_attempt_id() -> str:
224
+ """We don't need that long because attempts are limited to rollouts."""
225
+ return "at-" + generate_id(8)
226
+
227
+
228
+ class CollectionBasedLightningStore(LightningStore, Generic[T_collections]):
229
+ """It's the standard implementation of LightningStore that uses collections to store data.
230
+
231
+ If the store implementation is to use the store's default behavior, it's recommended to
232
+ inherit from this class and override the methods if needed.
233
+ Bring your own collection implementation by using a different `collections` argument.
234
+
235
+ The methods in this class should generally not call each other,
236
+ especially those that are locked.
237
+
238
+ Args:
239
+ collections: The collections to use for storage.
240
+ read_snapshot: Make sure read operations are atomic. If set to true,
241
+ all read operations like `query_rollouts` will have better consistency.
242
+ It may use an isolated snapshot that supports repeatable reads.
243
+ tracker: Enable metrics tracking.
244
+ scan_debounce_seconds: The debounce time for the scan for unhealthy rollouts.
245
+ Set to 0 to disable debouncing. The debounce is a non-perfect traffic control.
246
+ It's isolated for each store instance if there are multiple worker replicas.
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ collections: T_collections,
252
+ *,
253
+ read_snapshot: bool = False,
254
+ tracker: MetricsBackend | None = None,
255
+ scan_debounce_seconds: float = 10.0,
256
+ listeners: Optional[Sequence["StorageListener"]] = None,
257
+ ) -> None:
258
+ super().__init__(listeners=listeners)
259
+ # rollouts and spans' storage
260
+ self.collections = collections
261
+ self._read_snapshot = read_snapshot
262
+ self._tracker = tracker
263
+ self._launch_time = time.time()
264
+
265
+ # Control scan debounce to avoid overloading the store.
266
+ self._scan_debounce_seconds = scan_debounce_seconds
267
+ last_scan_time = self._launch_time
268
+ if self._scan_debounce_seconds > 0:
269
+ # Allow the first scan immediately after instantiation
270
+ last_scan_time -= self._scan_debounce_seconds
271
+ self._last_scan_entrance_time = last_scan_time
272
+
273
+ if self._tracker is not None:
274
+ self._tracker.register_histogram(
275
+ "msk.store.latency",
276
+ ["method", "store_pubmeth", "status"],
277
+ buckets=LATENCY_BUCKETS,
278
+ group_level=1,
279
+ )
280
+ self._tracker.register_counter(
281
+ "msk.store.total",
282
+ ["method", "store_pubmeth", "status"],
283
+ group_level=1,
284
+ )
285
+ self._tracker.register_counter(
286
+ "msk.rollouts.total",
287
+ ["status", "mode"],
288
+ group_level=1,
289
+ )
290
+ self._tracker.register_histogram(
291
+ "msk.rollouts.duration",
292
+ ["status", "mode"],
293
+ buckets=LATENCY_BUCKETS,
294
+ group_level=1,
295
+ )
296
+
297
+ async def statistics(self) -> LightningStoreStatistics:
298
+ """Return the statistics of the store."""
299
+ current_time = time.time()
300
+ return {
301
+ "name": self.__class__.__name__,
302
+ "total_rollouts": await self.collections.rollouts.size(),
303
+ "total_attempts": await self.collections.attempts.size(),
304
+ "total_spans": await self.collections.spans.size(),
305
+ "total_resources": await self.collections.resources.size(),
306
+ "total_workers": await self.collections.workers.size(),
307
+ "uptime": current_time - self._launch_time,
308
+ }
309
+
310
+ async def _notify(self, method_name: str, *args: Any, **kwargs: Any) -> None:
311
+ """Notify all listeners of a storage event.
312
+
313
+ Catches and logs any exceptions from listeners to ensure storage
314
+ operations are never blocked by tracking failures.
315
+ """
316
+ for listener in self.listeners:
317
+ try:
318
+ method = getattr(listener, method_name, None)
319
+ if method is not None:
320
+ await method(*args, **kwargs)
321
+ except Exception as e:
322
+ logger.warning(f"Listener {listener.__class__.__name__}.{method_name} failed: {e}")
323
+
324
+ @tracked("_get_latest_resources")
325
+ async def _get_latest_resources(self) -> Optional[ResourcesUpdate]:
326
+ """Get the latest resources from the collections. Returns `None` if no resources are found."""
327
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["resources"]) as collections:
328
+ return await collections.resources.get(sort={"name": "update_time", "order": "desc"})
329
+
330
+ @tracked("_update_or_insert_worker")
331
+ async def _update_or_insert_worker(self, worker: Worker, update_fields: Sequence[str] | None = None) -> Worker:
332
+ """Create a worker if it doesn't exist. Update its `update_fields` if it already exists."""
333
+ async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["workers"]) as collections:
334
+ updated_workers = await collections.workers.upsert([worker], update_fields=update_fields)
335
+ return updated_workers[0]
336
+
337
+ @tracked("_unlocked_sync_worker_with_attempt")
338
+ async def _unlocked_sync_worker_with_attempt(
339
+ self, collections: T_collections, attempt: Attempt, dequeue: bool
340
+ ) -> None:
341
+ """Update the worker's status. This can be done in a separate session."""
342
+ worker_id = attempt.worker_id
343
+ if not worker_id:
344
+ return
345
+
346
+ worker = Worker(worker_id=worker_id)
347
+ update_fields: List[str] = []
348
+ now = time.time()
349
+
350
+ # This is called from dequeue_rollout
351
+ if dequeue:
352
+ worker.last_dequeue_time = now
353
+ update_fields.append("last_dequeue_time")
354
+
355
+ # NOTE: We don't check the status change anymore, in sake of performance.
356
+ # Instead, we always update the last_idle_time regardless of whether the attempt status has changed.
357
+ if attempt.status in ("succeeded", "failed"):
358
+ worker.last_idle_time = now
359
+ worker.status = "idle"
360
+ worker.current_rollout_id = None
361
+ worker.current_attempt_id = None
362
+ update_fields.extend(["last_idle_time", "status", "current_rollout_id", "current_attempt_id"])
363
+ elif attempt.status in ("timeout", "unresponsive"):
364
+ worker.last_idle_time = now
365
+ worker.status = "unknown"
366
+ worker.current_rollout_id = None
367
+ worker.current_attempt_id = None
368
+ update_fields.extend(["last_idle_time", "status", "current_rollout_id", "current_attempt_id"])
369
+ else:
370
+ worker.last_busy_time = now
371
+ worker.status = "busy"
372
+ worker.current_rollout_id = attempt.rollout_id
373
+ worker.current_attempt_id = attempt.attempt_id
374
+ update_fields.extend(["last_busy_time", "status", "current_rollout_id", "current_attempt_id"])
375
+
376
+ # Validate the schema to make sure it's valid.
377
+ Worker.model_validate(worker.model_dump())
378
+ await collections.workers.upsert([worker], update_fields=update_fields)
379
+
380
+ @property
381
+ def capabilities(self) -> LightningStoreCapabilities:
382
+ """Return the capabilities of the store.
383
+
384
+ This store supports no capability. The capability depends on the underlying collections.
385
+ """
386
+ return LightningStoreCapabilities()
387
+
388
+ @tracked("_sync_workers_with_attempts")
389
+ async def _sync_workers_with_attempts(self, attempts: Sequence[Attempt], dequeue: bool = False) -> None:
390
+ """Update the worker's status. Locked bulk version of `_unlocked_sync_workers_with_attempts`.
391
+
392
+ Use `dequeue = True` if `last_dequeue_time` should be updated.
393
+ """
394
+ async with self.collections.atomic(mode="w", snapshot=self._read_snapshot, labels=["workers"]) as collections:
395
+ for attempt in attempts:
396
+ await self._unlocked_sync_worker_with_attempt(collections, attempt, dequeue)
397
+
398
+ @tracked("_dequeue_mark_worker_idle")
399
+ async def _dequeue_mark_worker_idle(self, worker_id: str) -> None:
400
+ """Dequeue fails and mark the worker as idle."""
401
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["workers"]) as collections:
402
+ worker = await collections.workers.get({"worker_id": {"exact": worker_id}})
403
+ now = time.time()
404
+ if not worker or worker.status != "idle":
405
+ # should mark the worker as idle
406
+ worker = Worker(worker_id=worker_id, status="idle", last_idle_time=now, last_dequeue_time=now)
407
+ await self._update_or_insert_worker(worker, update_fields=["status", "last_idle_time", "last_dequeue_time"])
408
+ else:
409
+ # only update last_dequeue_time
410
+ worker = Worker(worker_id=worker_id, last_dequeue_time=now)
411
+ await self._update_or_insert_worker(worker, update_fields=["last_dequeue_time"])
412
+
413
+ @tracked("start_rollout")
414
+ @healthcheck_before
415
+ async def start_rollout(
416
+ self,
417
+ input: TaskInput,
418
+ mode: Literal["train", "val", "test"] | None = None,
419
+ resources_id: str | None = None,
420
+ config: RolloutConfig | None = None,
421
+ metadata: Dict[str, Any] | None = None,
422
+ worker_id: str | None = None,
423
+ ) -> AttemptedRollout:
424
+ """Notify the store that I'm about to run a rollout.
425
+
426
+ See [`LightningStore.start_rollout()`][mantisdk.LightningStore.start_rollout] for semantics.
427
+ """
428
+ rollout_id = _generate_rollout_id()
429
+ current_time = time.time()
430
+
431
+ rollout_config = config.model_copy(deep=True) if config is not None else RolloutConfig()
432
+ rollout_metadata = dict(metadata) if metadata is not None else {}
433
+
434
+ if resources_id is None:
435
+ latest_resources = await self._get_latest_resources()
436
+ resources_id = latest_resources.resources_id if latest_resources is not None else None
437
+
438
+ rollout = Rollout(
439
+ rollout_id=rollout_id,
440
+ input=input,
441
+ mode=mode,
442
+ resources_id=resources_id,
443
+ start_time=current_time,
444
+ status="preparing",
445
+ config=rollout_config,
446
+ metadata=rollout_metadata,
447
+ )
448
+
449
+ # Create the initial attempt
450
+ attempt_id = _generate_attempt_id()
451
+ attempt = Attempt(
452
+ rollout_id=rollout.rollout_id,
453
+ attempt_id=attempt_id,
454
+ sequence_id=1,
455
+ start_time=current_time,
456
+ status="preparing",
457
+ worker_id=worker_id,
458
+ )
459
+
460
+ async def _insert_rollout_and_attempt(collections: T_collections) -> None:
461
+ await collections.attempts.insert([attempt])
462
+ await collections.rollouts.insert([rollout])
463
+
464
+ await self.collections.execute(
465
+ _insert_rollout_and_attempt,
466
+ mode="rw",
467
+ snapshot=self._read_snapshot,
468
+ commit=True,
469
+ labels=["rollouts", "attempts"],
470
+ )
471
+ # Notify the subclass that the rollout status has changed.
472
+ all_fields = list(rollout.__class__.model_fields.keys())
473
+ await self._post_update_rollout([(rollout, all_fields)])
474
+
475
+ if worker_id is not None:
476
+ await self._sync_workers_with_attempts([attempt])
477
+
478
+ # Notify listeners
479
+ await self._notify("on_rollout_created", rollout)
480
+ await self._notify("on_attempt_created", attempt)
481
+
482
+ # Return a rollout with attempt attached.
483
+ return AttemptedRollout(**rollout.model_dump(), attempt=attempt)
484
+
485
+ @tracked("_enqueue_many_rollouts")
486
+ @_with_collections_execute(labels=["rollouts", "rollout_queue"])
487
+ async def _enqueue_many_rollouts(self, collections: T_collections, rollouts: Sequence[Rollout]) -> None:
488
+ """Enqueue many rollouts into the rollout queue. Locked bulk version."""
489
+ rollout_ids = [rollout.rollout_id for rollout in rollouts]
490
+ await collections.rollout_queue.enqueue(rollout_ids)
491
+ await collections.rollouts.insert(rollouts)
492
+
493
+ @tracked("_prepare_single_rollout")
494
+ async def _prepare_single_rollout(
495
+ self,
496
+ input: TaskInput,
497
+ mode: Literal["train", "val", "test"] | None = None,
498
+ resources_id: str | None = None,
499
+ config: RolloutConfig | None = None,
500
+ metadata: Dict[str, Any] | None = None,
501
+ ) -> Rollout:
502
+ """Prepare a single rollout object without enqueuing it.
503
+
504
+ Expects resources_id to have been resolved.
505
+ """
506
+ rollout_id = _generate_rollout_id()
507
+ current_time = time.time()
508
+
509
+ rollout_config = config.model_copy(deep=True) if config is not None else RolloutConfig()
510
+ rollout_metadata = dict(metadata) if metadata is not None else {}
511
+
512
+ return Rollout(
513
+ rollout_id=rollout_id,
514
+ input=input,
515
+ mode=mode,
516
+ resources_id=resources_id,
517
+ start_time=current_time,
518
+ status="queuing", # should be queuing
519
+ config=rollout_config,
520
+ metadata=rollout_metadata,
521
+ )
522
+
523
+ @tracked("enqueue_rollout")
524
+ @healthcheck_before
525
+ async def enqueue_rollout(
526
+ self,
527
+ input: TaskInput,
528
+ mode: Literal["train", "val", "test"] | None = None,
529
+ resources_id: str | None = None,
530
+ config: RolloutConfig | None = None,
531
+ metadata: Dict[str, Any] | None = None,
532
+ ) -> Rollout:
533
+ """Adds a new task to the queue with specific metadata and returns the rollout.
534
+
535
+ See [`LightningStore.enqueue_rollout()`][mantisdk.LightningStore.enqueue_rollout] for semantics.
536
+ """
537
+ if resources_id is None:
538
+ latest_resources = await self._get_latest_resources()
539
+ resources_id = latest_resources.resources_id if latest_resources is not None else None
540
+
541
+ rollout = await self._prepare_single_rollout(
542
+ input=input,
543
+ resources_id=resources_id,
544
+ mode=mode,
545
+ config=config,
546
+ metadata=metadata,
547
+ )
548
+
549
+ await self._enqueue_many_rollouts([rollout])
550
+ # Notify the subclass that the rollout status has changed.
551
+ all_fields = list(Rollout.model_fields.keys())
552
+ # Skip queueing because the rollout is already in the queue.
553
+ await self._post_update_rollout([(rollout, all_fields)], skip_enqueue=True)
554
+
555
+ # Notify listeners
556
+ await self._notify("on_rollout_created", rollout)
557
+
558
+ # Return the rollout with no attempt attached.
559
+ return rollout
560
+
561
+ @tracked("enqueue_many_rollouts")
562
+ @healthcheck_before
563
+ async def enqueue_many_rollouts(self, rollouts: Sequence[EnqueueRolloutRequest]) -> Sequence[Rollout]:
564
+ """Adds many rollouts in a batch."""
565
+ prepared_rollouts: List[Rollout] = []
566
+ latest_resources = await self._get_latest_resources()
567
+
568
+ for request in rollouts:
569
+ resources_id = request.resources_id
570
+ if resources_id is None:
571
+ resources_id = latest_resources.resources_id if latest_resources is not None else None
572
+
573
+ rollout = await self._prepare_single_rollout(
574
+ input=request.input,
575
+ resources_id=resources_id,
576
+ mode=request.mode,
577
+ config=request.config,
578
+ metadata=request.metadata,
579
+ )
580
+ prepared_rollouts.append(rollout)
581
+
582
+ await self._enqueue_many_rollouts(prepared_rollouts)
583
+ all_fields = list(Rollout.model_fields.keys())
584
+ rollout_updates = [(rollout, all_fields) for rollout in prepared_rollouts]
585
+ await self._post_update_rollout(rollout_updates, skip_enqueue=True)
586
+
587
+ return prepared_rollouts
588
+
589
+ @tracked("_unlocked_query_rollouts_by_rollout_ids")
590
+ async def _unlocked_query_rollouts_by_rollout_ids(
591
+ self, collections: T_collections, rollout_ids: Sequence[str]
592
+ ) -> List[Rollout]:
593
+ """Query rollouts by rollout IDs."""
594
+ if len(rollout_ids) == 0:
595
+ return []
596
+ elif len(rollout_ids) == 1:
597
+ # Performance optimization: use exact filter for single rollout.
598
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_ids[0]}})
599
+ return [rollout] if rollout is not None else []
600
+ else:
601
+ result = await collections.rollouts.query({"rollout_id": {"within": rollout_ids}})
602
+ # Preserve the order of rollout_ids.
603
+ result_dict = {rollout.rollout_id: rollout for rollout in result}
604
+ return [result_dict[rollout_id] for rollout_id in rollout_ids if rollout_id in result_dict]
605
+
606
+ @tracked("_post_dequeue_rollouts")
607
+ @_with_collections_execute(labels=["rollouts", "attempts"])
608
+ async def _post_dequeue_rollouts(
609
+ self, collections: T_collections, rollout_ids: Sequence[str], worker_id: Optional[str]
610
+ ) -> Sequence[Tuple[AttemptedRollout, Sequence[str]]]:
611
+ """Post-dequeue logic for the rollout. Returns the rollout and the update fields (for post-update logic)."""
612
+ rollouts = await self._unlocked_query_rollouts_by_rollout_ids(collections, rollout_ids)
613
+
614
+ if not rollouts:
615
+ logger.warning(f"No rollout found for rollout IDs: {rollout_ids}, skipping dequeuing")
616
+ return []
617
+
618
+ dequeue_results: List[Tuple[AttemptedRollout, Sequence[str]]] = []
619
+ for rollout in rollouts:
620
+ # Check if rollout is still in a queuing state
621
+ # (it might have been updated to a different status while in queue)
622
+ if is_queuing(rollout):
623
+ # Create a new attempt (could be first attempt or retry)
624
+ attempt_id = _generate_attempt_id()
625
+ current_time = time.time()
626
+
627
+ # Get existing attempts to determine sequence number
628
+ existing_attempts = await self._unlocked_query_attempts_for_rollout(collections, rollout.rollout_id)
629
+ sequence_id = len(existing_attempts) + 1
630
+
631
+ attempt = Attempt(
632
+ rollout_id=rollout.rollout_id,
633
+ attempt_id=attempt_id,
634
+ sequence_id=sequence_id,
635
+ start_time=current_time,
636
+ status="preparing",
637
+ worker_id=worker_id,
638
+ )
639
+
640
+ await collections.attempts.insert([attempt])
641
+
642
+ # Sync attempt status to rollout
643
+ rollout, update_fields = await self._unlocked_update_rollout_only(
644
+ collections, rollout.rollout_id, status="preparing"
645
+ )
646
+ dequeue_results.append((AttemptedRollout(**rollout.model_dump(), attempt=attempt), update_fields))
647
+
648
+ else:
649
+ # If not in queuing state, skip this rollout and continue
650
+ # (it was updated externally and should not be processed)
651
+ logger.warning(
652
+ f"Rollout {rollout.rollout_id} is not in queuing state: {rollout.status}, skipping dequeuing"
653
+ )
654
+
655
+ return dequeue_results
656
+
657
+ @tracked("dequeue_rollout")
658
+ @healthcheck_before
659
+ async def dequeue_rollout(self, worker_id: Optional[str] = None) -> Optional[AttemptedRollout]:
660
+ """Retrieves the next task from the queue without blocking.
661
+ Returns `None` if the queue is empty.
662
+
663
+ Will set the rollout status to preparing and create a new attempt.
664
+
665
+ See [`LightningStore.dequeue_rollout()`][mantisdk.LightningStore.dequeue_rollout] for semantics.
666
+ """
667
+ # Keep looking until we find a rollout that's still in queuing status
668
+ # or the queue is empty
669
+ while True:
670
+ async with self.collections.atomic(
671
+ mode="rw", snapshot=self._read_snapshot, labels=["rollout_queue"]
672
+ ) as collections:
673
+ dequeued = await collections.rollout_queue.dequeue(1)
674
+ if not dequeued:
675
+ break
676
+ rollout_id = dequeued[0]
677
+ logger.debug("Rollout ID %s has been dequeued by Worker ID %s", rollout_id, worker_id)
678
+
679
+ post_dequeue_result = await self._post_dequeue_rollouts([rollout_id], worker_id)
680
+ if post_dequeue_result:
681
+ await self._post_update_rollout(post_dequeue_result)
682
+ attempted_rollout, _ = post_dequeue_result[0]
683
+ if worker_id is not None:
684
+ await self._sync_workers_with_attempts([attempted_rollout.attempt], dequeue=True)
685
+ logger.debug("Rollout has been prepared for Worker ID %s: %s", worker_id, attempted_rollout)
686
+ return attempted_rollout
687
+
688
+ # else continue the loop
689
+
690
+ # No valid rollouts found
691
+ if worker_id is not None:
692
+ # Mark the current worker as idle
693
+ await self._dequeue_mark_worker_idle(worker_id)
694
+ return None
695
+
696
+ @tracked("dequeue_many_rollouts")
697
+ @healthcheck_before
698
+ async def dequeue_many_rollouts(
699
+ self, *, limit: int = 1, worker_id: Optional[str] = None
700
+ ) -> Sequence[AttemptedRollout]:
701
+ """Retrieves up to `limit` tasks from the queue without blocking."""
702
+ dequeued_rollouts: List[AttemptedRollout] = []
703
+ # Keep looking until we find a rollout that's still in queuing status
704
+ # or the queue is empty
705
+ while len(dequeued_rollouts) < limit:
706
+ rest_limit = limit - len(dequeued_rollouts)
707
+ async with self.collections.atomic(
708
+ mode="rw", snapshot=self._read_snapshot, labels=["rollout_queue"]
709
+ ) as collections:
710
+ dequeued = await collections.rollout_queue.dequeue(rest_limit)
711
+ if not dequeued:
712
+ # have no more rollouts in the queue; break.
713
+ break
714
+
715
+ post_dequeue_result = await self._post_dequeue_rollouts(dequeued, worker_id)
716
+ if post_dequeue_result:
717
+ await self._post_update_rollout(post_dequeue_result)
718
+ dequeued_rollouts.extend([item for item, _ in post_dequeue_result])
719
+
720
+ # else continue the loop
721
+
722
+ # Final cleanup and worker status update
723
+ if worker_id is not None:
724
+ if dequeued_rollouts:
725
+ # NOTE: One worker can currently only associated with one attempt.
726
+ # Assuming the worker is working on the last dequeued rollout.
727
+ await self._sync_workers_with_attempts([dequeued_rollouts[-1].attempt], dequeue=True)
728
+ else:
729
+ # Mark the current worker as idle
730
+ await self._dequeue_mark_worker_idle(worker_id)
731
+ return dequeued_rollouts
732
+
733
+ @tracked("start_attempt")
734
+ @healthcheck_before
735
+ async def start_attempt(self, rollout_id: str, worker_id: Optional[str] = None) -> AttemptedRollout:
736
+ """Creates a new attempt for a given rollout ID and return the attempt details.
737
+
738
+ See [`LightningStore.start_attempt()`][mantisdk.LightningStore.start_attempt] for semantics.
739
+ """
740
+
741
+ async def _create_attempt(collections: T_collections):
742
+ # Get the rollout
743
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
744
+ if not rollout:
745
+ raise ValueError(f"Rollout {rollout_id} not found")
746
+
747
+ # Get existing attempts to determine sequence number
748
+ existing_attempts = await self._unlocked_query_attempts_for_rollout(collections, rollout_id)
749
+ sequence_id = len(existing_attempts) + 1
750
+
751
+ # We don't care whether the max attempts have reached or not
752
+ # This attempt is from user trigger
753
+
754
+ # Create new attempt
755
+ attempt_id = _generate_attempt_id()
756
+ current_time = time.time()
757
+
758
+ attempt = Attempt(
759
+ rollout_id=rollout_id,
760
+ attempt_id=attempt_id,
761
+ sequence_id=sequence_id,
762
+ start_time=current_time,
763
+ status="preparing",
764
+ worker_id=worker_id,
765
+ )
766
+
767
+ # Add attempt to storage
768
+ await collections.attempts.insert([attempt])
769
+
770
+ # Sync attempt status to rollout
771
+ rollout, update_fields = await self._unlocked_update_rollout_only(
772
+ collections, rollout_id, status="preparing"
773
+ )
774
+ return attempt, rollout, update_fields
775
+
776
+ attempt, rollout, update_fields = await self.collections.execute(
777
+ _create_attempt, mode="rw", snapshot=self._read_snapshot, commit=True, labels=["rollouts", "attempts"]
778
+ )
779
+ await self._post_update_rollout([(rollout, update_fields)])
780
+
781
+ if worker_id is not None:
782
+ await self._sync_workers_with_attempts([attempt])
783
+
784
+ # Return the rollout with the new attempt attached.
785
+ return AttemptedRollout(**rollout.model_dump(), attempt=attempt)
786
+
787
+ @tracked("query_rollouts")
788
+ @healthcheck_before
789
+ async def query_rollouts(
790
+ self,
791
+ *,
792
+ status_in: Optional[Sequence[RolloutStatus]] = None,
793
+ rollout_id_in: Optional[Sequence[str]] = None,
794
+ rollout_id_contains: Optional[str] = None,
795
+ filter_logic: Literal["and", "or"] = "and",
796
+ sort_by: Optional[str] = None,
797
+ sort_order: Literal["asc", "desc"] = "asc",
798
+ limit: int = -1,
799
+ offset: int = 0,
800
+ status: Optional[Sequence[RolloutStatus]] = None,
801
+ rollout_ids: Optional[Sequence[str]] = None,
802
+ ) -> PaginatedResult[Union[Rollout, AttemptedRollout]]:
803
+ """Retrieve rollouts with filtering and pagination.
804
+
805
+ See [`LightningStore.query_rollouts()`][mantisdk.LightningStore.query_rollouts] for semantics.
806
+ """
807
+ # Construct filters condition
808
+ if status_in is not None:
809
+ resolved_status = status_in
810
+ elif status is not None:
811
+ warnings.warn("status is deprecated, use status_in instead", DeprecationWarning, stacklevel=3)
812
+ resolved_status = status
813
+ else:
814
+ resolved_status = None
815
+
816
+ if rollout_id_in is not None:
817
+ resolved_rollout_ids = rollout_id_in
818
+ elif rollout_ids is not None:
819
+ warnings.warn("rollout_ids is deprecated, use rollout_id_in instead", DeprecationWarning, stacklevel=3)
820
+ resolved_rollout_ids = rollout_ids
821
+ else:
822
+ resolved_rollout_ids = None
823
+
824
+ filters: FilterOptions = {}
825
+ filters["_aggregate"] = filter_logic
826
+ if resolved_status is not None:
827
+ filters["status"] = {"within": list(resolved_status)}
828
+ if resolved_rollout_ids is not None:
829
+ rollout_id_field = cast(FilterField, filters.setdefault("rollout_id", {}))
830
+ rollout_id_field["within"] = list(resolved_rollout_ids)
831
+ if rollout_id_contains is not None:
832
+ rollout_id_field = cast(FilterField, filters.setdefault("rollout_id", {}))
833
+ rollout_id_field["contains"] = rollout_id_contains
834
+
835
+ async with self.collections.atomic(
836
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts", "attempts"]
837
+ ) as collections:
838
+ rollouts = await collections.rollouts.query(
839
+ filter=filters if list(filters.keys()) != ["_aggregate"] else None,
840
+ sort=SortOptions(name=sort_by, order=sort_order) if sort_by else None,
841
+ limit=limit,
842
+ offset=offset,
843
+ )
844
+
845
+ # Attach the latest attempt to the rollout objects
846
+ attempted_rollouts = await self._unlocked_many_rollouts_to_attempted_rollouts(collections, rollouts.items)
847
+
848
+ return PaginatedResult(
849
+ items=attempted_rollouts, limit=rollouts.limit, offset=rollouts.offset, total=rollouts.total
850
+ )
851
+
852
+ @tracked("_unlocked_query_attempts_for_rollout")
853
+ async def _unlocked_query_attempts_for_rollout(self, collections: T_collections, rollout_id: str) -> List[Attempt]:
854
+ """The unlocked version of `query_attempts_for_rollout`."""
855
+ result = await collections.attempts.query(
856
+ filter={"rollout_id": {"exact": rollout_id}},
857
+ sort={"name": "sequence_id", "order": "asc"},
858
+ )
859
+ return list(result.items)
860
+
861
+ @tracked("get_rollout_by_id")
862
+ @healthcheck_before
863
+ async def get_rollout_by_id(self, rollout_id: str) -> Optional[Union[Rollout, AttemptedRollout]]:
864
+ """Retrieves a specific rollout by its ID.
865
+
866
+ See [`LightningStore.get_rollout_by_id()`][mantisdk.LightningStore.get_rollout_by_id] for semantics.
867
+
868
+ If the rollout has been attempted, the latest attempt will also be returned.
869
+ """
870
+ async with self.collections.atomic(
871
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts", "attempts"]
872
+ ) as collections:
873
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
874
+ if rollout is None:
875
+ return None
876
+ return await self._unlocked_rollout_to_attempted_rollout(collections, rollout)
877
+
878
+ @tracked("_unlocked_rollout_to_attempted_rollout")
879
+ async def _unlocked_rollout_to_attempted_rollout(
880
+ self, collections: T_collections, rollout: Rollout
881
+ ) -> Union[Rollout, AttemptedRollout]:
882
+ """Query the latest attempt for the rollout, and attach it to the rollout object.
883
+
884
+ If the rollout has no attempts, return the rollout object itself.
885
+ """
886
+ latest_attempt = await self._unlocked_get_latest_attempt(collections, rollout.rollout_id)
887
+ if latest_attempt is None:
888
+ return rollout
889
+ else:
890
+ return AttemptedRollout(**rollout.model_dump(), attempt=latest_attempt)
891
+
892
+ @tracked("_unlocked_many_rollouts_to_attempted_rollouts")
893
+ async def _unlocked_many_rollouts_to_attempted_rollouts(
894
+ self, collections: T_collections, rollouts: Sequence[Rollout]
895
+ ) -> List[Union[Rollout, AttemptedRollout]]:
896
+ """Query the latest attempts for the rollouts, and attach them to the rollout objects."""
897
+ # TODO: Maybe we can use asyncio.gather here to speed up the process?
898
+ return [await self._unlocked_rollout_to_attempted_rollout(collections, rollout) for rollout in rollouts]
899
+
900
+ @tracked("_unlocked_get_latest_attempt")
901
+ async def _unlocked_get_latest_attempt(self, collections: T_collections, rollout_id: str) -> Optional[Attempt]:
902
+ """The unlocked version of `get_latest_attempt`."""
903
+ return await collections.attempts.get(
904
+ filter={"rollout_id": {"exact": rollout_id}},
905
+ sort={"name": "sequence_id", "order": "desc"},
906
+ )
907
+
908
+ @tracked("query_attempts")
909
+ @healthcheck_before
910
+ async def query_attempts(
911
+ self,
912
+ rollout_id: str,
913
+ *,
914
+ sort_by: Optional[str] = "sequence_id",
915
+ sort_order: Literal["asc", "desc"] = "asc",
916
+ limit: int = -1,
917
+ offset: int = 0,
918
+ ) -> PaginatedResult[Attempt]:
919
+ """Retrieve attempts for a rollout with optional ordering/pagination."""
920
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["attempts"]) as collections:
921
+ return await collections.attempts.query(
922
+ filter={"rollout_id": {"exact": rollout_id}},
923
+ sort={"name": sort_by, "order": sort_order} if sort_by else None,
924
+ limit=limit,
925
+ offset=offset,
926
+ )
927
+
928
+ @tracked("get_latest_attempt")
929
+ @healthcheck_before
930
+ async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
931
+ """Retrieves the latest attempt for a given rollout ID.
932
+
933
+ See [`LightningStore.get_latest_attempt()`][mantisdk.LightningStore.get_latest_attempt] for semantics.
934
+ """
935
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["attempts"]) as collections:
936
+ return await self._unlocked_get_latest_attempt(collections, rollout_id)
937
+
938
+ @tracked("query_resources")
939
+ async def query_resources(
940
+ self,
941
+ *,
942
+ resources_id: Optional[str] = None,
943
+ resources_id_contains: Optional[str] = None,
944
+ sort_by: Optional[str] = None,
945
+ sort_order: Literal["asc", "desc"] = "asc",
946
+ limit: int = -1,
947
+ offset: int = 0,
948
+ ) -> PaginatedResult[ResourcesUpdate]:
949
+ """Return every stored resource snapshot in insertion order."""
950
+ filters: FilterOptions = {}
951
+ if resources_id is not None:
952
+ resources_id_field = cast(FilterField, filters.setdefault("resources_id", {}))
953
+ resources_id_field["exact"] = resources_id
954
+ if resources_id_contains is not None:
955
+ resources_id_field = cast(FilterField, filters.setdefault("resources_id", {}))
956
+ resources_id_field["contains"] = resources_id_contains
957
+
958
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["resources"]) as collections:
959
+ return await collections.resources.query(
960
+ filter=filters or None,
961
+ sort={"name": sort_by, "order": sort_order} if sort_by else None,
962
+ limit=limit,
963
+ offset=offset,
964
+ )
965
+
966
+ @tracked("add_resources")
967
+ async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
968
+ """Stores a new version of named resources and sets it as the latest.
969
+
970
+ See [`LightningStore.add_resources()`][mantisdk.LightningStore.add_resources] for semantics.
971
+ """
972
+ resources_id = _generate_resources_id()
973
+ current_time = time.time()
974
+ update = ResourcesUpdate(
975
+ resources_id=resources_id,
976
+ resources=resources,
977
+ create_time=current_time,
978
+ update_time=current_time,
979
+ version=1,
980
+ )
981
+ async with self.collections.atomic(mode="w", snapshot=self._read_snapshot, labels=["resources"]) as collections:
982
+ await collections.resources.insert([update])
983
+
984
+ # Notify listeners
985
+ await self._notify("on_resource_registered", update)
986
+
987
+ return update
988
+
989
+ @tracked("update_resources")
990
+ @healthcheck_before
991
+ @_with_collections_execute(labels=["resources"])
992
+ async def update_resources(
993
+ self, collections: T_collections, resources_id: str, resources: NamedResources
994
+ ) -> ResourcesUpdate:
995
+ """
996
+ Safely stores a new version of named resources and sets it as the latest.
997
+
998
+ See [`LightningStore.update_resources()`][mantisdk.LightningStore.update_resources] for semantics.
999
+ """
1000
+ current_time = time.time()
1001
+ existing = await collections.resources.get({"resources_id": {"exact": resources_id}})
1002
+ if existing is None:
1003
+ update = ResourcesUpdate(
1004
+ resources_id=resources_id,
1005
+ resources=resources,
1006
+ create_time=current_time,
1007
+ update_time=current_time,
1008
+ version=1,
1009
+ )
1010
+ await collections.resources.insert([update])
1011
+ else:
1012
+ update = existing.model_copy(
1013
+ update={
1014
+ "resources": resources,
1015
+ "update_time": current_time,
1016
+ "version": existing.version + 1,
1017
+ }
1018
+ )
1019
+ await collections.resources.update([update])
1020
+
1021
+ # Notify listeners (note: called inside transaction context)
1022
+ await self._notify("on_resource_registered", update)
1023
+
1024
+ return update
1025
+
1026
+ @tracked("get_resources_by_id")
1027
+ async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
1028
+ """Retrieves a specific version of named resources by its ID.
1029
+
1030
+ See [`LightningStore.get_resources_by_id()`][mantisdk.LightningStore.get_resources_by_id] for semantics.
1031
+ """
1032
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["resources"]) as collections:
1033
+ return await collections.resources.get({"resources_id": {"exact": resources_id}})
1034
+
1035
+ @tracked("get_latest_resources")
1036
+ async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
1037
+ """Retrieves the latest version of named resources.
1038
+
1039
+ See [`LightningStore.get_latest_resources()`][mantisdk.LightningStore.get_latest_resources] for semantics.
1040
+ """
1041
+ return await self._get_latest_resources()
1042
+
1043
+ @tracked("_issue_many_span_sequence_ids")
1044
+ async def _issue_many_span_sequence_ids(self, rollout_ids: List[str]) -> List[int]:
1045
+ """Issue a new span sequence ID for a given rollout."""
1046
+ if not rollout_ids:
1047
+ return []
1048
+
1049
+ request_counts: Dict[str, int] = defaultdict(int)
1050
+ for rollout_id in rollout_ids:
1051
+ request_counts[rollout_id] += 1
1052
+
1053
+ latest_values: Dict[str, int] = {}
1054
+ for rollout_id, count in request_counts.items():
1055
+ async with self.collections.atomic(mode="rw", snapshot=False, labels=["span_sequence_ids"]) as collections:
1056
+ latest_values[rollout_id] = await collections.span_sequence_ids.inc(rollout_id, count)
1057
+
1058
+ next_value_tracker: Dict[str, int] = {
1059
+ rollout_id: latest_values[rollout_id] - request_counts[rollout_id] for rollout_id in request_counts
1060
+ }
1061
+
1062
+ result: List[int] = []
1063
+ for rollout_id in rollout_ids:
1064
+ next_value_tracker[rollout_id] += 1
1065
+ result.append(next_value_tracker[rollout_id])
1066
+
1067
+ return result
1068
+
1069
+ @tracked("_sync_span_sequence_id")
1070
+ async def _sync_span_sequence_id(self, rollout_id: str, sequence_id: int) -> None:
1071
+ """Sync the span sequence ID for a given rollout from the input span sequence ID."""
1072
+ async with self.collections.atomic(mode="rw", snapshot=False, labels=["span_sequence_ids"]) as collections:
1073
+ await collections.span_sequence_ids.chmax(rollout_id, sequence_id)
1074
+
1075
+ @tracked("get_next_span_sequence_id")
1076
+ async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
1077
+ """Get the next span sequence ID for a given rollout and attempt.
1078
+ The number is strictly increasing for each rollout.
1079
+ The store will not issue the same sequence ID twice.
1080
+
1081
+ See [`LightningStore.get_next_span_sequence_id()`][mantisdk.LightningStore.get_next_span_sequence_id] for semantics.
1082
+ """
1083
+ ret = await self._issue_many_span_sequence_ids([rollout_id])
1084
+ return ret[0]
1085
+
1086
+ @tracked("get_many_span_sequence_ids")
1087
+ async def get_many_span_sequence_ids(self, rollout_attempt_ids: Sequence[Tuple[str, str]]) -> Sequence[int]:
1088
+ """Get the next span sequence IDs for a given list of rollout and attempt identifiers."""
1089
+ return await self._issue_many_span_sequence_ids([rollout_id for rollout_id, _ in rollout_attempt_ids])
1090
+
1091
+ @tracked("add_span")
1092
+ async def add_span(self, span: Span) -> Optional[Span]:
1093
+ """Persist a pre-converted span.
1094
+
1095
+ See [`LightningStore.add_span()`][mantisdk.LightningStore.add_span] for semantics.
1096
+ """
1097
+ # Update the sequence ID to be synced with latest input span
1098
+ await self._sync_span_sequence_id(span.rollout_id, span.sequence_id)
1099
+ successful_spans = await self._add_many_spans_helper(span.rollout_id, span.attempt_id, [span])
1100
+ return successful_spans[0] if len(successful_spans) > 0 else None
1101
+
1102
+ @tracked("add_many_spans")
1103
+ async def add_many_spans(self, spans: Sequence[Span]) -> Sequence[Span]:
1104
+ """Persist a sequence of pre-converted spans.
1105
+
1106
+ See [`LightningStore.add_many_spans()`][mantisdk.LightningStore.add_many_spans] for semantics.
1107
+ """
1108
+ # Group spans by rollout and attempt
1109
+ spans_by_rollout_attempt: Dict[Tuple[str, str], List[Span]] = defaultdict(list)
1110
+ for span in spans:
1111
+ spans_by_rollout_attempt[(span.rollout_id, span.attempt_id)].append(span)
1112
+
1113
+ # Bulk add spans for each rollout and attempt
1114
+ successful_spans: List[Span] = []
1115
+ for (rollout_id, attempt_id), spans in spans_by_rollout_attempt.items():
1116
+ await self._sync_span_sequence_id(rollout_id, max(span.sequence_id for span in spans))
1117
+ ret = await self._add_many_spans_helper(rollout_id, attempt_id, spans)
1118
+ successful_spans.extend(ret)
1119
+ return successful_spans
1120
+
1121
+ @tracked("add_otel_span")
1122
+ async def add_otel_span(
1123
+ self,
1124
+ rollout_id: str,
1125
+ attempt_id: str,
1126
+ readable_span: ReadableSpan,
1127
+ sequence_id: int | None = None,
1128
+ ) -> Optional[Span]:
1129
+ """Add an opentelemetry span to the store.
1130
+
1131
+ See [`LightningStore.add_otel_span()`][mantisdk.LightningStore.add_otel_span] for semantics.
1132
+ """
1133
+ if sequence_id is None:
1134
+ # Issue a new sequence ID for the rollout
1135
+ sequence_id = (await self._issue_many_span_sequence_ids([rollout_id]))[0]
1136
+ else:
1137
+ # Comes from a provided sequence ID
1138
+ # Make sure our counter is strictly increasing
1139
+ await self._sync_span_sequence_id(rollout_id, sequence_id)
1140
+
1141
+ span = Span.from_opentelemetry(
1142
+ readable_span, rollout_id=rollout_id, attempt_id=attempt_id, sequence_id=sequence_id
1143
+ )
1144
+ ret = await self._add_many_spans_helper(rollout_id, attempt_id, [span])
1145
+ return ret[0] if len(ret) > 0 else None
1146
+
1147
+ @tracked("_insert_spans_with_fallback")
1148
+ async def _insert_spans_with_fallback(self, spans: Sequence[Span]) -> Sequence[Span]:
1149
+ """Insert spans into the store. If the insert fails, fallback to inserting one by one."""
1150
+
1151
+ async def _add_span_fallback(collections: T_collections, span: Span) -> bool:
1152
+ try:
1153
+ await collections.spans.insert([span])
1154
+ return True
1155
+ except DuplicatedPrimaryKeyError:
1156
+ logger.error(
1157
+ f"Duplicated span added for rollout={span.rollout_id}, attempt={span.attempt_id}, span={span.span_id}. Skipping."
1158
+ )
1159
+ return False
1160
+
1161
+ successful_spans: List[Span] = []
1162
+ try:
1163
+ # This is not guarded by commit=True.
1164
+ async with self.collections.atomic(
1165
+ mode="w", snapshot=self._read_snapshot, commit=False, labels=["spans"]
1166
+ ) as collections:
1167
+ # FIXME: Part of the insertion might complete though the full operation fails.
1168
+ # In that case, the "insert spans" return values might not be accurate.
1169
+ await collections.spans.insert(spans)
1170
+ successful_spans.extend(spans)
1171
+ except DuplicatedPrimaryKeyError:
1172
+ # There is a duplicate span, we warn it
1173
+ # We fallback to adding the spans one by one
1174
+ for span in spans:
1175
+ async with self.collections.atomic(
1176
+ mode="w", snapshot=self._read_snapshot, labels=["spans"]
1177
+ ) as collections:
1178
+ # No need to commit here, it will be simple atomic write operations
1179
+ if await _add_span_fallback(collections, span):
1180
+ successful_spans.append(span)
1181
+
1182
+ return successful_spans
1183
+
1184
+ @tracked("_add_many_spans_helper")
1185
+ async def _add_many_spans_helper(self, rollout_id: str, attempt_id: str, spans: Sequence[Span]) -> Sequence[Span]:
1186
+ """Add many spans to the store. All spans must be for the same rollout and attempt.
1187
+
1188
+ This method is divided into three parts:
1189
+
1190
+ 1. Verify the rollout and attempt exist;
1191
+ 2. Insert the spans in bulk; if insert fails, fallback to inserting one by one;
1192
+ 3. Update rollout and attempt status if necessary.
1193
+ """
1194
+
1195
+ async with self.collections.atomic(
1196
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts", "attempts"]
1197
+ ) as collections:
1198
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
1199
+ if not rollout:
1200
+ raise ValueError(f"Rollout {rollout_id} not found")
1201
+ current_attempt = await collections.attempts.get(
1202
+ filter={"rollout_id": {"exact": rollout_id}, "attempt_id": {"exact": attempt_id}},
1203
+ )
1204
+
1205
+ if not current_attempt:
1206
+ raise ValueError(f"Attempt {attempt_id} not found for rollout {rollout_id}")
1207
+
1208
+ successful_spans = await self._insert_spans_with_fallback(spans)
1209
+ if successful_spans:
1210
+ await self._post_add_spans(successful_spans, rollout_id, attempt_id)
1211
+
1212
+ logger.debug("Added %d spans for rollout %s, attempt %s", len(successful_spans), rollout_id, attempt_id)
1213
+
1214
+ return successful_spans
1215
+
1216
+ @tracked("_post_add_spans")
1217
+ async def _post_add_spans(self, spans: Sequence[Span], rollout_id: str, attempt_id: str) -> None:
1218
+ """Update attempt heartbeat and rollout status after spans are inserted.
1219
+
1220
+ Args:
1221
+ spans: Newly inserted spans.
1222
+ rollout_id: Identifier for the rollout receiving the spans.
1223
+ attempt_id: Identifier for the attempt receiving the spans.
1224
+
1225
+ Note:
1226
+ The method refetches the attempt/rollout inside the transactional callback to
1227
+ avoid clobbering fields that might have changed after the spans were queued.
1228
+ """
1229
+ if not spans:
1230
+ return
1231
+
1232
+ rollout_update = await self._on_attempt_heartbeat(rollout_id=rollout_id, attempt_id=attempt_id)
1233
+ if rollout_update is not None:
1234
+ await self._post_update_rollout([rollout_update])
1235
+
1236
+ # Notify listeners of all new spans
1237
+ for span in spans:
1238
+ await self._notify("on_span_created", span)
1239
+
1240
+ @tracked("_on_attempt_heartbeat")
1241
+ @_with_collections_execute(labels=["rollouts", "attempts"])
1242
+ async def _on_attempt_heartbeat(
1243
+ self, collections: T_collections, rollout_id: str, attempt_id: str
1244
+ ) -> Optional[Tuple[Rollout, Sequence[str]]]:
1245
+ attempt = await collections.attempts.get(
1246
+ {"rollout_id": {"exact": rollout_id}, "attempt_id": {"exact": attempt_id}}
1247
+ )
1248
+ if attempt is None:
1249
+ return None
1250
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
1251
+ if rollout is None:
1252
+ return None
1253
+
1254
+ # Update attempt heartbeat and ensure persistence
1255
+ attempt.last_heartbeat_time = time.time()
1256
+ if attempt.status in ["preparing", "unresponsive"]:
1257
+ attempt.status = "running"
1258
+ await collections.attempts.update([attempt], update_fields=["last_heartbeat_time", "status"])
1259
+
1260
+ # If the status has already timed out or failed, do not change it (but heartbeat is still recorded)
1261
+
1262
+ # Update rollout status if it's the latest attempt
1263
+ rollout_updated: bool = False
1264
+ updated_fields: List[str] = []
1265
+ latest_attempt = await self._unlocked_get_latest_attempt(collections, rollout.rollout_id)
1266
+ if latest_attempt is not None and attempt.attempt_id == latest_attempt.attempt_id:
1267
+ if rollout.status in ["preparing", "queueing", "requeuing"]:
1268
+ # If rollout is currently preparing or queuing, set it to running
1269
+ rollout.status = "running"
1270
+ await collections.rollouts.update([rollout], update_fields=["status"])
1271
+ rollout_updated = True
1272
+ updated_fields = ["status"]
1273
+ # Otherwise, the rollout has succeeded or failed, do nothing
1274
+ return (rollout, updated_fields) if rollout_updated else None
1275
+
1276
+ @tracked("wait_for_rollouts")
1277
+ @healthcheck_before
1278
+ async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
1279
+ """Wait for specified rollouts to complete with a timeout.
1280
+ Returns the completed rollouts, potentially incomplete if timeout is reached.
1281
+
1282
+ This method does not change the state of the store.
1283
+
1284
+ See [`LightningStore.wait_for_rollouts()`][mantisdk.LightningStore.wait_for_rollouts] for semantics.
1285
+ """
1286
+ # Wait for all rollouts concurrently
1287
+ rollouts = await asyncio.gather(
1288
+ *[self.wait_for_rollout(rid, timeout) for rid in rollout_ids], return_exceptions=True
1289
+ )
1290
+
1291
+ for rollout_id, rollout in zip(rollout_ids, rollouts):
1292
+ if isinstance(rollout, Exception):
1293
+ logger.error(f"Error waiting for rollout {rollout_id}: {rollout}")
1294
+
1295
+ # Filter out the exceptions
1296
+ ret = [rollout for rollout in rollouts if isinstance(rollout, Rollout)]
1297
+ finished_rollout_ids = set([rollout.rollout_id for rollout in ret])
1298
+ unfinished_rollout_ids = set(rollout_ids) - finished_rollout_ids
1299
+ logger.debug(
1300
+ "Waiting for rollouts. Number of finished rollouts: %d; number of unfinished rollouts: %d",
1301
+ len(finished_rollout_ids),
1302
+ len(unfinished_rollout_ids),
1303
+ )
1304
+ if len(unfinished_rollout_ids) < 30:
1305
+ logger.debug("Unfinished rollouts: %s", unfinished_rollout_ids)
1306
+ return ret
1307
+
1308
+ @tracked("wait_for_rollout")
1309
+ async def wait_for_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
1310
+ """Wait for a specific rollout to complete with a timeout.
1311
+
1312
+ Subclass may use advanced mechanisms like events to accelerate this.
1313
+
1314
+ Returns the completed rollout, or None if timeout is reached.
1315
+ """
1316
+ # First check if already completed
1317
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["rollouts"]) as collections:
1318
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
1319
+
1320
+ if rollout is None:
1321
+ # Rollout does not exist, return immediately
1322
+ return None
1323
+
1324
+ if is_finished(rollout):
1325
+ # Rollout is already finished, return immediately
1326
+ return rollout
1327
+
1328
+ # No timeout, return immediately
1329
+ if timeout is not None and timeout <= 0:
1330
+ return None
1331
+
1332
+ start_time = time.time()
1333
+ deadline = start_time + timeout if timeout is not None else None
1334
+
1335
+ # If not completed, wait for completion
1336
+ while deadline is None or time.time() < deadline:
1337
+ # Poll every 10 seconds by default
1338
+ rest_time = max(0.01, min(deadline - time.time(), 10.0)) if deadline is not None else 10.0
1339
+ await asyncio.sleep(rest_time)
1340
+ async with self.collections.atomic(
1341
+ mode="r", snapshot=self._read_snapshot, labels=["rollouts"]
1342
+ ) as collections:
1343
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
1344
+ # check if rollout is finished
1345
+ if rollout and is_finished(rollout):
1346
+ return rollout
1347
+
1348
+ return None
1349
+
1350
+ @tracked("query_spans")
1351
+ @healthcheck_before # latest can point to a different attempt
1352
+ async def query_spans(
1353
+ self,
1354
+ rollout_id: str,
1355
+ attempt_id: str | Literal["latest"] | None = None,
1356
+ *,
1357
+ trace_id: Optional[str] = None,
1358
+ trace_id_contains: Optional[str] = None,
1359
+ span_id: Optional[str] = None,
1360
+ span_id_contains: Optional[str] = None,
1361
+ parent_id: Optional[str] = None,
1362
+ parent_id_contains: Optional[str] = None,
1363
+ name: Optional[str] = None,
1364
+ name_contains: Optional[str] = None,
1365
+ filter_logic: Literal["and", "or"] = "and",
1366
+ limit: int = -1,
1367
+ offset: int = 0,
1368
+ sort_by: Optional[str] = "sequence_id",
1369
+ sort_order: Literal["asc", "desc"] = "asc",
1370
+ ) -> PaginatedResult[Span]:
1371
+ """
1372
+ Query and retrieve spans associated with a specific rollout ID.
1373
+ Returns an empty list if no spans are found.
1374
+
1375
+ See [`LightningStore.query_spans()`][mantisdk.LightningStore.query_spans] for semantics.
1376
+ """
1377
+
1378
+ resolved_attempt_id: Optional[str]
1379
+ if attempt_id is None:
1380
+ resolved_attempt_id = None
1381
+ elif attempt_id == "latest":
1382
+ async with self.collections.atomic(
1383
+ mode="r", snapshot=self._read_snapshot, labels=["attempts"]
1384
+ ) as collections:
1385
+ latest_attempt = await self._unlocked_get_latest_attempt(collections, rollout_id)
1386
+ if not latest_attempt:
1387
+ logger.debug(f"No attempts found for rollout {rollout_id} when querying latest spans")
1388
+ return PaginatedResult(items=[], limit=limit, offset=offset, total=0)
1389
+ resolved_attempt_id = latest_attempt.attempt_id
1390
+ else:
1391
+ resolved_attempt_id = attempt_id
1392
+
1393
+ must_filter: Dict[str, FilterField] = {"rollout_id": {"exact": rollout_id}}
1394
+ if resolved_attempt_id is not None:
1395
+ must_filter["attempt_id"] = {"exact": resolved_attempt_id}
1396
+ filter_options: FilterOptions = {
1397
+ "_aggregate": filter_logic, # this can be and/or
1398
+ "_must": must_filter, # Must satisfy all the filters in the must list
1399
+ }
1400
+
1401
+ def _resolve_filter_field(
1402
+ field_name: str, filter_exact: Optional[str] | None, filter_contains: Optional[str] | None
1403
+ ) -> None:
1404
+ field = cast(FilterField, filter_options.setdefault(field_name, {}))
1405
+ if filter_exact is not None:
1406
+ field["exact"] = filter_exact
1407
+ if filter_contains is not None:
1408
+ field["contains"] = filter_contains
1409
+
1410
+ _resolve_filter_field("trace_id", trace_id, trace_id_contains)
1411
+ _resolve_filter_field("span_id", span_id, span_id_contains)
1412
+ _resolve_filter_field("parent_id", parent_id, parent_id_contains)
1413
+ _resolve_filter_field("name", name, name_contains)
1414
+
1415
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["spans"]) as collections:
1416
+ return await collections.spans.query(
1417
+ filter=filter_options,
1418
+ sort={"name": sort_by, "order": sort_order} if sort_by else None,
1419
+ limit=limit,
1420
+ offset=offset,
1421
+ )
1422
+
1423
+ @tracked("update_rollout")
1424
+ @healthcheck_before
1425
+ async def update_rollout(
1426
+ self,
1427
+ rollout_id: str,
1428
+ input: TaskInput | Unset = UNSET,
1429
+ mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
1430
+ resources_id: Optional[str] | Unset = UNSET,
1431
+ status: RolloutStatus | Unset = UNSET,
1432
+ config: RolloutConfig | Unset = UNSET,
1433
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
1434
+ ) -> Rollout:
1435
+ """Update the rollout status and related metadata.
1436
+
1437
+ See [`LightningStore.update_rollout()`][mantisdk.LightningStore.update_rollout] for semantics.
1438
+ """
1439
+ async with self.collections.atomic(mode="w", snapshot=self._read_snapshot, labels=["rollouts"]) as collections:
1440
+ rollout, update_fields = await self._unlocked_update_rollout_only(
1441
+ collections=collections,
1442
+ rollout_id=rollout_id,
1443
+ input=input,
1444
+ mode=mode,
1445
+ resources_id=resources_id,
1446
+ status=status,
1447
+ config=config,
1448
+ metadata=metadata,
1449
+ )
1450
+
1451
+ await self._post_update_rollout([(rollout, update_fields)])
1452
+
1453
+ # Notify listeners
1454
+ await self._notify("on_rollout_updated", rollout)
1455
+
1456
+ return rollout
1457
+
1458
+ @tracked("update_attempt")
1459
+ @healthcheck_before
1460
+ async def update_attempt(
1461
+ self,
1462
+ rollout_id: str,
1463
+ attempt_id: str | Literal["latest"],
1464
+ status: AttemptStatus | Unset = UNSET,
1465
+ worker_id: str | Unset = UNSET,
1466
+ last_heartbeat_time: float | Unset = UNSET,
1467
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
1468
+ ) -> Attempt:
1469
+ """Update a specific or latest attempt for a given rollout.
1470
+
1471
+ See [`LightningStore.update_attempt()`][mantisdk.LightningStore.update_attempt] for semantics.
1472
+ """
1473
+ attempt, rollout_update, worker_sync_required = await self.collections.execute(
1474
+ lambda collections: self._unlocked_update_attempt_and_rollout(
1475
+ collections=collections,
1476
+ rollout_id=rollout_id,
1477
+ attempt_id=attempt_id,
1478
+ status=status,
1479
+ worker_id=worker_id,
1480
+ last_heartbeat_time=last_heartbeat_time,
1481
+ metadata=metadata,
1482
+ ),
1483
+ mode="rw",
1484
+ snapshot=self._read_snapshot,
1485
+ commit=True,
1486
+ labels=["rollouts", "attempts"],
1487
+ )
1488
+ if rollout_update:
1489
+ await self._post_update_rollout([rollout_update])
1490
+ if worker_sync_required:
1491
+ await self._sync_workers_with_attempts([attempt])
1492
+
1493
+ # Notify listeners
1494
+ await self._notify("on_attempt_updated", attempt, rollout_id)
1495
+
1496
+ return attempt
1497
+
1498
+ @tracked("_unlocked_update_rollout_only")
1499
+ async def _unlocked_update_rollout_only(
1500
+ self,
1501
+ collections: T_collections,
1502
+ rollout_id: str,
1503
+ input: TaskInput | Unset = UNSET,
1504
+ mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
1505
+ resources_id: Optional[str] | Unset = UNSET,
1506
+ status: RolloutStatus | Unset = UNSET,
1507
+ config: RolloutConfig | Unset = UNSET,
1508
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
1509
+ ) -> Tuple[Rollout, Sequence[str]]:
1510
+ """Update the rollout status and related metadata only.
1511
+
1512
+ Not updating related attempts or workers.
1513
+
1514
+ There is only one update operation call inside; so commit is not strictly required.
1515
+ """
1516
+ rollout_construct_params: Dict[str, Any] = {"rollout_id": rollout_id}
1517
+ # Update fields if they are not UNSET
1518
+ if not isinstance(input, Unset):
1519
+ rollout_construct_params["input"] = input
1520
+ if not isinstance(mode, Unset):
1521
+ rollout_construct_params["mode"] = mode
1522
+ if not isinstance(resources_id, Unset):
1523
+ rollout_construct_params["resources_id"] = resources_id
1524
+ if not isinstance(status, Unset):
1525
+ rollout_construct_params["status"] = status
1526
+ if not isinstance(config, Unset):
1527
+ rollout_construct_params["config"] = config
1528
+ if not isinstance(metadata, Unset):
1529
+ rollout_construct_params["metadata"] = metadata
1530
+
1531
+ # Set end time for finished rollouts
1532
+ # Rollout is only finished when it succeeded or fail with no more retries.
1533
+ if not isinstance(status, Unset) and status in ["failed", "succeeded", "cancelled"]:
1534
+ rollout_construct_params["end_time"] = time.time()
1535
+
1536
+ update_fields = list(rollout_construct_params.keys())
1537
+
1538
+ # Set required fields for validation purposes.
1539
+ rollout_construct_params.setdefault("input", None)
1540
+ rollout_construct_params.setdefault("start_time", 0.0)
1541
+ rollout_obj = Rollout.model_validate(rollout_construct_params)
1542
+
1543
+ rollouts_updated = await collections.rollouts.update([rollout_obj], update_fields=update_fields)
1544
+ return rollouts_updated[0], update_fields
1545
+
1546
+ @tracked("_post_update_rollout")
1547
+ async def _post_update_rollout(
1548
+ self, rollouts: Sequence[Tuple[Rollout, Sequence[str]]], skip_enqueue: bool = False
1549
+ ) -> None:
1550
+ """Post-update logic for the rollout.
1551
+
1552
+ This method has locks inside, so it should be called with the lock held.
1553
+
1554
+ Args:
1555
+ rollouts: A sequence of tuples, each containing a rollout and the fields that were updated.
1556
+ skip_enqueue: Whether to skip queueing the rollouts.
1557
+ """
1558
+ for rollout, updated_fields in rollouts:
1559
+ # Sometimes "end_time" is set but it's not really updated.
1560
+ if "end_time" in updated_fields and is_finished(rollout):
1561
+ if self._tracker is not None:
1562
+ labels = {
1563
+ "status": rollout.status,
1564
+ "mode": rollout.mode if rollout.mode is not None else "unknown",
1565
+ }
1566
+ duration = cast(float, rollout.end_time) - rollout.start_time
1567
+ await self._tracker.inc_counter("msk.rollouts.total", labels=labels)
1568
+ await self._tracker.observe_histogram(
1569
+ "msk.rollouts.duration",
1570
+ value=duration,
1571
+ labels=labels,
1572
+ )
1573
+
1574
+ if not skip_enqueue:
1575
+ # If requeuing, add back to queue.
1576
+ # Check whether the rollout is already in queue.
1577
+ candidate_requeue_rollouts = [
1578
+ rollout.rollout_id
1579
+ for rollout, updated_fields in rollouts
1580
+ if "status" in updated_fields and is_queuing(rollout)
1581
+ ]
1582
+ if candidate_requeue_rollouts:
1583
+ # Do another filter: filter out rollouts that are already in the queue.
1584
+ async with self.collections.atomic(
1585
+ mode="r", snapshot=self._read_snapshot, labels=["rollout_queue"]
1586
+ ) as collections:
1587
+ candidate_requeue_rollouts = [
1588
+ rollout_id
1589
+ for rollout_id in candidate_requeue_rollouts
1590
+ if not await collections.rollout_queue.has(rollout_id)
1591
+ ]
1592
+
1593
+ if candidate_requeue_rollouts:
1594
+ async with self.collections.atomic(
1595
+ mode="w", snapshot=self._read_snapshot, labels=["rollout_queue"]
1596
+ ) as collections:
1597
+ await collections.rollout_queue.enqueue(candidate_requeue_rollouts)
1598
+
1599
+ # NOTE: We also don't need to remove non-queuing rollouts from the queue.
1600
+
1601
+ @tracked("_unlocked_update_attempt_and_rollout")
1602
+ async def _unlocked_update_attempt_and_rollout(
1603
+ self,
1604
+ collections: T_collections,
1605
+ rollout_id: str,
1606
+ attempt_id: str | Literal["latest"],
1607
+ status: AttemptStatus | Unset = UNSET,
1608
+ worker_id: str | Unset = UNSET,
1609
+ last_heartbeat_time: float | Unset = UNSET,
1610
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
1611
+ ) -> Tuple[Attempt, Optional[Tuple[Rollout, Sequence[str]]], bool]:
1612
+ """Update an attempt.
1613
+
1614
+ The attempt status is propagated to the rollout if the attempt is the latest attempt.
1615
+
1616
+ Returns:
1617
+ - The updated attempt
1618
+ - The updated rollout (or none if unchanged); post-rollout-update is not invoked yet
1619
+ - Whether the worker needs to be synced
1620
+ """
1621
+ # No lock, but with status propagation.
1622
+ rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
1623
+ if not rollout:
1624
+ raise ValueError(f"Rollout {rollout_id} not found")
1625
+
1626
+ latest_attempt = await self._unlocked_get_latest_attempt(collections, rollout_id)
1627
+ if not latest_attempt:
1628
+ raise ValueError(f"No attempts found for rollout {rollout_id}")
1629
+
1630
+ # Find the attempt to update
1631
+ if attempt_id == "latest":
1632
+ attempt = latest_attempt
1633
+ else:
1634
+ attempt = await collections.attempts.get(
1635
+ {"rollout_id": {"exact": rollout_id}, "attempt_id": {"exact": attempt_id}}
1636
+ )
1637
+ if not attempt:
1638
+ raise ValueError(f"Attempt {attempt_id} not found for rollout {rollout_id}")
1639
+
1640
+ worker_sync_required = False
1641
+
1642
+ # Update fields if they are not UNSET
1643
+ if not isinstance(worker_id, Unset):
1644
+ attempt.worker_id = worker_id
1645
+ worker_sync_required = worker_sync_required or bool(worker_id)
1646
+ if not isinstance(status, Unset):
1647
+ attempt.status = status
1648
+ # Also update end_time if the status indicates completion
1649
+ if status in ["failed", "succeeded"]:
1650
+ attempt.end_time = time.time()
1651
+ worker_sync_required = worker_sync_required or bool(attempt.worker_id)
1652
+ if not isinstance(last_heartbeat_time, Unset):
1653
+ attempt.last_heartbeat_time = last_heartbeat_time
1654
+ if not isinstance(metadata, Unset):
1655
+ attempt.metadata = metadata
1656
+
1657
+ # Re-validate the attempt to ensure legality
1658
+ Attempt.model_validate(attempt.model_dump())
1659
+ # Update the attempt in storage
1660
+ await collections.attempts.update([attempt])
1661
+
1662
+ rollout_update: Optional[Tuple[Rollout, Sequence[str]]] = None
1663
+ if attempt.attempt_id == latest_attempt.attempt_id:
1664
+ # Propagate the status to the rollout
1665
+ rollout_status = await rollout_status_from_attempt(attempt, rollout.config)
1666
+ if rollout_status != rollout.status:
1667
+ updated_rollout, update_fields = await self._unlocked_update_rollout_only(
1668
+ collections, rollout_id, status=rollout_status
1669
+ )
1670
+ rollout_update = (updated_rollout, update_fields)
1671
+
1672
+ return attempt, rollout_update, worker_sync_required
1673
+
1674
+ @tracked("query_workers")
1675
+ @healthcheck_before
1676
+ async def query_workers(
1677
+ self,
1678
+ *,
1679
+ status_in: Optional[Sequence[WorkerStatus]] = None,
1680
+ worker_id_contains: Optional[str] = None,
1681
+ filter_logic: Literal["and", "or"] = "and",
1682
+ sort_by: Optional[str] = None,
1683
+ sort_order: Literal["asc", "desc"] = "asc",
1684
+ limit: int = -1,
1685
+ offset: int = 0,
1686
+ ) -> PaginatedResult[Worker]:
1687
+ """Return the current snapshot of all workers."""
1688
+ filters: FilterOptions = {}
1689
+ if status_in is not None:
1690
+ filters["status"] = {"within": list(status_in)}
1691
+ if worker_id_contains is not None:
1692
+ filters["worker_id"] = {"contains": worker_id_contains}
1693
+ filters["_aggregate"] = filter_logic
1694
+
1695
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["workers"]) as collections:
1696
+ return await collections.workers.query(
1697
+ filter=filters if list(filters.keys()) != ["_aggregate"] else None,
1698
+ sort={"name": sort_by, "order": sort_order} if sort_by else None,
1699
+ limit=limit,
1700
+ offset=offset,
1701
+ )
1702
+
1703
+ @tracked("get_worker_by_id")
1704
+ async def get_worker_by_id(self, worker_id: str) -> Optional[Worker]:
1705
+ async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["workers"]) as collections:
1706
+ return await collections.workers.get({"worker_id": {"exact": worker_id}})
1707
+
1708
+ @tracked("update_worker")
1709
+ async def update_worker(
1710
+ self,
1711
+ worker_id: str,
1712
+ heartbeat_stats: Dict[str, Any] | Unset = UNSET,
1713
+ ) -> Worker:
1714
+ """Create or update a worker entry."""
1715
+ update_fields = ["last_heartbeat_time"]
1716
+ new_worker = Worker(worker_id=worker_id, last_heartbeat_time=time.time())
1717
+ if not isinstance(heartbeat_stats, Unset):
1718
+ update_fields.append("heartbeat_stats")
1719
+ new_worker.heartbeat_stats = dict(heartbeat_stats)
1720
+ return await self._update_or_insert_worker(new_worker, update_fields=update_fields)
1721
+
1722
+ @tracked("_unlocked_get_running_rollouts")
1723
+ async def _unlocked_get_running_rollouts(self, collections: T_collections) -> List[AttemptedRollout]:
1724
+ """Get all running rollouts.
1725
+
1726
+ As this is invoked very frequently (probably at every requests),
1727
+ subclass can implement hacks to make it more efficient.
1728
+ It should also be unlocked and let the caller hold the lock.
1729
+ """
1730
+ filtered_rollouts = await collections.rollouts.query(filter={"status": {"within": ["preparing", "running"]}})
1731
+ running_rollouts = await self._unlocked_many_rollouts_to_attempted_rollouts(collections, filtered_rollouts)
1732
+
1733
+ running_attempted_rollouts: List[AttemptedRollout] = []
1734
+ for rollout in running_rollouts:
1735
+ if not isinstance(rollout, AttemptedRollout):
1736
+ logger.error(f"Rollout {rollout.rollout_id} is running but has no attempts")
1737
+ continue
1738
+ running_attempted_rollouts.append(rollout)
1739
+
1740
+ return running_attempted_rollouts
1741
+
1742
+ @tracked("_scan_for_unhealthy_rollouts")
1743
+ async def _scan_for_unhealthy_rollouts(self) -> None:
1744
+ """Perform healthcheck against all running rollouts in the store."""
1745
+ if not await self._should_scan_for_unhealthy_rollouts():
1746
+ return
1747
+
1748
+ rollouts, attempts_sync_required = await self._find_and_update_unhealthy_rollouts()
1749
+
1750
+ if rollouts:
1751
+ await self._post_update_rollout(rollouts)
1752
+
1753
+ # Sync worker status
1754
+ if attempts_sync_required:
1755
+ await self._sync_workers_with_attempts(attempts_sync_required)
1756
+
1757
+ @tracked("_should_scan_for_unhealthy_rollouts")
1758
+ async def _should_scan_for_unhealthy_rollouts(self) -> bool:
1759
+ """Check if the scan for unhealthy rollouts should be performed."""
1760
+ if self._scan_debounce_seconds <= 0:
1761
+ return True
1762
+
1763
+ now = time.time()
1764
+ should_scan = now - self._last_scan_entrance_time >= self._scan_debounce_seconds
1765
+ if not should_scan:
1766
+ return False
1767
+
1768
+ # Someone else may be racing for the same scan. Double-check inside the lock.
1769
+ async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["generic"]):
1770
+ now = time.time()
1771
+ if now - self._last_scan_entrance_time < self._scan_debounce_seconds:
1772
+ return False
1773
+ self._last_scan_entrance_time = now
1774
+ return True
1775
+
1776
+ @tracked("_find_and_update_unhealthy_rollouts")
1777
+ @_with_collections_execute(labels=["rollouts", "attempts"])
1778
+ async def _find_and_update_unhealthy_rollouts(
1779
+ self, collections: T_collections
1780
+ ) -> Tuple[List[Tuple[Rollout, Sequence[str]]], List[Attempt]]:
1781
+ """Batch update the status of unhealthy attempts.
1782
+
1783
+ Returns:
1784
+ - The list of rollouts that have been updated
1785
+ - The list of attempts that need worker-sync
1786
+ """
1787
+ running_rollouts = await self._unlocked_get_running_rollouts(collections)
1788
+
1789
+ candidate_updates = await scan_unhealthy_rollouts(running_rollouts)
1790
+ if not candidate_updates:
1791
+ return [], []
1792
+
1793
+ rollouts: List[Tuple[Rollout, Sequence[str]]] = []
1794
+ attempts: List[Attempt] = []
1795
+ for (rollout_id, attempt_id), status in candidate_updates.items():
1796
+ attempt, rollout_update, worker_sync_required = await self._unlocked_update_attempt_and_rollout(
1797
+ collections, rollout_id, attempt_id, status=status
1798
+ )
1799
+ if rollout_update:
1800
+ rollouts.append(rollout_update)
1801
+ if worker_sync_required:
1802
+ attempts.append(attempt)
1803
+ return rollouts, attempts
1804
+
1805
+
1806
+ # _scan_for_unhealthy_rollouts is somehow standalone and automatically invoked.
1807
+ COLLECTION_STORE_PUBLIC_METHODS = frozenset(
1808
+ [name for name in LightningStore.__dict__ if not name.startswith("_")] + ["_scan_for_unhealthy_rollouts"]
1809
+ )
1810
+
1811
+ COLLECTION_STORE_ALL_METHODS = frozenset([name for name in CollectionBasedLightningStore.__dict__])
1812
+
1813
+
1814
+ def get_current_store_methods() -> Tuple[str, str]:
1815
+ """Get the current store method names from ContextVars.
1816
+
1817
+ This is a fast O(1) replacement for stack introspection. The ContextVars are
1818
+ set by the @tracked decorator when entering store methods.
1819
+
1820
+ Returns:
1821
+ A tuple of (public_method_name, private_method_name).
1822
+ """
1823
+ return _current_public_store_method.get(), _current_private_store_method.get()