mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
import logging
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar, Union
|
|
8
|
+
|
|
9
|
+
from mantisdk.adapter import TraceAdapter, TracerTraceToTriplet
|
|
10
|
+
from mantisdk.algorithm import Algorithm, Baseline, FastAlgorithm
|
|
11
|
+
from mantisdk.client import MantisdkClient
|
|
12
|
+
from mantisdk.execution.base import ExecutionStrategy
|
|
13
|
+
from mantisdk.execution.client_server import ClientServerExecutionStrategy
|
|
14
|
+
from mantisdk.execution.events import ExecutionEvent
|
|
15
|
+
from mantisdk.litagent import LitAgent
|
|
16
|
+
from mantisdk.llm_proxy import LLMProxy
|
|
17
|
+
from mantisdk.runner import LitAgentRunner, Runner
|
|
18
|
+
from mantisdk.store.base import LightningStore
|
|
19
|
+
from mantisdk.store.memory import InMemoryLightningStore
|
|
20
|
+
from mantisdk.tracer.agentops import AgentOpsTracer
|
|
21
|
+
from mantisdk.tracer.base import Tracer
|
|
22
|
+
from mantisdk.tracer.otel import OtelTracer
|
|
23
|
+
from mantisdk.types import Dataset, Hook, NamedResources
|
|
24
|
+
|
|
25
|
+
from .init_utils import build_component, instantiate_component
|
|
26
|
+
from .legacy import TrainerLegacy
|
|
27
|
+
from .registry import ExecutionStrategyRegistry
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
32
|
+
T = TypeVar("T")
|
|
33
|
+
|
|
34
|
+
ComponentSpec = Union[T, type[T], Callable[[], T], str, Dict[str, Any], None]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Trainer(TrainerLegacy):
|
|
38
|
+
"""High-level orchestration layer that wires Algorithm <-> Runner <-> Store.
|
|
39
|
+
|
|
40
|
+
A [`Trainer`][mantisdk.Trainer] packages the moving parts of Agent-Lightning's
|
|
41
|
+
training loop into a single entry point:
|
|
42
|
+
|
|
43
|
+
* **Algorithm lifecycle:** Instantiates or accepts an [`Algorithm`][mantisdk.Algorithm],
|
|
44
|
+
attaches the current [`LightningStore`][mantisdk.LightningStore], adapter, and
|
|
45
|
+
initial resources, then executes the algorithm role inside the configured execution strategy.
|
|
46
|
+
* **Runner fleet:** Spawns one or more [`Runner`][mantisdk.Runner] instances (defaulting
|
|
47
|
+
to [`LitAgentRunner`][mantisdk.LitAgentRunner]) that hydrate a [`LitAgent`][mantisdk.LitAgent],
|
|
48
|
+
claim rollouts, stream spans, and respect graceful termination signals from the execution strategy.
|
|
49
|
+
* **Execution strategy:** Delegates process management to an
|
|
50
|
+
[`ExecutionStrategy`][mantisdk.ExecutionStrategy] (shared memory, client/server, etc.),
|
|
51
|
+
so advanced users can swap orchestration backends without changing trainer code.
|
|
52
|
+
* **Telemetry plumbing:** Ensures tracers, adapters, and optional [`LLMProxy`][mantisdk.LLMProxy]
|
|
53
|
+
are wired into both algorithm and runners so telemetry flows back into the store.
|
|
54
|
+
|
|
55
|
+
The trainer exposes two convenience entry points:
|
|
56
|
+
[`fit()`][mantisdk.Trainer.fit] for full training and
|
|
57
|
+
[`dev()`][mantisdk.Trainer.dev] for fast, reproducible dry-runs. See the
|
|
58
|
+
[Train the First Agent](../how-to/train-first-agent.md) and
|
|
59
|
+
[Write the First Algorithm](../how-to/write-first-algorithm.md) tutorials for the broader context.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
algorithm: Optional[Algorithm]
|
|
63
|
+
"""An instance of [`Algorithm`][mantisdk.Algorithm] to use for training."""
|
|
64
|
+
|
|
65
|
+
store: LightningStore
|
|
66
|
+
"""An instance of [`LightningStore`][mantisdk.LightningStore] to use for storing tasks and traces."""
|
|
67
|
+
|
|
68
|
+
runner: Runner[Any]
|
|
69
|
+
"""An instance of [`Runner`][mantisdk.Runner] to use for running the agent."""
|
|
70
|
+
|
|
71
|
+
initial_resources: Optional[NamedResources]
|
|
72
|
+
"""An instance of [`NamedResources`][mantisdk.NamedResources] to use for bootstrapping the fit/dev process.
|
|
73
|
+
|
|
74
|
+
The resources will be handed over to the algorithm. Note that not all algorithms support seeding resources.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
n_runners: int
|
|
78
|
+
"""Number of agent runners to run in parallel."""
|
|
79
|
+
|
|
80
|
+
max_rollouts: Optional[int]
|
|
81
|
+
"""Maximum number of rollouts to process per runner. If None, workers run until no more rollouts are available."""
|
|
82
|
+
|
|
83
|
+
strategy: ExecutionStrategy
|
|
84
|
+
"""An instance of [`ExecutionStrategy`][mantisdk.ExecutionStrategy] to use for spawning the algorithm and runners."""
|
|
85
|
+
|
|
86
|
+
tracer: Tracer
|
|
87
|
+
"""A tracer instance, or a string pointing to the class full name or a dictionary with a 'type' key
|
|
88
|
+
that specifies the class full name and other initialization parameters.
|
|
89
|
+
If None, a default [`AgentOpsTracer`][mantisdk.AgentOpsTracer] will be created with the current settings."""
|
|
90
|
+
|
|
91
|
+
hooks: Sequence[Hook]
|
|
92
|
+
"""A sequence of [`Hook`][mantisdk.Hook] instances to be called at various lifecycle stages (e.g., `on_trace_start`,
|
|
93
|
+
`on_trace_end`, `on_rollout_start`, `on_rollout_end`)."""
|
|
94
|
+
|
|
95
|
+
adapter: TraceAdapter[Any]
|
|
96
|
+
"""An instance of [`TraceAdapter`][mantisdk.TraceAdapter] to export data consumble by algorithms from traces."""
|
|
97
|
+
|
|
98
|
+
llm_proxy: Optional[LLMProxy]
|
|
99
|
+
"""An instance of [`LLMProxy`][mantisdk.LLMProxy] to use for intercepting the LLM calls.
|
|
100
|
+
If not provided, algorithm may create one on its own."""
|
|
101
|
+
|
|
102
|
+
n_workers: int
|
|
103
|
+
"""Number of agent workers to run in parallel. Deprecated in favor of `n_runners`."""
|
|
104
|
+
|
|
105
|
+
max_tasks: Optional[int]
|
|
106
|
+
"""Maximum number of tasks to process per runner. Deprecated in favor of `max_rollouts`."""
|
|
107
|
+
|
|
108
|
+
daemon: bool
|
|
109
|
+
"""Whether worker processes should be daemons. Daemon processes
|
|
110
|
+
are terminated automatically when the main process exits. Deprecated.
|
|
111
|
+
Only have effect with `fit_v0`."""
|
|
112
|
+
|
|
113
|
+
triplet_exporter: TraceAdapter[Any]
|
|
114
|
+
"""An instance of [`TracerTraceToTriplet`][mantisdk.TracerTraceToTriplet] to export triplets from traces,
|
|
115
|
+
or a dictionary with the initialization parameters for the exporter.
|
|
116
|
+
Deprecated. Use [`adapter`][mantisdk.Trainer.adapter] instead."""
|
|
117
|
+
|
|
118
|
+
port: Optional[int]
|
|
119
|
+
"""Port forwarded to [`ClientServerExecutionStrategy`][mantisdk.ClientServerExecutionStrategy]."""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
*,
|
|
124
|
+
dev: bool = False,
|
|
125
|
+
n_runners: Optional[int] = None,
|
|
126
|
+
max_rollouts: Optional[int] = None,
|
|
127
|
+
initial_resources: Optional[NamedResources] = None,
|
|
128
|
+
tracer: ComponentSpec[Tracer] = None,
|
|
129
|
+
adapter: ComponentSpec[TraceAdapter[Any]] = None,
|
|
130
|
+
store: ComponentSpec[LightningStore] = None,
|
|
131
|
+
runner: ComponentSpec[Runner[Any]] = None,
|
|
132
|
+
strategy: ComponentSpec[ExecutionStrategy] = None,
|
|
133
|
+
port: Optional[int] = None,
|
|
134
|
+
algorithm: ComponentSpec[Algorithm] = None,
|
|
135
|
+
llm_proxy: ComponentSpec[LLMProxy] = None,
|
|
136
|
+
n_workers: Optional[int] = None,
|
|
137
|
+
max_tasks: Optional[int] = None,
|
|
138
|
+
daemon: bool = True,
|
|
139
|
+
triplet_exporter: ComponentSpec[TracerTraceToTriplet] = None,
|
|
140
|
+
hooks: Optional[Union[Hook, Sequence[Hook]]] = None,
|
|
141
|
+
):
|
|
142
|
+
"""Configure the trainer and resolve user-provided component specifications.
|
|
143
|
+
|
|
144
|
+
Each keyword accepts either a concrete instance, a class, a callable factory, a
|
|
145
|
+
registry string, or a lightweight configuration dictionary (see
|
|
146
|
+
[`build_component()`][mantisdk.trainer.init_utils.build_component]).
|
|
147
|
+
|
|
148
|
+
When ``port`` is provided it is forwarded to
|
|
149
|
+
[`ClientServerExecutionStrategy`][mantisdk.ClientServerExecutionStrategy]
|
|
150
|
+
instances constructed (or supplied) for the trainer.
|
|
151
|
+
"""
|
|
152
|
+
# Do not call super().__init__() here.
|
|
153
|
+
# super().__init__() will call TrainerLegacy's initialization, which is not intended.
|
|
154
|
+
self.worker_id: Optional[int] = None
|
|
155
|
+
|
|
156
|
+
if dev:
|
|
157
|
+
warnings.warn(
|
|
158
|
+
"Trainer(dev=True) is deprecated and will be removed in future versions. "
|
|
159
|
+
"Please use Trainer.dev(...) instead.",
|
|
160
|
+
DeprecationWarning,
|
|
161
|
+
stacklevel=2,
|
|
162
|
+
)
|
|
163
|
+
self._dev = dev
|
|
164
|
+
self.daemon = daemon
|
|
165
|
+
self._client: MantisdkClient | None = None # Will be initialized in fit or fit_v0
|
|
166
|
+
|
|
167
|
+
if n_workers is not None:
|
|
168
|
+
warnings.warn(
|
|
169
|
+
"`n_workers` is deprecated. Please use `n_runners`.",
|
|
170
|
+
DeprecationWarning,
|
|
171
|
+
stacklevel=2,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if n_runners is None:
|
|
175
|
+
n_runners = n_workers if n_workers is not None else 1
|
|
176
|
+
else:
|
|
177
|
+
if n_workers is not None and n_workers != n_runners:
|
|
178
|
+
warnings.warn(
|
|
179
|
+
"`n_workers` is ignored when `n_runners` is provided.",
|
|
180
|
+
DeprecationWarning,
|
|
181
|
+
stacklevel=2,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.n_runners = n_runners
|
|
185
|
+
self.n_workers = n_runners # Backwards compatibility for fit_v0
|
|
186
|
+
|
|
187
|
+
if max_tasks is not None:
|
|
188
|
+
warnings.warn(
|
|
189
|
+
"`max_tasks` is deprecated. Please use `max_rollouts`.",
|
|
190
|
+
DeprecationWarning,
|
|
191
|
+
stacklevel=2,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if max_rollouts is None:
|
|
195
|
+
max_rollouts = max_tasks
|
|
196
|
+
elif max_tasks is not None and max_tasks != max_rollouts:
|
|
197
|
+
warnings.warn(
|
|
198
|
+
"`max_tasks` is ignored when `max_rollouts` is provided.",
|
|
199
|
+
DeprecationWarning,
|
|
200
|
+
stacklevel=2,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.max_rollouts = max_rollouts
|
|
204
|
+
self.max_tasks = max_tasks if max_tasks is not None else max_rollouts
|
|
205
|
+
|
|
206
|
+
# Note: tracer is created after store to enable auto-detection of OTLP-capable stores
|
|
207
|
+
self._tracer_spec = tracer # Store spec for later resolution
|
|
208
|
+
|
|
209
|
+
if adapter is not None and triplet_exporter is not None:
|
|
210
|
+
warnings.warn(
|
|
211
|
+
"`triplet_exporter` is deprecated and ignored because `adapter` is provided.",
|
|
212
|
+
DeprecationWarning,
|
|
213
|
+
stacklevel=2,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
adapter_spec = adapter if adapter is not None else triplet_exporter
|
|
217
|
+
self.adapter = self._make_adapter(adapter_spec)
|
|
218
|
+
self.triplet_exporter = self.adapter # Backwards compatibility
|
|
219
|
+
|
|
220
|
+
self.algorithm = self._make_algorithm(algorithm)
|
|
221
|
+
|
|
222
|
+
# We might be able to support a list of resources in future.
|
|
223
|
+
self.initial_resources = initial_resources
|
|
224
|
+
|
|
225
|
+
self.port = port
|
|
226
|
+
|
|
227
|
+
self.strategy = self._make_strategy(
|
|
228
|
+
strategy,
|
|
229
|
+
n_runners=self.n_runners,
|
|
230
|
+
port=port,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# The active store for the current execution context
|
|
234
|
+
self.store = self._make_store(store, self.strategy)
|
|
235
|
+
|
|
236
|
+
# Create tracer after store so we can auto-detect OTLP-capable stores
|
|
237
|
+
self.tracer = self._make_tracer(self._tracer_spec, self.store)
|
|
238
|
+
|
|
239
|
+
self.runner = self._make_runner(runner)
|
|
240
|
+
|
|
241
|
+
if hasattr(self.strategy, "n_runners"):
|
|
242
|
+
strategy_runners = getattr(self.strategy, "n_runners")
|
|
243
|
+
if isinstance(strategy_runners, int) and strategy_runners > 0:
|
|
244
|
+
self.n_runners = strategy_runners
|
|
245
|
+
self.n_workers = strategy_runners
|
|
246
|
+
|
|
247
|
+
self.llm_proxy = self._make_llm_proxy(llm_proxy, store=self.store)
|
|
248
|
+
|
|
249
|
+
self.hooks = self._normalize_hooks(hooks)
|
|
250
|
+
|
|
251
|
+
if not self.daemon:
|
|
252
|
+
logger.warning(
|
|
253
|
+
"daemon=False. Worker processes are non-daemonic. "
|
|
254
|
+
"The worker processes will NOT be terminated when the main process exits. "
|
|
255
|
+
"The cleanup must be handled manually."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def _make_tracer(
|
|
259
|
+
self, tracer: ComponentSpec[Tracer], store: Optional[LightningStore] = None
|
|
260
|
+
) -> Tracer:
|
|
261
|
+
"""Resolve the tracer component from user input.
|
|
262
|
+
|
|
263
|
+
Auto-detects OTLP-capable stores (like InsightLightningStore) and uses OtelTracer
|
|
264
|
+
to enable native OTLP trace export. Falls back to AgentOpsTracer otherwise.
|
|
265
|
+
"""
|
|
266
|
+
# Check if store supports OTLP traces (e.g., InsightLightningStore)
|
|
267
|
+
store_supports_otlp = False
|
|
268
|
+
if store is not None:
|
|
269
|
+
# Check store's capabilities
|
|
270
|
+
if hasattr(store, "capabilities"):
|
|
271
|
+
store_supports_otlp = store.capabilities.get("otlp_traces", False)
|
|
272
|
+
# Also check listeners on the store for OTLP capability
|
|
273
|
+
elif hasattr(store, "_listeners"):
|
|
274
|
+
for listener in getattr(store, "_listeners", []):
|
|
275
|
+
if hasattr(listener, "capabilities"):
|
|
276
|
+
if listener.capabilities.get("otlp_traces", False):
|
|
277
|
+
store_supports_otlp = True
|
|
278
|
+
break
|
|
279
|
+
|
|
280
|
+
if store_supports_otlp:
|
|
281
|
+
logger.info("OTLP-capable store detected. Using OtelTracer for native trace export.")
|
|
282
|
+
default_factory = lambda: OtelTracer()
|
|
283
|
+
else:
|
|
284
|
+
default_factory = lambda: AgentOpsTracer(
|
|
285
|
+
agentops_managed=True,
|
|
286
|
+
instrument_managed=True,
|
|
287
|
+
daemon=self.daemon,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
return build_component(
|
|
291
|
+
tracer,
|
|
292
|
+
expected_type=Tracer,
|
|
293
|
+
spec_name="tracer",
|
|
294
|
+
default_factory=default_factory,
|
|
295
|
+
dict_requires_type=True,
|
|
296
|
+
invalid_spec_error_fmt="Invalid tracer type: {actual_type}. Expected Tracer, str, dict, or None.",
|
|
297
|
+
type_error_fmt="Tracer factory returned {type_name}, which is not a Tracer subclass.",
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
def _make_algorithm(self, algorithm: ComponentSpec[Algorithm]) -> Optional[Algorithm]:
|
|
301
|
+
"""Resolve the algorithm component, allowing `None` for dev-mode dry runs."""
|
|
302
|
+
return build_component(
|
|
303
|
+
algorithm,
|
|
304
|
+
expected_type=Algorithm,
|
|
305
|
+
spec_name="algorithm",
|
|
306
|
+
allow_none=True,
|
|
307
|
+
invalid_spec_error_fmt="Invalid algorithm type: {actual_type}. Expected Algorithm, str, dict, or None.",
|
|
308
|
+
type_error_fmt="Algorithm factory returned {type_name}, which is not a Algorithm subclass.",
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
def _make_adapter(self, adapter: ComponentSpec[TraceAdapter[Any]]) -> TraceAdapter[Any]:
|
|
312
|
+
"""Resolve the adapter used to transform spans into algorithm-ready payloads."""
|
|
313
|
+
return build_component(
|
|
314
|
+
adapter,
|
|
315
|
+
expected_type=TraceAdapter,
|
|
316
|
+
spec_name="adapter",
|
|
317
|
+
default_factory=TracerTraceToTriplet,
|
|
318
|
+
dict_requires_type=False,
|
|
319
|
+
dict_default_cls=TracerTraceToTriplet,
|
|
320
|
+
invalid_spec_error_fmt="Invalid adapter type: {actual_type}. Expected TraceAdapter, dict, or None.",
|
|
321
|
+
type_error_fmt="Adapter factory returned {type_name}, which is not a TraceAdapter subclass.",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _make_store(self, store: ComponentSpec[LightningStore], strategy: ExecutionStrategy) -> LightningStore:
|
|
325
|
+
"""Resolve the store implementation backing rollouts, attempts, spans, and resources.
|
|
326
|
+
|
|
327
|
+
By default, it's always a in-memory store. If using a client/server execution strategy,
|
|
328
|
+
the in-memory store will be initialized in a thread-safe manner.
|
|
329
|
+
"""
|
|
330
|
+
is_client_server = isinstance(strategy, ClientServerExecutionStrategy)
|
|
331
|
+
default_store_factory = lambda: InMemoryLightningStore(thread_safe=is_client_server)
|
|
332
|
+
return build_component(
|
|
333
|
+
store,
|
|
334
|
+
expected_type=LightningStore,
|
|
335
|
+
spec_name="store",
|
|
336
|
+
default_factory=default_store_factory,
|
|
337
|
+
invalid_spec_error_fmt="Invalid store type: {actual_type}. Expected LightningStore, str, dict, or None.",
|
|
338
|
+
type_error_fmt="Store factory returned {type_name}, which is not a LightningStore subclass.",
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def _make_strategy(
|
|
342
|
+
self,
|
|
343
|
+
strategy: ComponentSpec[ExecutionStrategy],
|
|
344
|
+
*,
|
|
345
|
+
n_runners: int,
|
|
346
|
+
port: Optional[int] = None,
|
|
347
|
+
) -> ExecutionStrategy:
|
|
348
|
+
"""Resolve the execution strategy and seed defaults such as `n_runners`."""
|
|
349
|
+
if isinstance(strategy, ExecutionStrategy):
|
|
350
|
+
if port is not None and isinstance(strategy, ClientServerExecutionStrategy):
|
|
351
|
+
strategy.server_port = port
|
|
352
|
+
return strategy
|
|
353
|
+
optional_defaults: Dict[str, Callable[[], Any]] = {"n_runners": lambda: n_runners}
|
|
354
|
+
if port is not None:
|
|
355
|
+
optional_defaults["server_port"] = lambda: port
|
|
356
|
+
|
|
357
|
+
def default_factory() -> ExecutionStrategy:
|
|
358
|
+
if port is not None:
|
|
359
|
+
return ClientServerExecutionStrategy(n_runners=n_runners, server_port=port)
|
|
360
|
+
return ClientServerExecutionStrategy(n_runners=n_runners)
|
|
361
|
+
|
|
362
|
+
return build_component(
|
|
363
|
+
strategy,
|
|
364
|
+
expected_type=ExecutionStrategy,
|
|
365
|
+
spec_name="strategy",
|
|
366
|
+
default_factory=default_factory,
|
|
367
|
+
optional_defaults=optional_defaults,
|
|
368
|
+
invalid_spec_error_fmt="Invalid strategy type: {actual_type}. Expected ExecutionStrategy, str, dict, or None.",
|
|
369
|
+
type_error_fmt="Strategy factory returned {type_name}, which is not an ExecutionStrategy subclass.",
|
|
370
|
+
registry=ExecutionStrategyRegistry,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
def _make_llm_proxy(
|
|
374
|
+
self,
|
|
375
|
+
llm_proxy: ComponentSpec[LLMProxy],
|
|
376
|
+
*,
|
|
377
|
+
store: LightningStore,
|
|
378
|
+
) -> Optional[LLMProxy]:
|
|
379
|
+
"""Resolve an optional LLM proxy and ensure it shares the trainer's store instance."""
|
|
380
|
+
if isinstance(llm_proxy, LLMProxy):
|
|
381
|
+
return llm_proxy
|
|
382
|
+
|
|
383
|
+
optional_defaults: Dict[str, Callable[[], Any]] = {"store": lambda: store}
|
|
384
|
+
|
|
385
|
+
# Extract OTLP endpoint and headers from store if it supports OTLP
|
|
386
|
+
if hasattr(store, "capabilities") and store.capabilities.get("otlp_traces", False):
|
|
387
|
+
try:
|
|
388
|
+
otlp_endpoint = store.otlp_traces_endpoint()
|
|
389
|
+
otlp_headers = store.get_otlp_headers() if hasattr(store, "get_otlp_headers") else {}
|
|
390
|
+
optional_defaults["otlp_endpoint"] = lambda: otlp_endpoint
|
|
391
|
+
optional_defaults["otlp_headers"] = lambda: otlp_headers
|
|
392
|
+
logger.info(f"Trainer: Store provides OTLP endpoint for proxy: {otlp_endpoint}")
|
|
393
|
+
except Exception as e:
|
|
394
|
+
logger.warning(f"Trainer: Failed to get OTLP endpoint from store: {e}")
|
|
395
|
+
|
|
396
|
+
if isinstance(llm_proxy, dict):
|
|
397
|
+
llm_proxy = {**llm_proxy}
|
|
398
|
+
llm_proxy.setdefault("store", store)
|
|
399
|
+
# Pass OTLP config to proxy if store provides it and it's not already set
|
|
400
|
+
if "otlp_endpoint" in optional_defaults and "otlp_endpoint" not in llm_proxy:
|
|
401
|
+
llm_proxy.setdefault("otlp_endpoint", optional_defaults["otlp_endpoint"]())
|
|
402
|
+
if "otlp_headers" in optional_defaults and "otlp_headers" not in llm_proxy:
|
|
403
|
+
llm_proxy.setdefault("otlp_headers", optional_defaults["otlp_headers"]())
|
|
404
|
+
|
|
405
|
+
return build_component(
|
|
406
|
+
llm_proxy,
|
|
407
|
+
expected_type=LLMProxy,
|
|
408
|
+
spec_name="llm_proxy",
|
|
409
|
+
allow_none=True,
|
|
410
|
+
optional_defaults=optional_defaults,
|
|
411
|
+
invalid_spec_error_fmt="Invalid llm_proxy type: {actual_type}. Expected LLMProxy, dict, str, or None.",
|
|
412
|
+
type_error_fmt="llm_proxy factory returned {type_name}, which is not an LLMProxy subclass.",
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def _make_runner(self, runner: ComponentSpec[Runner[Any]]) -> Runner[Any]:
|
|
416
|
+
"""Resolve the runner responsible for executing the agent inside each worker."""
|
|
417
|
+
optional_defaults: Dict[str, Callable[[], Any]] = {"tracer": lambda: self.tracer}
|
|
418
|
+
if self.max_rollouts is not None:
|
|
419
|
+
optional_defaults["max_rollouts"] = lambda: self.max_rollouts
|
|
420
|
+
|
|
421
|
+
def default_runner_factory() -> Runner[Any]:
|
|
422
|
+
return instantiate_component(LitAgentRunner, optional_defaults=optional_defaults)
|
|
423
|
+
|
|
424
|
+
return build_component(
|
|
425
|
+
runner,
|
|
426
|
+
expected_type=Runner,
|
|
427
|
+
spec_name="runner",
|
|
428
|
+
default_factory=default_runner_factory,
|
|
429
|
+
optional_defaults=optional_defaults,
|
|
430
|
+
invalid_spec_error_fmt="Invalid runner type: {actual_type}. Expected Runner, callable, str, dict, or None.",
|
|
431
|
+
type_error_fmt="Runner factory returned {type_name}, which is not a Runner subclass.",
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
def _normalize_hooks(self, hooks: Optional[Union[Hook, Sequence[Hook]]]) -> Sequence[Hook]:
|
|
435
|
+
"""Coerce hook inputs into an immutable sequence for runner initialization."""
|
|
436
|
+
if hooks is None:
|
|
437
|
+
return ()
|
|
438
|
+
if isinstance(hooks, Hook):
|
|
439
|
+
return (hooks,)
|
|
440
|
+
return tuple(hooks)
|
|
441
|
+
|
|
442
|
+
def fit(
|
|
443
|
+
self,
|
|
444
|
+
agent: LitAgent[T_co],
|
|
445
|
+
train_dataset: Optional[Dataset[T_co]] = None,
|
|
446
|
+
*,
|
|
447
|
+
val_dataset: Optional[Dataset[T_co]] = None,
|
|
448
|
+
) -> None:
|
|
449
|
+
"""Execute the full algorithm/runner training loop.
|
|
450
|
+
|
|
451
|
+
[`Trainer.fit`][mantisdk.Trainer.fit] packages the algorithm and runner bundles,
|
|
452
|
+
then hands them to the active [`ExecutionStrategy`][mantisdk.ExecutionStrategy].
|
|
453
|
+
The strategy rarely returns until:
|
|
454
|
+
|
|
455
|
+
* The algorithm exhausts the dataset(s) and stops enqueuing rollouts.
|
|
456
|
+
* `max_rollouts` causes individual runners to exit.
|
|
457
|
+
* An exception or interrupt cancels the shared [`ExecutionEvent`][mantisdk.ExecutionEvent].
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
agent: [`LitAgent`][mantisdk.LitAgent] implementation executed by runners.
|
|
461
|
+
train_dataset: Optional iterable of rollout inputs consumed by the algorithm.
|
|
462
|
+
val_dataset: Optional iterable consumed by validation passes.
|
|
463
|
+
"""
|
|
464
|
+
if isinstance(train_dataset, str):
|
|
465
|
+
logger.warning(
|
|
466
|
+
"Trainer.fit will no longer accepts a string URL in future version. "
|
|
467
|
+
"To continue using a string URL, please use Trainer.fit_v0 instead. "
|
|
468
|
+
"See documentation for how to migrate to latest version: https://microsoft.github.io/mantisdk/stable/"
|
|
469
|
+
)
|
|
470
|
+
return self.fit_v0( # type: ignore
|
|
471
|
+
agent,
|
|
472
|
+
train_dataset,
|
|
473
|
+
val_dataset, # type: ignore
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
agent.set_trainer(self)
|
|
477
|
+
|
|
478
|
+
algorithm_bundle = functools.partial(
|
|
479
|
+
self._algorithm_bundle,
|
|
480
|
+
train_dataset=train_dataset,
|
|
481
|
+
val_dataset=val_dataset,
|
|
482
|
+
algorithm=self.algorithm,
|
|
483
|
+
)
|
|
484
|
+
runner_bundle = functools.partial(self._runner_bundle, agent=agent)
|
|
485
|
+
|
|
486
|
+
self.strategy.execute(algorithm_bundle, runner_bundle, self.store)
|
|
487
|
+
|
|
488
|
+
def dev(
|
|
489
|
+
self,
|
|
490
|
+
agent: LitAgent[T_co],
|
|
491
|
+
train_dataset: Optional[Dataset[T_co]] = None,
|
|
492
|
+
*,
|
|
493
|
+
val_dataset: Optional[Dataset[T_co]] = None,
|
|
494
|
+
) -> None:
|
|
495
|
+
"""Exercise the infrastructure using a fast, synchronous algorithm.
|
|
496
|
+
|
|
497
|
+
[`Trainer.dev`][mantisdk.Trainer.dev] mirrors [`fit()`][mantisdk.Trainer.fit] but
|
|
498
|
+
insists on an [`Algorithm`][mantisdk.Algorithm] subtype that also derives from
|
|
499
|
+
[`FastAlgorithm`][mantisdk.FastAlgorithm]. This keeps the loop responsive for
|
|
500
|
+
debugging while still touching the same store, runners, hooks, and tracer plumbing.
|
|
501
|
+
|
|
502
|
+
If no algorithm is provided, a default [`Baseline`][mantisdk.Baseline] algorithm will be used.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
agent: [`LitAgent`][mantisdk.LitAgent] implementation to execute.
|
|
506
|
+
train_dataset: Optional iterable passed to the algorithm.
|
|
507
|
+
val_dataset: Optional iterable passed to the algorithm.
|
|
508
|
+
|
|
509
|
+
Raises:
|
|
510
|
+
TypeError: If the configured algorithm does not inherit from `FastAlgorithm`.
|
|
511
|
+
"""
|
|
512
|
+
agent.set_trainer(self)
|
|
513
|
+
|
|
514
|
+
# Sanity check
|
|
515
|
+
if self.algorithm is None:
|
|
516
|
+
algorithm = Baseline()
|
|
517
|
+
else:
|
|
518
|
+
algorithm = self.algorithm
|
|
519
|
+
|
|
520
|
+
if not isinstance(algorithm, FastAlgorithm):
|
|
521
|
+
raise TypeError(
|
|
522
|
+
"Trainer.dev() requires an algorithm that inherits from FastAlgorithm. "
|
|
523
|
+
f"Received {type(algorithm).__name__}."
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
algorithm_bundle = functools.partial(
|
|
527
|
+
self._algorithm_bundle,
|
|
528
|
+
train_dataset=train_dataset,
|
|
529
|
+
val_dataset=val_dataset,
|
|
530
|
+
algorithm=algorithm,
|
|
531
|
+
)
|
|
532
|
+
runner_bundle = functools.partial(self._runner_bundle, agent=agent)
|
|
533
|
+
self.strategy.execute(algorithm_bundle, runner_bundle, self.store)
|
|
534
|
+
|
|
535
|
+
async def _algorithm_bundle(
|
|
536
|
+
self,
|
|
537
|
+
store: LightningStore,
|
|
538
|
+
event: ExecutionEvent,
|
|
539
|
+
train_dataset: Optional[Dataset[T_co]],
|
|
540
|
+
val_dataset: Optional[Dataset[T_co]],
|
|
541
|
+
algorithm: Optional[Algorithm],
|
|
542
|
+
) -> None:
|
|
543
|
+
"""Internal entry point executed by the strategy for the algorithm role.
|
|
544
|
+
|
|
545
|
+
This coroutine is scheduled inside the strategy's process/thread and is responsible
|
|
546
|
+
for binding algorithm dependencies (store, adapter, initial resources, proxy) before
|
|
547
|
+
invoking [`Algorithm.run`][mantisdk.Algorithm.run].
|
|
548
|
+
When `algorithm` is `None` the bundle simply waits for the
|
|
549
|
+
shared `event` to signal shutdown so runners can still execute (useful for manual queue
|
|
550
|
+
seeding or external algorithms).
|
|
551
|
+
"""
|
|
552
|
+
if algorithm is not None:
|
|
553
|
+
algorithm.set_trainer(self)
|
|
554
|
+
algorithm.set_store(store)
|
|
555
|
+
algorithm.set_adapter(self.adapter)
|
|
556
|
+
if self.initial_resources is not None:
|
|
557
|
+
algorithm.set_initial_resources(self.initial_resources)
|
|
558
|
+
if self.llm_proxy is not None:
|
|
559
|
+
self.llm_proxy.set_store(store)
|
|
560
|
+
algorithm.set_llm_proxy(self.llm_proxy)
|
|
561
|
+
|
|
562
|
+
if algorithm is None:
|
|
563
|
+
while not event.is_set():
|
|
564
|
+
await asyncio.sleep(0.1)
|
|
565
|
+
return
|
|
566
|
+
try:
|
|
567
|
+
if algorithm.is_async():
|
|
568
|
+
await algorithm.run( # type: ignore
|
|
569
|
+
train_dataset=train_dataset,
|
|
570
|
+
val_dataset=val_dataset,
|
|
571
|
+
)
|
|
572
|
+
else:
|
|
573
|
+
# This will block the event loop to maximize the debugging experience
|
|
574
|
+
# It's the responsibility of the execution strategy to enable async execution
|
|
575
|
+
algorithm.run(
|
|
576
|
+
train_dataset=train_dataset,
|
|
577
|
+
val_dataset=val_dataset,
|
|
578
|
+
)
|
|
579
|
+
except Exception:
|
|
580
|
+
logger.exception("Algorithm bundle encountered an error.")
|
|
581
|
+
raise
|
|
582
|
+
|
|
583
|
+
async def _runner_bundle(
|
|
584
|
+
self, store: LightningStore, worker_id: int, event: ExecutionEvent, agent: LitAgent[T_co]
|
|
585
|
+
) -> None:
|
|
586
|
+
"""Internal entry point executed by the strategy for each runner role.
|
|
587
|
+
|
|
588
|
+
The bundle materializes the configured runner, binds the agent and hooks, associates
|
|
589
|
+
the worker with the shared store, and then drives the runner's [`iter`][mantisdk.Runner.iter]
|
|
590
|
+
loop until the execution event is set or an exception occurs. Cleanup mirrors the initialization
|
|
591
|
+
sequence to keep tracer state, hooks, and agent resources consistent across restarts.
|
|
592
|
+
"""
|
|
593
|
+
runner_instance: Runner[Any] | None = None
|
|
594
|
+
runner_initialized = False
|
|
595
|
+
worker_initialized = False
|
|
596
|
+
try:
|
|
597
|
+
# If not using shm execution strategy, we are already in the forked process
|
|
598
|
+
runner_instance = self.runner
|
|
599
|
+
runner_instance.init(agent=agent, hooks=self.hooks)
|
|
600
|
+
runner_initialized = True
|
|
601
|
+
runner_instance.init_worker(worker_id, store)
|
|
602
|
+
worker_initialized = True
|
|
603
|
+
await runner_instance.iter(event=event)
|
|
604
|
+
except Exception:
|
|
605
|
+
logger.exception("Runner bundle encountered an error (worker_id=%s).", worker_id)
|
|
606
|
+
raise
|
|
607
|
+
finally:
|
|
608
|
+
if runner_instance is not None:
|
|
609
|
+
if worker_initialized:
|
|
610
|
+
try:
|
|
611
|
+
runner_instance.teardown_worker(worker_id)
|
|
612
|
+
except Exception:
|
|
613
|
+
logger.exception("Error during runner worker teardown (worker_id=%s).", worker_id)
|
|
614
|
+
if runner_initialized:
|
|
615
|
+
try:
|
|
616
|
+
runner_instance.teardown()
|
|
617
|
+
except Exception:
|
|
618
|
+
logger.exception("Error during runner teardown (worker_id=%s).", worker_id)
|