mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/runner/agent.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
"""Agent runner implementation for executing agent rollouts.
|
|
4
|
+
|
|
5
|
+
This module provides the concrete implementation of the runner interface,
|
|
6
|
+
handling the execution of agent rollouts with support for tracing, hooks,
|
|
7
|
+
and distributed worker coordination.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import logging
|
|
14
|
+
import random
|
|
15
|
+
import threading
|
|
16
|
+
import time
|
|
17
|
+
from contextlib import suppress
|
|
18
|
+
from typing import (
|
|
19
|
+
TYPE_CHECKING,
|
|
20
|
+
Any,
|
|
21
|
+
Awaitable,
|
|
22
|
+
Callable,
|
|
23
|
+
List,
|
|
24
|
+
Literal,
|
|
25
|
+
Optional,
|
|
26
|
+
Sequence,
|
|
27
|
+
TypeVar,
|
|
28
|
+
cast,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from opentelemetry.sdk.trace import ReadableSpan
|
|
32
|
+
|
|
33
|
+
from mantisdk.litagent import LitAgent
|
|
34
|
+
from mantisdk.reward import emit_reward, find_final_reward
|
|
35
|
+
from mantisdk.store.base import LightningStore
|
|
36
|
+
from mantisdk.tracer.base import Tracer
|
|
37
|
+
from mantisdk.tracer.otel import OtelTracer
|
|
38
|
+
from mantisdk.types import (
|
|
39
|
+
AttemptedRollout,
|
|
40
|
+
Hook,
|
|
41
|
+
NamedResources,
|
|
42
|
+
Rollout,
|
|
43
|
+
RolloutMode,
|
|
44
|
+
RolloutRawResult,
|
|
45
|
+
Span,
|
|
46
|
+
SpanCoreFields,
|
|
47
|
+
)
|
|
48
|
+
from mantisdk.utils.system_snapshot import system_snapshot
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from mantisdk.execution.events import ExecutionEvent
|
|
52
|
+
|
|
53
|
+
from .base import Runner
|
|
54
|
+
|
|
55
|
+
T_task = TypeVar("T_task")
|
|
56
|
+
|
|
57
|
+
logger = logging.getLogger(__name__)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LitAgentRunner(Runner[T_task]):
|
|
61
|
+
"""Execute [`LitAgent`][mantisdk.LitAgent] tasks with tracing support.
|
|
62
|
+
|
|
63
|
+
This runner manages the complete lifecycle of agent rollout execution,
|
|
64
|
+
including task polling, resource management, tracing, and hooks. It supports
|
|
65
|
+
both continuous iteration over tasks from the store and single-step execution.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
worker_id: Identifier for the active worker process, if any.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
tracer: Tracer,
|
|
74
|
+
max_rollouts: Optional[int] = None,
|
|
75
|
+
poll_interval: float = 5.0,
|
|
76
|
+
heartbeat_interval: float = 10.0,
|
|
77
|
+
interval_jitter: float = 0.5,
|
|
78
|
+
heartbeat_launch_mode: Literal["asyncio", "thread"] = "thread",
|
|
79
|
+
heartbeat_include_gpu: bool = False,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Initialize the agent runner.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
tracer: [`Tracer`][mantisdk.Tracer] used for rollout spans.
|
|
85
|
+
max_rollouts: Optional cap on iterations processed by
|
|
86
|
+
[`iter`][mantisdk.LitAgentRunner.iter].
|
|
87
|
+
poll_interval: Seconds to wait between store polls when no work is available.
|
|
88
|
+
heartbeat_interval: Seconds to wait between sending heartbeats to the store.
|
|
89
|
+
interval_jitter: Jitter factor for the poll interval. The actual interval will be between
|
|
90
|
+
poll_interval - interval_jitter and poll_interval + interval_jitter.
|
|
91
|
+
This is to avoid the overload caused by the synchronization of the runners.
|
|
92
|
+
heartbeat_launch_mode: Launch mode for the heartbeat loop. Can be "asyncio" or "thread".
|
|
93
|
+
"thread" is the default and recommended mode as it prevents blocking the event loop
|
|
94
|
+
under load. Use "asyncio" for simpler deployments with low worker counts.
|
|
95
|
+
heartbeat_include_gpu: Whether to include GPU stats in heartbeat snapshots.
|
|
96
|
+
Querying GPU stats can be slow under load, so this is disabled by default.
|
|
97
|
+
"""
|
|
98
|
+
super().__init__()
|
|
99
|
+
self._tracer = tracer
|
|
100
|
+
self._max_rollouts = max_rollouts
|
|
101
|
+
self._poll_interval = poll_interval
|
|
102
|
+
self._heartbeat_interval = heartbeat_interval
|
|
103
|
+
self._interval_jitter = interval_jitter
|
|
104
|
+
self._heartbeat_launch_mode = heartbeat_launch_mode
|
|
105
|
+
self._heartbeat_include_gpu = heartbeat_include_gpu
|
|
106
|
+
self._random_state = random.Random()
|
|
107
|
+
|
|
108
|
+
# Set later
|
|
109
|
+
self._agent: Optional[LitAgent[T_task]] = None
|
|
110
|
+
self._hooks: Sequence[Hook] = []
|
|
111
|
+
self._store: Optional[LightningStore] = None
|
|
112
|
+
self.worker_id: Optional[int] = None
|
|
113
|
+
|
|
114
|
+
def init(self, agent: LitAgent[T_task], *, hooks: Optional[Sequence[Hook]] = None, **kwargs: Any) -> None:
|
|
115
|
+
"""Initialize the runner with the agent.
|
|
116
|
+
|
|
117
|
+
This sets up the agent-runner relationship, registers hooks, and
|
|
118
|
+
initializes the tracer.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
agent: [`LitAgent`][mantisdk.LitAgent] instance executed by the runner.
|
|
122
|
+
hooks: Optional sequence of [`Hook`][mantisdk.Hook]
|
|
123
|
+
callbacks invoked around tracing and rollout boundaries.
|
|
124
|
+
**kwargs: Additional initialization arguments (currently unused).
|
|
125
|
+
"""
|
|
126
|
+
self._agent = agent
|
|
127
|
+
self._agent.set_runner(self)
|
|
128
|
+
self._hooks = [*hooks] if hooks is not None else []
|
|
129
|
+
|
|
130
|
+
self._tracer.init()
|
|
131
|
+
|
|
132
|
+
def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> None:
|
|
133
|
+
"""Initialize the runner for each worker with worker_id and store.
|
|
134
|
+
|
|
135
|
+
This method is called once per worker in a distributed setup to provide
|
|
136
|
+
the worker with its ID and store connection.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
worker_id: Unique identifier for this worker process.
|
|
140
|
+
store: [`LightningStore`][mantisdk.LightningStore]
|
|
141
|
+
used for task coordination and persistence.
|
|
142
|
+
**kwargs: Additional worker-specific initialization arguments (currently unused).
|
|
143
|
+
"""
|
|
144
|
+
self._store = store
|
|
145
|
+
self.worker_id = worker_id
|
|
146
|
+
|
|
147
|
+
self._tracer.init_worker(worker_id, store)
|
|
148
|
+
|
|
149
|
+
def teardown(self, *args: Any, **kwargs: Any) -> None:
|
|
150
|
+
"""Teardown the runner and clean up all resources.
|
|
151
|
+
|
|
152
|
+
This method resets all internal state including the agent, store,
|
|
153
|
+
hooks, and worker ID, and calls the tracer's teardown method.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
*args: Additional teardown arguments (currently unused).
|
|
157
|
+
**kwargs: Additional teardown keyword arguments (currently unused).
|
|
158
|
+
"""
|
|
159
|
+
self._agent = None
|
|
160
|
+
self._store = None
|
|
161
|
+
self.worker_id = None
|
|
162
|
+
self._hooks = []
|
|
163
|
+
|
|
164
|
+
self._tracer.teardown()
|
|
165
|
+
|
|
166
|
+
def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
|
|
167
|
+
"""Teardown the runner for a specific worker.
|
|
168
|
+
|
|
169
|
+
This method cleans up worker-specific resources and resets the worker ID.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
worker_id: Unique identifier of the worker being torn down.
|
|
173
|
+
*args: Additional teardown arguments (currently unused).
|
|
174
|
+
**kwargs: Additional teardown keyword arguments (currently unused).
|
|
175
|
+
"""
|
|
176
|
+
self.worker_id = None
|
|
177
|
+
|
|
178
|
+
self._tracer.teardown_worker(worker_id)
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def tracer(self) -> Tracer:
|
|
182
|
+
"""Get the tracer instance.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
The Tracer instance used by this runner.
|
|
186
|
+
"""
|
|
187
|
+
return self._tracer
|
|
188
|
+
|
|
189
|
+
def get_agent(self) -> LitAgent[T_task]:
|
|
190
|
+
"""Get the agent instance.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
The LitAgent instance managed by this runner.
|
|
194
|
+
|
|
195
|
+
Raises:
|
|
196
|
+
ValueError: If the agent has not been initialized via [`init`][mantisdk.LitAgentRunner.init].
|
|
197
|
+
"""
|
|
198
|
+
if self._agent is None:
|
|
199
|
+
raise ValueError("Agent not initialized. Call init() first.")
|
|
200
|
+
return self._agent
|
|
201
|
+
|
|
202
|
+
def get_store(self) -> LightningStore:
|
|
203
|
+
"""Get the store instance.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The LightningStore instance for this worker.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
ValueError: If the store has not been initialized via [`init_worker`][mantisdk.LitAgentRunner.init_worker].
|
|
210
|
+
"""
|
|
211
|
+
if self._store is None:
|
|
212
|
+
raise ValueError("Store not initialized. Call init_worker() first.")
|
|
213
|
+
return self._store
|
|
214
|
+
|
|
215
|
+
def get_worker_id(self) -> str:
|
|
216
|
+
"""Get the formatted worker ID string.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
A formatted string like "Worker-0" if initialized, or "Worker-Unknown"
|
|
220
|
+
if the worker ID has not been set.
|
|
221
|
+
"""
|
|
222
|
+
return f"Worker-{self.worker_id}" if self.worker_id is not None else "Worker-Unknown"
|
|
223
|
+
|
|
224
|
+
def _log_prefix(self, rollout_id: Optional[str] = None) -> str:
|
|
225
|
+
"""Generate a standardized log prefix for the current worker.
|
|
226
|
+
|
|
227
|
+
This creates a consistent prefix format for log messages to identify
|
|
228
|
+
which worker and rollout the message is associated with.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
rollout_id: Optional rollout ID to include in the prefix.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
A formatted log prefix string like "[Worker 0 | Rollout xyz]",
|
|
235
|
+
"[Worker 0]", "[Rollout xyz]", or "[Default Worker]".
|
|
236
|
+
"""
|
|
237
|
+
if self.worker_id is not None:
|
|
238
|
+
if rollout_id:
|
|
239
|
+
return f"[Worker {self.worker_id} | Rollout {rollout_id}]"
|
|
240
|
+
else:
|
|
241
|
+
return f"[Worker {self.worker_id}]"
|
|
242
|
+
if rollout_id:
|
|
243
|
+
return f"[Rollout {rollout_id}]"
|
|
244
|
+
return "[Default Worker]"
|
|
245
|
+
|
|
246
|
+
async def _trigger_hooks(
|
|
247
|
+
self,
|
|
248
|
+
hook_type: Literal["on_trace_start", "on_trace_end", "on_rollout_start", "on_rollout_end"],
|
|
249
|
+
*args: Any,
|
|
250
|
+
**kwargs: Any,
|
|
251
|
+
) -> None:
|
|
252
|
+
"""Trigger all registered hooks of a specific type.
|
|
253
|
+
|
|
254
|
+
This method calls the specified hook method on all registered hooks,
|
|
255
|
+
catching and logging any exceptions that occur during hook execution
|
|
256
|
+
to prevent them from disrupting the main execution flow.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
hook_type: The type of hook to trigger. Valid values are:
|
|
260
|
+
"on_trace_start", "on_trace_end", "on_rollout_start", "on_rollout_end".
|
|
261
|
+
*args: Positional arguments to pass to the hook methods.
|
|
262
|
+
**kwargs: Keyword arguments to pass to the hook methods.
|
|
263
|
+
"""
|
|
264
|
+
for hook in self._hooks:
|
|
265
|
+
try:
|
|
266
|
+
await getattr(hook, hook_type)(*args, **kwargs)
|
|
267
|
+
except Exception:
|
|
268
|
+
logger.exception(f"{self._log_prefix()} Exception during {hook_type} hook {hook}.")
|
|
269
|
+
|
|
270
|
+
async def _post_process_rollout_result(
|
|
271
|
+
self, rollout: AttemptedRollout, raw_result: RolloutRawResult
|
|
272
|
+
) -> List[ReadableSpan] | List[Span]:
|
|
273
|
+
"""Standardizes the agent's return value and report what's needed to report to the store.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
rollout: The rollout object for the current task.
|
|
277
|
+
raw_result: The output from the agent's rollout method.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
The spans that are assumed to be added to the store.
|
|
281
|
+
This only serves as an estimation for logging purposes. For precise tracking, use the store directly.
|
|
282
|
+
"""
|
|
283
|
+
store = self.get_store()
|
|
284
|
+
|
|
285
|
+
trace_spans: list[Span] = []
|
|
286
|
+
result_recognized: bool = False
|
|
287
|
+
|
|
288
|
+
# Case 0: result is None
|
|
289
|
+
if raw_result is None:
|
|
290
|
+
trace_spans = self._tracer.get_last_trace()
|
|
291
|
+
result_recognized = True
|
|
292
|
+
|
|
293
|
+
# Case 1: result is a float (final reward)
|
|
294
|
+
if isinstance(raw_result, (bool, int, float)):
|
|
295
|
+
if isinstance(raw_result, (bool, int)):
|
|
296
|
+
logger.warning(
|
|
297
|
+
f"{self._log_prefix(rollout.rollout_id)} Reward is not a number, got: {type(raw_result)}. "
|
|
298
|
+
"Auto converting to float."
|
|
299
|
+
)
|
|
300
|
+
raw_result = float(raw_result)
|
|
301
|
+
# Preserve the existing spans before another span is emitted
|
|
302
|
+
trace_spans = list(self._tracer.get_last_trace())
|
|
303
|
+
# This will NOT emit another span to the tracer
|
|
304
|
+
reward_span_core_fields = emit_reward(raw_result, propagate=False)
|
|
305
|
+
# We add it to the store manually
|
|
306
|
+
sequence_id = await store.get_next_span_sequence_id(rollout.rollout_id, rollout.attempt.attempt_id)
|
|
307
|
+
# Get trace_id from existing spans to maintain trace context consistency
|
|
308
|
+
# This ensures the reward span links to the same trace as OTLP-exported spans
|
|
309
|
+
existing_trace_id = trace_spans[0].trace_id if trace_spans else None
|
|
310
|
+
reward_span = Span.from_core_fields(
|
|
311
|
+
reward_span_core_fields,
|
|
312
|
+
rollout_id=rollout.rollout_id,
|
|
313
|
+
attempt_id=rollout.attempt.attempt_id,
|
|
314
|
+
sequence_id=sequence_id,
|
|
315
|
+
trace_id=existing_trace_id,
|
|
316
|
+
)
|
|
317
|
+
await store.add_span(reward_span)
|
|
318
|
+
result_recognized = True
|
|
319
|
+
|
|
320
|
+
# Case 2-4: result is a list
|
|
321
|
+
if isinstance(raw_result, list):
|
|
322
|
+
# Case 2: result is a list of ReadableSpan (OpenTelemetry spans)
|
|
323
|
+
if len(raw_result) > 0 and all(isinstance(t, ReadableSpan) for t in raw_result):
|
|
324
|
+
if isinstance(self._tracer, OtelTracer):
|
|
325
|
+
logger.warning(
|
|
326
|
+
f"{self._log_prefix(rollout.rollout_id)} Tracer is already an OpenTelemetry tracer. "
|
|
327
|
+
"The traces should have already been added to the store. "
|
|
328
|
+
"Returning the traces from the rollout will result in duplicate spans."
|
|
329
|
+
)
|
|
330
|
+
for span in raw_result:
|
|
331
|
+
added_span = await store.add_otel_span(
|
|
332
|
+
rollout.rollout_id, rollout.attempt.attempt_id, cast(ReadableSpan, span)
|
|
333
|
+
)
|
|
334
|
+
if added_span is not None:
|
|
335
|
+
trace_spans.append(added_span)
|
|
336
|
+
else:
|
|
337
|
+
logger.error(
|
|
338
|
+
f"{self._log_prefix(rollout.rollout_id)} Failed to add OpenTelemetry span to the store: {span}"
|
|
339
|
+
)
|
|
340
|
+
result_recognized = True
|
|
341
|
+
|
|
342
|
+
# Case 3: result is a list of Span (mantisdk spans)
|
|
343
|
+
elif len(raw_result) > 0 and all(isinstance(t, Span) for t in raw_result):
|
|
344
|
+
# Add the spans directly to the store
|
|
345
|
+
for span in raw_result:
|
|
346
|
+
await store.add_span(cast(Span, span))
|
|
347
|
+
trace_spans = [cast(Span, span) for span in raw_result]
|
|
348
|
+
result_recognized = True
|
|
349
|
+
|
|
350
|
+
# Case 4: result is a list of SpanCoreFields (mantisdk spans)
|
|
351
|
+
elif len(raw_result) > 0 and all(isinstance(t, SpanCoreFields) for t in raw_result):
|
|
352
|
+
# Add the spans directly to the store too, but needs to get sequence id first
|
|
353
|
+
sequence_ids = await store.get_many_span_sequence_ids(
|
|
354
|
+
[(rollout.rollout_id, rollout.attempt.attempt_id) for _ in range(len(raw_result))]
|
|
355
|
+
)
|
|
356
|
+
trace_spans = [
|
|
357
|
+
Span.from_core_fields(
|
|
358
|
+
cast(SpanCoreFields, span_core_fields),
|
|
359
|
+
rollout_id=rollout.rollout_id,
|
|
360
|
+
attempt_id=rollout.attempt.attempt_id,
|
|
361
|
+
sequence_id=sequence_id,
|
|
362
|
+
)
|
|
363
|
+
for span_core_fields, sequence_id in zip(raw_result, sequence_ids, strict=True)
|
|
364
|
+
]
|
|
365
|
+
await store.add_many_spans(trace_spans)
|
|
366
|
+
result_recognized = True
|
|
367
|
+
|
|
368
|
+
# Left over cases for list
|
|
369
|
+
elif len(raw_result) == 0:
|
|
370
|
+
logger.warning(
|
|
371
|
+
f"{self._log_prefix(rollout.rollout_id)} The rollout returns an empty list. "
|
|
372
|
+
"Please check your rollout implementation."
|
|
373
|
+
)
|
|
374
|
+
trace_spans = []
|
|
375
|
+
result_recognized = True
|
|
376
|
+
|
|
377
|
+
else:
|
|
378
|
+
types = [type(t).__name__ for t in raw_result][:10]
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Invalid raw result type. It's expected to be a list of ReadableSpan or Span, "
|
|
381
|
+
f"but got: {', '.join(types)}..."
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
if not result_recognized:
|
|
385
|
+
raise TypeError(
|
|
386
|
+
f"Invalid raw result type. It's expected to be none, float, or a list of ReadableSpan or Span, "
|
|
387
|
+
f"but got: {type(raw_result).__name__}..."
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
return trace_spans
|
|
391
|
+
|
|
392
|
+
async def _emit_heartbeat(self, store: LightningStore) -> None:
|
|
393
|
+
"""Send a heartbeat tick to the store.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
store: The lightning store to update.
|
|
397
|
+
"""
|
|
398
|
+
logger.debug(f"{self._log_prefix()} Preparing to emit heartbeat.")
|
|
399
|
+
worker_id = self.get_worker_id()
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
snapshot = await asyncio.wait_for(
|
|
403
|
+
asyncio.to_thread(system_snapshot, self._heartbeat_include_gpu),
|
|
404
|
+
timeout=self._heartbeat_interval,
|
|
405
|
+
)
|
|
406
|
+
logger.debug(f"{self._log_prefix()} Heartbeat snapshot acquired.")
|
|
407
|
+
except asyncio.TimeoutError:
|
|
408
|
+
logger.warning(
|
|
409
|
+
"%s Heartbeat snapshot acquisition timed out after %.1fs, skipping.",
|
|
410
|
+
self._log_prefix(),
|
|
411
|
+
self._heartbeat_interval,
|
|
412
|
+
)
|
|
413
|
+
return
|
|
414
|
+
except asyncio.CancelledError:
|
|
415
|
+
# bypass the exception
|
|
416
|
+
raise
|
|
417
|
+
except Exception:
|
|
418
|
+
logger.exception("%s Unable to acquire heartbeat snapshot.", self._log_prefix())
|
|
419
|
+
return
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
await asyncio.wait_for(store.update_worker(worker_id, snapshot), timeout=self._heartbeat_interval)
|
|
423
|
+
logger.debug(f"{self._log_prefix()} Heartbeat updated successfully.")
|
|
424
|
+
except asyncio.CancelledError:
|
|
425
|
+
# bypass the exception
|
|
426
|
+
raise
|
|
427
|
+
except asyncio.TimeoutError:
|
|
428
|
+
logger.warning(
|
|
429
|
+
"%s update worker heartbeat timed out after %.1fs, skipping.",
|
|
430
|
+
self._log_prefix(),
|
|
431
|
+
self._heartbeat_interval,
|
|
432
|
+
)
|
|
433
|
+
except Exception:
|
|
434
|
+
logger.exception("%s Unable to update worker heartbeat.", self._log_prefix())
|
|
435
|
+
|
|
436
|
+
def _start_heartbeat_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
|
|
437
|
+
"""Start a background heartbeat loop and return an async stopper."""
|
|
438
|
+
|
|
439
|
+
if self._heartbeat_interval <= 0:
|
|
440
|
+
return None
|
|
441
|
+
|
|
442
|
+
if self.worker_id is None:
|
|
443
|
+
logger.warning("%s Cannot start heartbeat loop without worker_id.", self._log_prefix())
|
|
444
|
+
return None
|
|
445
|
+
|
|
446
|
+
if self._heartbeat_launch_mode == "asyncio":
|
|
447
|
+
return self._start_heartbeat_asyncio_loop(store)
|
|
448
|
+
if self._heartbeat_launch_mode == "thread":
|
|
449
|
+
return self._start_heartbeat_thread_loop(store)
|
|
450
|
+
raise ValueError(f"Unsupported heartbeat launch mode: {self._heartbeat_launch_mode}")
|
|
451
|
+
|
|
452
|
+
def _start_heartbeat_asyncio_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
|
|
453
|
+
"""Start a background heartbeat loop using asyncio.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
store: The lightning store to update.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
An async stopper function that can be used to stop the heartbeat loop.
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
stop_event = asyncio.Event()
|
|
463
|
+
|
|
464
|
+
async def heartbeat_loop() -> None:
|
|
465
|
+
while not stop_event.is_set():
|
|
466
|
+
try:
|
|
467
|
+
# Run _emit_heartbeat in thread pool to avoid blocking the event loop.
|
|
468
|
+
# Timeout at the interval - if it takes longer, the data is stale anyway.
|
|
469
|
+
await self._emit_heartbeat(store)
|
|
470
|
+
except Exception:
|
|
471
|
+
logger.exception("%s Heartbeat failed.", self._log_prefix())
|
|
472
|
+
with suppress(asyncio.TimeoutError):
|
|
473
|
+
interval = self._heartbeat_interval + self._random_state.uniform(
|
|
474
|
+
-self._interval_jitter, self._interval_jitter
|
|
475
|
+
)
|
|
476
|
+
interval = max(interval, 0.01)
|
|
477
|
+
await asyncio.wait_for(stop_event.wait(), timeout=interval)
|
|
478
|
+
|
|
479
|
+
task = asyncio.create_task(heartbeat_loop(), name=f"{self.get_worker_id()}-heartbeat")
|
|
480
|
+
|
|
481
|
+
async def stop() -> None:
|
|
482
|
+
stop_event.set()
|
|
483
|
+
with suppress(asyncio.CancelledError):
|
|
484
|
+
await task
|
|
485
|
+
|
|
486
|
+
return stop
|
|
487
|
+
|
|
488
|
+
def _start_heartbeat_thread_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
|
|
489
|
+
"""Start a background heartbeat loop using threading.
|
|
490
|
+
|
|
491
|
+
It uses two threads: one to produce the snapshot and one to consume it,
|
|
492
|
+
to avoid either of them blocking the event loop.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
store: The lightning store to update.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
An async stopper function that can be used to stop the heartbeat loop.
|
|
499
|
+
"""
|
|
500
|
+
stop_evt = threading.Event()
|
|
501
|
+
lock = threading.Lock()
|
|
502
|
+
|
|
503
|
+
latest_snapshot = None
|
|
504
|
+
latest_ts = 0.0 # time.monotonic() when snapshot was captured
|
|
505
|
+
|
|
506
|
+
# Consider snapshot stale after ~1 interval plus jitter slack.
|
|
507
|
+
stale_after = self._heartbeat_interval + self._interval_jitter + 1.0
|
|
508
|
+
|
|
509
|
+
worker_id = self.get_worker_id()
|
|
510
|
+
|
|
511
|
+
def producer() -> None:
|
|
512
|
+
nonlocal latest_snapshot, latest_ts
|
|
513
|
+
while not stop_evt.is_set():
|
|
514
|
+
try:
|
|
515
|
+
logger.debug(f"{self._log_prefix()} Heartbeat producer: acquiring snapshot.")
|
|
516
|
+
snap = system_snapshot(self._heartbeat_include_gpu) # sync
|
|
517
|
+
logger.debug(f"{self._log_prefix()} Heartbeat producer: snapshot acquired.")
|
|
518
|
+
ts = time.monotonic()
|
|
519
|
+
with lock:
|
|
520
|
+
latest_snapshot = snap
|
|
521
|
+
latest_ts = ts
|
|
522
|
+
except Exception:
|
|
523
|
+
logger.warning("%s Heartbeat producer: system_snapshot failed.", self._log_prefix(), exc_info=True)
|
|
524
|
+
|
|
525
|
+
interval = self._heartbeat_interval + self._random_state.uniform(
|
|
526
|
+
-self._interval_jitter, self._interval_jitter
|
|
527
|
+
)
|
|
528
|
+
stop_evt.wait(max(interval, 0.01))
|
|
529
|
+
|
|
530
|
+
def consumer() -> None:
|
|
531
|
+
loop = asyncio.new_event_loop()
|
|
532
|
+
asyncio.set_event_loop(loop)
|
|
533
|
+
last_warned_ts = None # Track which snapshot we've already warned about
|
|
534
|
+
try:
|
|
535
|
+
while not stop_evt.is_set():
|
|
536
|
+
with lock:
|
|
537
|
+
snap = latest_snapshot
|
|
538
|
+
ts = latest_ts
|
|
539
|
+
|
|
540
|
+
wait_interval = max(
|
|
541
|
+
self._heartbeat_interval
|
|
542
|
+
+ self._random_state.uniform(-self._interval_jitter, self._interval_jitter),
|
|
543
|
+
0.01,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
if snap is None:
|
|
547
|
+
# probably just started
|
|
548
|
+
logger.debug("%s Heartbeat consumer: no snapshot yet; skipping update.", self._log_prefix())
|
|
549
|
+
stop_evt.wait(wait_interval)
|
|
550
|
+
continue
|
|
551
|
+
|
|
552
|
+
age = time.monotonic() - ts
|
|
553
|
+
if age > stale_after:
|
|
554
|
+
# Only warn once per stale snapshot (check if we haven't warned about this timestamp yet)
|
|
555
|
+
if last_warned_ts != ts:
|
|
556
|
+
logger.warning(
|
|
557
|
+
"%s Heartbeat consumer: snapshot stale (age=%.2fs > %.2fs); skipping update.",
|
|
558
|
+
self._log_prefix(),
|
|
559
|
+
age,
|
|
560
|
+
stale_after,
|
|
561
|
+
)
|
|
562
|
+
last_warned_ts = ts
|
|
563
|
+
stop_evt.wait(wait_interval)
|
|
564
|
+
continue
|
|
565
|
+
|
|
566
|
+
try:
|
|
567
|
+
logger.debug(f"{self._log_prefix()} Heartbeat consumer: updating worker.")
|
|
568
|
+
loop.run_until_complete(
|
|
569
|
+
asyncio.wait_for(
|
|
570
|
+
store.update_worker(worker_id, snap),
|
|
571
|
+
timeout=self._heartbeat_interval,
|
|
572
|
+
)
|
|
573
|
+
)
|
|
574
|
+
logger.debug(f"{self._log_prefix()} Heartbeat consumer: worker updated.")
|
|
575
|
+
except asyncio.TimeoutError:
|
|
576
|
+
logger.warning(
|
|
577
|
+
"%s Heartbeat consumer: update timed out after %.1fs.",
|
|
578
|
+
self._log_prefix(),
|
|
579
|
+
self._heartbeat_interval,
|
|
580
|
+
)
|
|
581
|
+
except Exception:
|
|
582
|
+
logger.warning("%s Heartbeat consumer: update failed.", self._log_prefix(), exc_info=True)
|
|
583
|
+
|
|
584
|
+
stop_evt.wait(wait_interval)
|
|
585
|
+
finally:
|
|
586
|
+
with suppress(Exception):
|
|
587
|
+
loop.stop()
|
|
588
|
+
with suppress(Exception):
|
|
589
|
+
loop.close()
|
|
590
|
+
|
|
591
|
+
t_prod = threading.Thread(target=producer, name=f"{worker_id}-heartbeat-producer", daemon=True)
|
|
592
|
+
t_cons = threading.Thread(target=consumer, name=f"{worker_id}-heartbeat-consumer", daemon=True)
|
|
593
|
+
t_prod.start()
|
|
594
|
+
t_cons.start()
|
|
595
|
+
|
|
596
|
+
async def stop() -> None:
|
|
597
|
+
stop_evt.set()
|
|
598
|
+
await asyncio.to_thread(t_prod.join)
|
|
599
|
+
await asyncio.to_thread(t_cons.join)
|
|
600
|
+
|
|
601
|
+
return stop
|
|
602
|
+
|
|
603
|
+
async def _sleep_until_next_poll(self, event: Optional[ExecutionEvent] = None) -> None:
|
|
604
|
+
"""Sleep until the next poll interval, with optional event-based interruption.
|
|
605
|
+
|
|
606
|
+
If an event is provided, the method will check it periodically (every 0.1s)
|
|
607
|
+
and return early if the event is set.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
event: Optional [`ExecutionEvent`][mantisdk.ExecutionEvent] object that can be used to interrupt the sleep.
|
|
611
|
+
If set during the sleep period, the method returns immediately.
|
|
612
|
+
"""
|
|
613
|
+
interval = self._poll_interval + self._random_state.uniform(-self._interval_jitter, self._interval_jitter)
|
|
614
|
+
interval = max(interval, 0.01)
|
|
615
|
+
if event is None:
|
|
616
|
+
await asyncio.sleep(interval)
|
|
617
|
+
return
|
|
618
|
+
current_time = time.time()
|
|
619
|
+
next_time = current_time + interval
|
|
620
|
+
while time.time() < next_time:
|
|
621
|
+
await asyncio.sleep(0.1)
|
|
622
|
+
if event.is_set():
|
|
623
|
+
return
|
|
624
|
+
|
|
625
|
+
async def _step_impl(self, next_rollout: AttemptedRollout, raise_on_exception: bool = False) -> str:
|
|
626
|
+
"""Execute a single rollout implementation.
|
|
627
|
+
|
|
628
|
+
This is the core method that handles the execution of a single rollout,
|
|
629
|
+
including resource fetching, hook triggering, agent invocation, tracing,
|
|
630
|
+
and result processing.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
next_rollout: The rollout to execute, containing input data, mode,
|
|
634
|
+
and resources information.
|
|
635
|
+
raise_on_exception: If True, exceptions during rollout execution will
|
|
636
|
+
be re-raised. If False, exceptions are logged but not propagated.
|
|
637
|
+
"""
|
|
638
|
+
store = self.get_store()
|
|
639
|
+
agent = self.get_agent()
|
|
640
|
+
|
|
641
|
+
rollout_id = next_rollout.rollout_id
|
|
642
|
+
|
|
643
|
+
resources_id = next_rollout.resources_id
|
|
644
|
+
resources_update = None
|
|
645
|
+
if resources_id:
|
|
646
|
+
resources_update = await store.get_resources_by_id(resources_id)
|
|
647
|
+
else:
|
|
648
|
+
logger.debug(f"{self._log_prefix(rollout_id)} No 'resources_id'. Fetching latest resources.")
|
|
649
|
+
resources_update = await store.get_latest_resources()
|
|
650
|
+
if not resources_update:
|
|
651
|
+
if raise_on_exception:
|
|
652
|
+
raise RuntimeError(f"{self._log_prefix(rollout_id)} Failed to fetch resources")
|
|
653
|
+
else:
|
|
654
|
+
logger.error(f"{self._log_prefix(rollout_id)} Failed to fetch resources. Skipping.")
|
|
655
|
+
return rollout_id
|
|
656
|
+
|
|
657
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Resources fetched (id={resources_update.resources_id}).")
|
|
658
|
+
|
|
659
|
+
trace_spans: List[ReadableSpan] | List[Span] = []
|
|
660
|
+
has_exception: bool = False
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
await self._trigger_hooks(hook_type="on_rollout_start", agent=agent, runner=self, rollout=next_rollout)
|
|
664
|
+
|
|
665
|
+
start_time = time.time()
|
|
666
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Prepared for trace context.")
|
|
667
|
+
async with self._tracer.trace_context(
|
|
668
|
+
name=rollout_id, rollout_id=rollout_id, attempt_id=next_rollout.attempt.attempt_id
|
|
669
|
+
):
|
|
670
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Entered trace context.")
|
|
671
|
+
await self._trigger_hooks(
|
|
672
|
+
hook_type="on_trace_start", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# NOTE: This is the most costly step in the whole function
|
|
676
|
+
# If the rollout method becomes unresponsive or timeouts, there is nothing we can do within the runner.
|
|
677
|
+
# We might need some mechanisms in execution strategy to restart the runner. But that's a future work.
|
|
678
|
+
if agent.is_async():
|
|
679
|
+
rollout_method = (
|
|
680
|
+
agent.training_rollout_async if next_rollout.mode == "train" else agent.validation_rollout_async
|
|
681
|
+
)
|
|
682
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Starting async rollout method.")
|
|
683
|
+
result = await rollout_method(
|
|
684
|
+
next_rollout.input, resources=resources_update.resources, rollout=next_rollout
|
|
685
|
+
)
|
|
686
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Async rollout method completed.")
|
|
687
|
+
else:
|
|
688
|
+
rollout_method = (
|
|
689
|
+
agent.training_rollout if next_rollout.mode == "train" else agent.validation_rollout
|
|
690
|
+
)
|
|
691
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Starting sync rollout method.")
|
|
692
|
+
result = rollout_method(
|
|
693
|
+
next_rollout.input, resources=resources_update.resources, rollout=next_rollout
|
|
694
|
+
)
|
|
695
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Sync rollout method completed.")
|
|
696
|
+
|
|
697
|
+
await self._trigger_hooks(
|
|
698
|
+
hook_type="on_trace_end", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
logger.debug(f"{self._log_prefix(rollout_id)} Trace context exited.")
|
|
702
|
+
|
|
703
|
+
# Possible exceptions in post_process will be caught in the overall exception handler
|
|
704
|
+
trace_spans = await self._post_process_rollout_result(next_rollout, result)
|
|
705
|
+
last_reward = find_final_reward(trace_spans)
|
|
706
|
+
|
|
707
|
+
end_time = time.time()
|
|
708
|
+
logger.info(
|
|
709
|
+
f"{self._log_prefix(rollout_id)} Completed in "
|
|
710
|
+
f"{end_time - start_time:.2f}s. Collected {len(trace_spans)} span(s). "
|
|
711
|
+
f"Final reward: {last_reward}"
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
except Exception:
|
|
715
|
+
logger.exception(f"{self._log_prefix(rollout_id)} Exception during rollout.")
|
|
716
|
+
has_exception = True
|
|
717
|
+
|
|
718
|
+
if raise_on_exception:
|
|
719
|
+
raise
|
|
720
|
+
finally:
|
|
721
|
+
try:
|
|
722
|
+
await self._trigger_hooks(
|
|
723
|
+
hook_type="on_rollout_end", agent=agent, runner=self, rollout=next_rollout, spans=trace_spans
|
|
724
|
+
)
|
|
725
|
+
except Exception:
|
|
726
|
+
logger.exception(f"{self._log_prefix(rollout_id)} Exception during on_rollout_end hook.")
|
|
727
|
+
|
|
728
|
+
try:
|
|
729
|
+
if has_exception:
|
|
730
|
+
# possibly timed out and cancelled?
|
|
731
|
+
await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="failed")
|
|
732
|
+
else:
|
|
733
|
+
await store.update_attempt(rollout_id, next_rollout.attempt.attempt_id, status="succeeded")
|
|
734
|
+
except Exception:
|
|
735
|
+
logger.exception(
|
|
736
|
+
f"{self._log_prefix(rollout_id)} Exception during update_attempt. Giving up the update."
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
return rollout_id
|
|
740
|
+
|
|
741
|
+
async def iter(self, *, event: Optional[ExecutionEvent] = None) -> None:
|
|
742
|
+
"""Run the runner, continuously iterating over tasks in the store.
|
|
743
|
+
|
|
744
|
+
This method polls the store for new rollouts and executes them until:
|
|
745
|
+
|
|
746
|
+
- The event is set (if provided)
|
|
747
|
+
- The max_rollouts limit is reached (if configured)
|
|
748
|
+
- No more tasks are available
|
|
749
|
+
|
|
750
|
+
All exceptions during rollout execution are caught and logged but not
|
|
751
|
+
propagated, allowing the runner to continue processing subsequent tasks.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
event: Optional ExecutionEvent object to signal the runner to stop. The runner
|
|
755
|
+
will check this event periodically and stop gracefully when set.
|
|
756
|
+
"""
|
|
757
|
+
num_tasks_processed = 0
|
|
758
|
+
logger.info(f"{self._log_prefix()} Started async rollouts (max: {self._max_rollouts or 'unlimited'}).")
|
|
759
|
+
store = self.get_store()
|
|
760
|
+
|
|
761
|
+
stop_heartbeat = self._start_heartbeat_loop(store)
|
|
762
|
+
|
|
763
|
+
try:
|
|
764
|
+
while not (event is not None and event.is_set()) and (
|
|
765
|
+
self._max_rollouts is None or num_tasks_processed < self._max_rollouts
|
|
766
|
+
):
|
|
767
|
+
# Retrieve the next rollout
|
|
768
|
+
next_rollout: Optional[Rollout] = None
|
|
769
|
+
while not (event is not None and event.is_set()):
|
|
770
|
+
logger.debug(f"{self._log_prefix()} Try to poll for next rollout.")
|
|
771
|
+
next_rollout = await store.dequeue_rollout(worker_id=self.get_worker_id())
|
|
772
|
+
logger.debug(f"{self._log_prefix()} Next rollout retrieved: {next_rollout}")
|
|
773
|
+
if next_rollout is None:
|
|
774
|
+
logger.debug(
|
|
775
|
+
f"{self._log_prefix()} No rollout to poll. Waiting for {self._poll_interval} seconds."
|
|
776
|
+
)
|
|
777
|
+
await self._sleep_until_next_poll(event)
|
|
778
|
+
else:
|
|
779
|
+
break
|
|
780
|
+
|
|
781
|
+
if next_rollout is None:
|
|
782
|
+
return
|
|
783
|
+
|
|
784
|
+
# Execute the step
|
|
785
|
+
await self._step_impl(next_rollout)
|
|
786
|
+
|
|
787
|
+
num_tasks_processed += 1
|
|
788
|
+
if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
|
|
789
|
+
logger.info(
|
|
790
|
+
f"{self._log_prefix()} Progress: {num_tasks_processed}/{self._max_rollouts or 'unlimited'}"
|
|
791
|
+
)
|
|
792
|
+
finally:
|
|
793
|
+
if stop_heartbeat is not None:
|
|
794
|
+
await stop_heartbeat()
|
|
795
|
+
|
|
796
|
+
logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")
|
|
797
|
+
|
|
798
|
+
async def step(
|
|
799
|
+
self,
|
|
800
|
+
input: T_task,
|
|
801
|
+
*,
|
|
802
|
+
resources: Optional[NamedResources] = None,
|
|
803
|
+
mode: Optional[RolloutMode] = None,
|
|
804
|
+
event: Optional[ExecutionEvent] = None,
|
|
805
|
+
) -> Rollout:
|
|
806
|
+
"""Execute a single task directly, bypassing the task queue.
|
|
807
|
+
|
|
808
|
+
This method creates a new rollout for the given input and executes it
|
|
809
|
+
immediately. Unlike [`iter()`][mantisdk.LitAgentRunner.iter],
|
|
810
|
+
exceptions are propagated to the caller.
|
|
811
|
+
|
|
812
|
+
Args:
|
|
813
|
+
input: The task input to be processed by the agent.
|
|
814
|
+
resources: Optional named resources to be used for this specific task.
|
|
815
|
+
If provided, a new resources entry will be created in the store.
|
|
816
|
+
If not provided, the latest resources from the store will be used.
|
|
817
|
+
mode: Optional rollout mode ("train" or "validation"). If not provided,
|
|
818
|
+
the agent's default mode will be used.
|
|
819
|
+
event: Optional ExecutionEvent object to signal interruption (currently unused
|
|
820
|
+
but included for interface consistency).
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
The completed rollout.
|
|
824
|
+
|
|
825
|
+
Raises:
|
|
826
|
+
Exception: Any exception that occurs during rollout execution will be
|
|
827
|
+
re-raised to the caller.
|
|
828
|
+
"""
|
|
829
|
+
store = self.get_store()
|
|
830
|
+
|
|
831
|
+
if resources is not None:
|
|
832
|
+
resources_update = await store.add_resources(resources)
|
|
833
|
+
resources_id = resources_update.resources_id
|
|
834
|
+
else:
|
|
835
|
+
resources_id = None
|
|
836
|
+
|
|
837
|
+
attempted_rollout = await self.get_store().start_rollout(
|
|
838
|
+
input=input, mode=mode, resources_id=resources_id, worker_id=self.get_worker_id()
|
|
839
|
+
)
|
|
840
|
+
rollout_id = await self._step_impl(attempted_rollout, raise_on_exception=True)
|
|
841
|
+
|
|
842
|
+
completed_rollout = await store.get_rollout_by_id(rollout_id)
|
|
843
|
+
if completed_rollout is None:
|
|
844
|
+
raise RuntimeError(f"{self._log_prefix()} Failed to fetch completed rollout by id after step: {rollout_id}")
|
|
845
|
+
return completed_rollout
|