mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,845 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Agent runner implementation for executing agent rollouts.
4
+
5
+ This module provides the concrete implementation of the runner interface,
6
+ handling the execution of agent rollouts with support for tracing, hooks,
7
+ and distributed worker coordination.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import logging
14
+ import random
15
+ import threading
16
+ import time
17
+ from contextlib import suppress
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ Any,
21
+ Awaitable,
22
+ Callable,
23
+ List,
24
+ Literal,
25
+ Optional,
26
+ Sequence,
27
+ TypeVar,
28
+ cast,
29
+ )
30
+
31
+ from opentelemetry.sdk.trace import ReadableSpan
32
+
33
+ from mantisdk.litagent import LitAgent
34
+ from mantisdk.reward import emit_reward, find_final_reward
35
+ from mantisdk.store.base import LightningStore
36
+ from mantisdk.tracer.base import Tracer
37
+ from mantisdk.tracer.otel import OtelTracer
38
+ from mantisdk.types import (
39
+ AttemptedRollout,
40
+ Hook,
41
+ NamedResources,
42
+ Rollout,
43
+ RolloutMode,
44
+ RolloutRawResult,
45
+ Span,
46
+ SpanCoreFields,
47
+ )
48
+ from mantisdk.utils.system_snapshot import system_snapshot
49
+
50
+ if TYPE_CHECKING:
51
+ from mantisdk.execution.events import ExecutionEvent
52
+
53
+ from .base import Runner
54
+
55
+ T_task = TypeVar("T_task")
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ class LitAgentRunner(Runner[T_task]):
61
+ """Execute [`LitAgent`][mantisdk.LitAgent] tasks with tracing support.
62
+
63
+ This runner manages the complete lifecycle of agent rollout execution,
64
+ including task polling, resource management, tracing, and hooks. It supports
65
+ both continuous iteration over tasks from the store and single-step execution.
66
+
67
+ Attributes:
68
+ worker_id: Identifier for the active worker process, if any.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ tracer: Tracer,
74
+ max_rollouts: Optional[int] = None,
75
+ poll_interval: float = 5.0,
76
+ heartbeat_interval: float = 10.0,
77
+ interval_jitter: float = 0.5,
78
+ heartbeat_launch_mode: Literal["asyncio", "thread"] = "thread",
79
+ heartbeat_include_gpu: bool = False,
80
+ ) -> None:
81
+ """Initialize the agent runner.
82
+
83
+ Args:
84
+ tracer: [`Tracer`][mantisdk.Tracer] used for rollout spans.
85
+ max_rollouts: Optional cap on iterations processed by
86
+ [`iter`][mantisdk.LitAgentRunner.iter].
87
+ poll_interval: Seconds to wait between store polls when no work is available.
88
+ heartbeat_interval: Seconds to wait between sending heartbeats to the store.
89
+ interval_jitter: Jitter factor for the poll interval. The actual interval will be between
90
+ poll_interval - interval_jitter and poll_interval + interval_jitter.
91
+ This is to avoid the overload caused by the synchronization of the runners.
92
+ heartbeat_launch_mode: Launch mode for the heartbeat loop. Can be "asyncio" or "thread".
93
+ "thread" is the default and recommended mode as it prevents blocking the event loop
94
+ under load. Use "asyncio" for simpler deployments with low worker counts.
95
+ heartbeat_include_gpu: Whether to include GPU stats in heartbeat snapshots.
96
+ Querying GPU stats can be slow under load, so this is disabled by default.
97
+ """
98
+ super().__init__()
99
+ self._tracer = tracer
100
+ self._max_rollouts = max_rollouts
101
+ self._poll_interval = poll_interval
102
+ self._heartbeat_interval = heartbeat_interval
103
+ self._interval_jitter = interval_jitter
104
+ self._heartbeat_launch_mode = heartbeat_launch_mode
105
+ self._heartbeat_include_gpu = heartbeat_include_gpu
106
+ self._random_state = random.Random()
107
+
108
+ # Set later
109
+ self._agent: Optional[LitAgent[T_task]] = None
110
+ self._hooks: Sequence[Hook] = []
111
+ self._store: Optional[LightningStore] = None
112
+ self.worker_id: Optional[int] = None
113
+
114
+ def init(self, agent: LitAgent[T_task], *, hooks: Optional[Sequence[Hook]] = None, **kwargs: Any) -> None:
115
+ """Initialize the runner with the agent.
116
+
117
+ This sets up the agent-runner relationship, registers hooks, and
118
+ initializes the tracer.
119
+
120
+ Args:
121
+ agent: [`LitAgent`][mantisdk.LitAgent] instance executed by the runner.
122
+ hooks: Optional sequence of [`Hook`][mantisdk.Hook]
123
+ callbacks invoked around tracing and rollout boundaries.
124
+ **kwargs: Additional initialization arguments (currently unused).
125
+ """
126
+ self._agent = agent
127
+ self._agent.set_runner(self)
128
+ self._hooks = [*hooks] if hooks is not None else []
129
+
130
+ self._tracer.init()
131
+
132
+ def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
133
+ """Initialize the runner for each worker with worker_id and store.
134
+
135
+ This method is called once per worker in a distributed setup to provide
136
+ the worker with its ID and store connection.
137
+
138
+ Args:
139
+ worker_id: Unique identifier for this worker process.
140
+ store: [`LightningStore`][mantisdk.LightningStore]
141
+ used for task coordination and persistence.
142
+ **kwargs: Additional worker-specific initialization arguments (currently unused).
143
+ """
144
+ self._store = store
145
+ self.worker_id = worker_id
146
+
147
+ self._tracer.init_worker(worker_id, store)
148
+
149
+ def teardown(self, *args: Any, **kwargs: Any) -> None:
150
+ """Teardown the runner and clean up all resources.
151
+
152
+ This method resets all internal state including the agent, store,
153
+ hooks, and worker ID, and calls the tracer's teardown method.
154
+
155
+ Args:
156
+ *args: Additional teardown arguments (currently unused).
157
+ **kwargs: Additional teardown keyword arguments (currently unused).
158
+ """
159
+ self._agent = None
160
+ self._store = None
161
+ self.worker_id = None
162
+ self._hooks = []
163
+
164
+ self._tracer.teardown()
165
+
166
+ def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
167
+ """Teardown the runner for a specific worker.
168
+
169
+ This method cleans up worker-specific resources and resets the worker ID.
170
+
171
+ Args:
172
+ worker_id: Unique identifier of the worker being torn down.
173
+ *args: Additional teardown arguments (currently unused).
174
+ **kwargs: Additional teardown keyword arguments (currently unused).
175
+ """
176
+ self.worker_id = None
177
+
178
+ self._tracer.teardown_worker(worker_id)
179
+
180
+ @property
181
+ def tracer(self) -> Tracer:
182
+ """Get the tracer instance.
183
+
184
+ Returns:
185
+ The Tracer instance used by this runner.
186
+ """
187
+ return self._tracer
188
+
189
+ def get_agent(self) -> LitAgent[T_task]:
190
+ """Get the agent instance.
191
+
192
+ Returns:
193
+ The LitAgent instance managed by this runner.
194
+
195
+ Raises:
196
+ ValueError: If the agent has not been initialized via [`init`][mantisdk.LitAgentRunner.init].
197
+ """
198
+ if self._agent is None:
199
+ raise ValueError("Agent not initialized. Call init() first.")
200
+ return self._agent
201
+
202
+ def get_store(self) -> LightningStore:
203
+ """Get the store instance.
204
+
205
+ Returns:
206
+ The LightningStore instance for this worker.
207
+
208
+ Raises:
209
+ ValueError: If the store has not been initialized via [`init_worker`][mantisdk.LitAgentRunner.init_worker].
210
+ """
211
+ if self._store is None:
212
+ raise ValueError("Store not initialized. Call init_worker() first.")
213
+ return self._store
214
+
215
+ def get_worker_id(self) -> str:
216
+ """Get the formatted worker ID string.
217
+
218
+ Returns:
219
+ A formatted string like "Worker-0" if initialized, or "Worker-Unknown"
220
+ if the worker ID has not been set.
221
+ """
222
+ return f"Worker-{self.worker_id}" if self.worker_id is not None else "Worker-Unknown"
223
+
224
+ def _log_prefix(self, rollout_id: Optional[str] = None) -> str:
225
+ """Generate a standardized log prefix for the current worker.
226
+
227
+ This creates a consistent prefix format for log messages to identify
228
+ which worker and rollout the message is associated with.
229
+
230
+ Args:
231
+ rollout_id: Optional rollout ID to include in the prefix.
232
+
233
+ Returns:
234
+ A formatted log prefix string like "[Worker 0 | Rollout xyz]",
235
+ "[Worker 0]", "[Rollout xyz]", or "[Default Worker]".
236
+ """
237
+ if self.worker_id is not None:
238
+ if rollout_id:
239
+ return f"[Worker {self.worker_id} | Rollout {rollout_id}]"
240
+ else:
241
+ return f"[Worker {self.worker_id}]"
242
+ if rollout_id:
243
+ return f"[Rollout {rollout_id}]"
244
+ return "[Default Worker]"
245
+
246
+ async def _trigger_hooks(
247
+ self,
248
+ hook_type: Literal["on_trace_start", "on_trace_end", "on_rollout_start", "on_rollout_end"],
249
+ *args: Any,
250
+ **kwargs: Any,
251
+ ) -> None:
252
+ """Trigger all registered hooks of a specific type.
253
+
254
+ This method calls the specified hook method on all registered hooks,
255
+ catching and logging any exceptions that occur during hook execution
256
+ to prevent them from disrupting the main execution flow.
257
+
258
+ Args:
259
+ hook_type: The type of hook to trigger. Valid values are:
260
+ "on_trace_start", "on_trace_end", "on_rollout_start", "on_rollout_end".
261
+ *args: Positional arguments to pass to the hook methods.
262
+ **kwargs: Keyword arguments to pass to the hook methods.
263
+ """
264
+ for hook in self._hooks:
265
+ try:
266
+ await getattr(hook, hook_type)(*args, **kwargs)
267
+ except Exception:
268
+ logger.exception(f"{self._log_prefix()} Exception during {hook_type} hook {hook}.")
269
+
270
+ async def _post_process_rollout_result(
271
+ self, rollout: AttemptedRollout, raw_result: RolloutRawResult
272
+ ) -> List[ReadableSpan] | List[Span]:
273
+ """Standardizes the agent's return value and report what's needed to report to the store.
274
+
275
+ Args:
276
+ rollout: The rollout object for the current task.
277
+ raw_result: The output from the agent's rollout method.
278
+
279
+ Returns:
280
+ The spans that are assumed to be added to the store.
281
+ This only serves as an estimation for logging purposes. For precise tracking, use the store directly.
282
+ """
283
+ store = self.get_store()
284
+
285
+ trace_spans: list[Span] = []
286
+ result_recognized: bool = False
287
+
288
+ # Case 0: result is None
289
+ if raw_result is None:
290
+ trace_spans = self._tracer.get_last_trace()
291
+ result_recognized = True
292
+
293
+ # Case 1: result is a float (final reward)
294
+ if isinstance(raw_result, (bool, int, float)):
295
+ if isinstance(raw_result, (bool, int)):
296
+ logger.warning(
297
+ f"{self._log_prefix(rollout.rollout_id)} Reward is not a number, got: {type(raw_result)}. "
298
+ "Auto converting to float."
299
+ )
300
+ raw_result = float(raw_result)
301
+ # Preserve the existing spans before another span is emitted
302
+ trace_spans = list(self._tracer.get_last_trace())
303
+ # This will NOT emit another span to the tracer
304
+ reward_span_core_fields = emit_reward(raw_result, propagate=False)
305
+ # We add it to the store manually
306
+ sequence_id = await store.get_next_span_sequence_id(rollout.rollout_id, rollout.attempt.attempt_id)
307
+ # Get trace_id from existing spans to maintain trace context consistency
308
+ # This ensures the reward span links to the same trace as OTLP-exported spans
309
+ existing_trace_id = trace_spans[0].trace_id if trace_spans else None
310
+ reward_span = Span.from_core_fields(
311
+ reward_span_core_fields,
312
+ rollout_id=rollout.rollout_id,
313
+ attempt_id=rollout.attempt.attempt_id,
314
+ sequence_id=sequence_id,
315
+ trace_id=existing_trace_id,
316
+ )
317
+ await store.add_span(reward_span)
318
+ result_recognized = True
319
+
320
+ # Case 2-4: result is a list
321
+ if isinstance(raw_result, list):
322
+ # Case 2: result is a list of ReadableSpan (OpenTelemetry spans)
323
+ if len(raw_result) > 0 and all(isinstance(t, ReadableSpan) for t in raw_result):
324
+ if isinstance(self._tracer, OtelTracer):
325
+ logger.warning(
326
+ f"{self._log_prefix(rollout.rollout_id)} Tracer is already an OpenTelemetry tracer. "
327
+ "The traces should have already been added to the store. "
328
+ "Returning the traces from the rollout will result in duplicate spans."
329
+ )
330
+ for span in raw_result:
331
+ added_span = await store.add_otel_span(
332
+ rollout.rollout_id, rollout.attempt.attempt_id, cast(ReadableSpan, span)
333
+ )
334
+ if added_span is not None:
335
+ trace_spans.append(added_span)
336
+ else:
337
+ logger.error(
338
+ f"{self._log_prefix(rollout.rollout_id)} Failed to add OpenTelemetry span to the store: {span}"
339
+ )
340
+ result_recognized = True
341
+
342
+ # Case 3: result is a list of Span (mantisdk spans)
343
+ elif len(raw_result) > 0 and all(isinstance(t, Span) for t in raw_result):
344
+ # Add the spans directly to the store
345
+ for span in raw_result:
346
+ await store.add_span(cast(Span, span))
347
+ trace_spans = [cast(Span, span) for span in raw_result]
348
+ result_recognized = True
349
+
350
+ # Case 4: result is a list of SpanCoreFields (mantisdk spans)
351
+ elif len(raw_result) > 0 and all(isinstance(t, SpanCoreFields) for t in raw_result):
352
+ # Add the spans directly to the store too, but needs to get sequence id first
353
+ sequence_ids = await store.get_many_span_sequence_ids(
354
+ [(rollout.rollout_id, rollout.attempt.attempt_id) for _ in range(len(raw_result))]
355
+ )
356
+ trace_spans = [
357
+ Span.from_core_fields(
358
+ cast(SpanCoreFields, span_core_fields),
359
+ rollout_id=rollout.rollout_id,
360
+ attempt_id=rollout.attempt.attempt_id,
361
+ sequence_id=sequence_id,
362
+ )
363
+ for span_core_fields, sequence_id in zip(raw_result, sequence_ids, strict=True)
364
+ ]
365
+ await store.add_many_spans(trace_spans)
366
+ result_recognized = True
367
+
368
+ # Left over cases for list
369
+ elif len(raw_result) == 0:
370
+ logger.warning(
371
+ f"{self._log_prefix(rollout.rollout_id)} The rollout returns an empty list. "
372
+ "Please check your rollout implementation."
373
+ )
374
+ trace_spans = []
375
+ result_recognized = True
376
+
377
+ else:
378
+ types = [type(t).__name__ for t in raw_result][:10]
379
+ raise ValueError(
380
+ f"Invalid raw result type. It's expected to be a list of ReadableSpan or Span, "
381
+ f"but got: {', '.join(types)}..."
382
+ )
383
+
384
+ if not result_recognized:
385
+ raise TypeError(
386
+ f"Invalid raw result type. It's expected to be none, float, or a list of ReadableSpan or Span, "
387
+ f"but got: {type(raw_result).__name__}..."
388
+ )
389
+
390
+ return trace_spans
391
+
392
+ async def _emit_heartbeat(self, store: LightningStore) -> None:
393
+ """Send a heartbeat tick to the store.
394
+
395
+ Args:
396
+ store: The lightning store to update.
397
+ """
398
+ logger.debug(f"{self._log_prefix()} Preparing to emit heartbeat.")
399
+ worker_id = self.get_worker_id()
400
+
401
+ try:
402
+ snapshot = await asyncio.wait_for(
403
+ asyncio.to_thread(system_snapshot, self._heartbeat_include_gpu),
404
+ timeout=self._heartbeat_interval,
405
+ )
406
+ logger.debug(f"{self._log_prefix()} Heartbeat snapshot acquired.")
407
+ except asyncio.TimeoutError:
408
+ logger.warning(
409
+ "%s Heartbeat snapshot acquisition timed out after %.1fs, skipping.",
410
+ self._log_prefix(),
411
+ self._heartbeat_interval,
412
+ )
413
+ return
414
+ except asyncio.CancelledError:
415
+ # bypass the exception
416
+ raise
417
+ except Exception:
418
+ logger.exception("%s Unable to acquire heartbeat snapshot.", self._log_prefix())
419
+ return
420
+
421
+ try:
422
+ await asyncio.wait_for(store.update_worker(worker_id, snapshot), timeout=self._heartbeat_interval)
423
+ logger.debug(f"{self._log_prefix()} Heartbeat updated successfully.")
424
+ except asyncio.CancelledError:
425
+ # bypass the exception
426
+ raise
427
+ except asyncio.TimeoutError:
428
+ logger.warning(
429
+ "%s update worker heartbeat timed out after %.1fs, skipping.",
430
+ self._log_prefix(),
431
+ self._heartbeat_interval,
432
+ )
433
+ except Exception:
434
+ logger.exception("%s Unable to update worker heartbeat.", self._log_prefix())
435
+
436
+ def _start_heartbeat_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
437
+ """Start a background heartbeat loop and return an async stopper."""
438
+
439
+ if self._heartbeat_interval <= 0:
440
+ return None
441
+
442
+ if self.worker_id is None:
443
+ logger.warning("%s Cannot start heartbeat loop without worker_id.", self._log_prefix())
444
+ return None
445
+
446
+ if self._heartbeat_launch_mode == "asyncio":
447
+ return self._start_heartbeat_asyncio_loop(store)
448
+ if self._heartbeat_launch_mode == "thread":
449
+ return self._start_heartbeat_thread_loop(store)
450
+ raise ValueError(f"Unsupported heartbeat launch mode: {self._heartbeat_launch_mode}")
451
+
452
+ def _start_heartbeat_asyncio_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
453
+ """Start a background heartbeat loop using asyncio.
454
+
455
+ Args:
456
+ store: The lightning store to update.
457
+
458
+ Returns:
459
+ An async stopper function that can be used to stop the heartbeat loop.
460
+ """
461
+
462
+ stop_event = asyncio.Event()
463
+
464
+ async def heartbeat_loop() -> None:
465
+ while not stop_event.is_set():
466
+ try:
467
+ # Run _emit_heartbeat in thread pool to avoid blocking the event loop.
468
+ # Timeout at the interval - if it takes longer, the data is stale anyway.
469
+ await self._emit_heartbeat(store)
470
+ except Exception:
471
+ logger.exception("%s Heartbeat failed.", self._log_prefix())
472
+ with suppress(asyncio.TimeoutError):
473
+ interval = self._heartbeat_interval + self._random_state.uniform(
474
+ -self._interval_jitter, self._interval_jitter
475
+ )
476
+ interval = max(interval, 0.01)
477
+ await asyncio.wait_for(stop_event.wait(), timeout=interval)
478
+
479
+ task = asyncio.create_task(heartbeat_loop(), name=f"{self.get_worker_id()}-heartbeat")
480
+
481
+ async def stop() -> None:
482
+ stop_event.set()
483
+ with suppress(asyncio.CancelledError):
484
+ await task
485
+
486
+ return stop
487
+
488
+ def _start_heartbeat_thread_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
489
+ """Start a background heartbeat loop using threading.
490
+
491
+ It uses two threads: one to produce the snapshot and one to consume it,
492
+ to avoid either of them blocking the event loop.
493
+
494
+ Args:
495
+ store: The lightning store to update.
496
+
497
+ Returns:
498
+ An async stopper function that can be used to stop the heartbeat loop.
499
+ """
500
+ stop_evt = threading.Event()
501
+ lock = threading.Lock()
502
+
503
+ latest_snapshot = None
504
+ latest_ts = 0.0 # time.monotonic() when snapshot was captured
505
+
506
+ # Consider snapshot stale after ~1 interval plus jitter slack.
507
+ stale_after = self._heartbeat_interval + self._interval_jitter + 1.0
508
+
509
+ worker_id = self.get_worker_id()
510
+
511
+ def producer() -> None:
512
+ nonlocal latest_snapshot, latest_ts
513
+ while not stop_evt.is_set():
514
+ try:
515
+ logger.debug(f"{self._log_prefix()} Heartbeat producer: acquiring snapshot.")
516
+ snap = system_snapshot(self._heartbeat_include_gpu) # sync
517
+ logger.debug(f"{self._log_prefix()} Heartbeat producer: snapshot acquired.")
518
+ ts = time.monotonic()
519
+ with lock:
520
+ latest_snapshot = snap
521
+ latest_ts = ts
522
+ except Exception:
523
+ logger.warning("%s Heartbeat producer: system_snapshot failed.", self._log_prefix(), exc_info=True)
524
+
525
+ interval = self._heartbeat_interval + self._random_state.uniform(
526
+ -self._interval_jitter, self._interval_jitter
527
+ )
528
+ stop_evt.wait(max(interval, 0.01))
529
+
530
+ def consumer() -> None:
531
+ loop = asyncio.new_event_loop()
532
+ asyncio.set_event_loop(loop)
533
+ last_warned_ts = None # Track which snapshot we've already warned about
534
+ try:
535
+ while not stop_evt.is_set():
536
+ with lock:
537
+ snap = latest_snapshot
538
+ ts = latest_ts
539
+
540
+ wait_interval = max(
541
+ self._heartbeat_interval
542
+ + self._random_state.uniform(-self._interval_jitter, self._interval_jitter),
543
+ 0.01,
544
+ )
545
+
546
+ if snap is None:
547
+ # probably just started
548
+ logger.debug("%s Heartbeat consumer: no snapshot yet; skipping update.", self._log_prefix())
549
+ stop_evt.wait(wait_interval)
550
+ continue
551
+
552
+ age = time.monotonic() - ts
553
+ if age > stale_after:
554
+ # Only warn once per stale snapshot (check if we haven't warned about this timestamp yet)
555
+ if last_warned_ts != ts:
556
+ logger.warning(
557
+ "%s Heartbeat consumer: snapshot stale (age=%.2fs > %.2fs); skipping update.",
558
+ self._log_prefix(),
559
+ age,
560
+ stale_after,
561
+ )
562
+ last_warned_ts = ts
563
+ stop_evt.wait(wait_interval)
564
+ continue
565
+
566
+ try:
567
+ logger.debug(f"{self._log_prefix()} Heartbeat consumer: updating worker.")
568
+ loop.run_until_complete(
569
+ asyncio.wait_for(
570
+ store.update_worker(worker_id, snap),
571
+ timeout=self._heartbeat_interval,
572
+ )
573
+ )
574
+ logger.debug(f"{self._log_prefix()} Heartbeat consumer: worker updated.")
575
+ except asyncio.TimeoutError:
576
+ logger.warning(
577
+ "%s Heartbeat consumer: update timed out after %.1fs.",
578
+ self._log_prefix(),
579
+ self._heartbeat_interval,
580
+ )
581
+ except Exception:
582
+ logger.warning("%s Heartbeat consumer: update failed.", self._log_prefix(), exc_info=True)
583
+
584
+ stop_evt.wait(wait_interval)
585
+ finally:
586
+ with suppress(Exception):
587
+ loop.stop()
588
+ with suppress(Exception):
589
+ loop.close()
590
+
591
+ t_prod = threading.Thread(target=producer, name=f"{worker_id}-heartbeat-producer", daemon=True)
592
+ t_cons = threading.Thread(target=consumer, name=f"{worker_id}-heartbeat-consumer", daemon=True)
593
+ t_prod.start()
594
+ t_cons.start()
595
+
596
+ async def stop() -> None:
597
+ stop_evt.set()
598
+ await asyncio.to_thread(t_prod.join)
599
+ await asyncio.to_thread(t_cons.join)
600
+
601
+ return stop
602
+
603
+ async def _sleep_until_next_poll(self, event: Optional[ExecutionEvent] = None) -> None:
604
+ """Sleep until the next poll interval, with optional event-based interruption.
605
+
606
+ If an event is provided, the method will check it periodically (every 0.1s)
607
+ and return early if the event is set.
608
+
609
+ Args:
610
+ event: Optional [`ExecutionEvent`][mantisdk.ExecutionEvent] object that can be used to interrupt the sleep.
611
+ If set during the sleep period, the method returns immediately.
612
+ """
613
+ interval = self._poll_interval + self._random_state.uniform(-self._interval_jitter, self._interval_jitter)
614
+ interval = max(interval, 0.01)
615
+ if event is None:
616
+ await asyncio.sleep(interval)
617
+ return
618
+ current_time = time.time()
619
+ next_time = current_time + interval
620
+ while time.time() < next_time:
621
+ await asyncio.sleep(0.1)
622
+ if event.is_set():
623
+ return
624
+
625
+ async def _step_impl(self, next_rollout: AttemptedRollout, raise_on_exception: bool = False) -> str:
626
+ """Execute a single rollout implementation.
627
+
628
+ This is the core method that handles the execution of a single rollout,
629
+ including resource fetching, hook triggering, agent invocation, tracing,
630
+ and result processing.
631
+
632
+ Args:
633
+ next_rollout: The rollout to execute, containing input data, mode,
634
+ and resources information.
635
+ raise_on_exception: If True, exceptions during rollout execution will
636
+ be re-raised. If False, exceptions are logged but not propagated.
637
+ """
638
+ store = self.get_store()
639
+ agent = self.get_agent()
640
+
641
+ rollout_id = next_rollout.rollout_id
642
+
643
+ resources_id = next_rollout.resources_id
644
+ resources_update = None
645
+ if resources_id:
646
+ resources_update = await store.get_resources_by_id(resources_id)
647
+ else:
648
+ logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
649
+ resources_update = await store.get_latest_resources()
650
+ if not resources_update:
651
+ if raise_on_exception:
652
+ raise RuntimeError(f"{self._log_prefix(rollout_id)} Failed to fetch resources")
653
+ else:
654
+ logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
655
+ return rollout_id
656
+
657
+ logger.debug(f"{self._log_prefix(rollout_id)} Resources fetched (id={resources_update.resources_id}).")
658
+
659
+ trace_spans: List[ReadableSpan] | List[Span] = []
660
+ has_exception: bool = False
661
+
662
+ try:
663
+ await self._trigger_hooks(hook_type="on_rollout_start", agent=agent, runner=self, rollout=next_rollout)
664
+
665
+ start_time = time.time()
666
+ logger.debug(f"{self._log_prefix(rollout_id)} Prepared for trace context.")
667
+ async with self._tracer.trace_context(
668
+ name=rollout_id, rollout_id=rollout_id, attempt_id=next_rollout.attempt.attempt_id
669
+ ):
670
+ logger.debug(f"{self._log_prefix(rollout_id)} Entered trace context.")
671
+ await self._trigger_hooks(
672
+ hook_type="on_trace_start", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout
673
+ )
674
+
675
+ # NOTE: This is the most costly step in the whole function
676
+ # If the rollout method becomes unresponsive or timeouts, there is nothing we can do within the runner.
677
+ # We might need some mechanisms in execution strategy to restart the runner. But that's a future work.
678
+ if agent.is_async():
679
+ rollout_method = (
680
+ agent.training_rollout_async if next_rollout.mode == "train" else agent.validation_rollout_async
681
+ )
682
+ logger.debug(f"{self._log_prefix(rollout_id)} Starting async rollout method.")
683
+ result = await rollout_method(
684
+ next_rollout.input, resources=resources_update.resources, rollout=next_rollout
685
+ )
686
+ logger.debug(f"{self._log_prefix(rollout_id)} Async rollout method completed.")
687
+ else:
688
+ rollout_method = (
689
+ agent.training_rollout if next_rollout.mode == "train" else agent.validation_rollout
690
+ )
691
+ logger.debug(f"{self._log_prefix(rollout_id)} Starting sync rollout method.")
692
+ result = rollout_method(
693
+ next_rollout.input, resources=resources_update.resources, rollout=next_rollout
694
+ )
695
+ logger.debug(f"{self._log_prefix(rollout_id)} Sync rollout method completed.")
696
+
697
+ await self._trigger_hooks(
698
+ hook_type="on_trace_end", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout
699
+ )
700
+
701
+ logger.debug(f"{self._log_prefix(rollout_id)} Trace context exited.")
702
+
703
+ # Possible exceptions in post_process will be caught in the overall exception handler
704
+ trace_spans = await self._post_process_rollout_result(next_rollout, result)
705
+ last_reward = find_final_reward(trace_spans)
706
+
707
+ end_time = time.time()
708
+ logger.info(
709
+ f"{self._log_prefix(rollout_id)} Completed in "
710
+ f"{end_time - start_time:.2f}s. Collected {len(trace_spans)} span(s). "
711
+ f"Final reward: {last_reward}"
712
+ )
713
+
714
+ except Exception:
715
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
716
+ has_exception = True
717
+
718
+ if raise_on_exception:
719
+ raise
720
+ finally:
721
+ try:
722
+ await self._trigger_hooks(
723
+ hook_type="on_rollout_end", agent=agent, runner=self, rollout=next_rollout, spans=trace_spans
724
+ )
725
+ except Exception:
726
+ logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
727
+
728
+ try:
729
+ if has_exception:
730
+ # possibly timed out and cancelled?
731
+ await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="failed")
732
+ else:
733
+ await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="succeeded")
734
+ except Exception:
735
+ logger.exception(
736
+ f"{self._log_prefix(rollout_id)} Exception during update_attempt. Giving up the update."
737
+ )
738
+
739
+ return rollout_id
740
+
741
+ async def iter(self, *, event: Optional[ExecutionEvent] = None) -> None:
742
+ """Run the runner, continuously iterating over tasks in the store.
743
+
744
+ This method polls the store for new rollouts and executes them until:
745
+
746
+ - The event is set (if provided)
747
+ - The max_rollouts limit is reached (if configured)
748
+ - No more tasks are available
749
+
750
+ All exceptions during rollout execution are caught and logged but not
751
+ propagated, allowing the runner to continue processing subsequent tasks.
752
+
753
+ Args:
754
+ event: Optional ExecutionEvent object to signal the runner to stop. The runner
755
+ will check this event periodically and stop gracefully when set.
756
+ """
757
+ num_tasks_processed = 0
758
+ logger.info(f"{self._log_prefix()} Started async rollouts (max: {self._max_rollouts or 'unlimited'}).")
759
+ store = self.get_store()
760
+
761
+ stop_heartbeat = self._start_heartbeat_loop(store)
762
+
763
+ try:
764
+ while not (event is not None and event.is_set()) and (
765
+ self._max_rollouts is None or num_tasks_processed < self._max_rollouts
766
+ ):
767
+ # Retrieve the next rollout
768
+ next_rollout: Optional[Rollout] = None
769
+ while not (event is not None and event.is_set()):
770
+ logger.debug(f"{self._log_prefix()} Try to poll for next rollout.")
771
+ next_rollout = await store.dequeue_rollout(worker_id=self.get_worker_id())
772
+ logger.debug(f"{self._log_prefix()} Next rollout retrieved: {next_rollout}")
773
+ if next_rollout is None:
774
+ logger.debug(
775
+ f"{self._log_prefix()} No rollout to poll. Waiting for {self._poll_interval} seconds."
776
+ )
777
+ await self._sleep_until_next_poll(event)
778
+ else:
779
+ break
780
+
781
+ if next_rollout is None:
782
+ return
783
+
784
+ # Execute the step
785
+ await self._step_impl(next_rollout)
786
+
787
+ num_tasks_processed += 1
788
+ if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
789
+ logger.info(
790
+ f"{self._log_prefix()} Progress: {num_tasks_processed}/{self._max_rollouts or 'unlimited'}"
791
+ )
792
+ finally:
793
+ if stop_heartbeat is not None:
794
+ await stop_heartbeat()
795
+
796
+ logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")
797
+
798
+ async def step(
799
+ self,
800
+ input: T_task,
801
+ *,
802
+ resources: Optional[NamedResources] = None,
803
+ mode: Optional[RolloutMode] = None,
804
+ event: Optional[ExecutionEvent] = None,
805
+ ) -> Rollout:
806
+ """Execute a single task directly, bypassing the task queue.
807
+
808
+ This method creates a new rollout for the given input and executes it
809
+ immediately. Unlike [`iter()`][mantisdk.LitAgentRunner.iter],
810
+ exceptions are propagated to the caller.
811
+
812
+ Args:
813
+ input: The task input to be processed by the agent.
814
+ resources: Optional named resources to be used for this specific task.
815
+ If provided, a new resources entry will be created in the store.
816
+ If not provided, the latest resources from the store will be used.
817
+ mode: Optional rollout mode ("train" or "validation"). If not provided,
818
+ the agent's default mode will be used.
819
+ event: Optional ExecutionEvent object to signal interruption (currently unused
820
+ but included for interface consistency).
821
+
822
+ Returns:
823
+ The completed rollout.
824
+
825
+ Raises:
826
+ Exception: Any exception that occurs during rollout execution will be
827
+ re-raised to the caller.
828
+ """
829
+ store = self.get_store()
830
+
831
+ if resources is not None:
832
+ resources_update = await store.add_resources(resources)
833
+ resources_id = resources_update.resources_id
834
+ else:
835
+ resources_id = None
836
+
837
+ attempted_rollout = await self.get_store().start_rollout(
838
+ input=input, mode=mode, resources_id=resources_id, worker_id=self.get_worker_id()
839
+ )
840
+ rollout_id = await self._step_impl(attempted_rollout, raise_on_exception=True)
841
+
842
+ completed_rollout = await store.get_rollout_by_id(rollout_id)
843
+ if completed_rollout is None:
844
+ raise RuntimeError(f"{self._log_prefix()} Failed to fetch completed rollout by id after step: {rollout_id}")
845
+ return completed_rollout