mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/types/core.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
"""Core data models shared across Mantisdk components."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Any,
|
|
10
|
+
Callable,
|
|
11
|
+
Dict,
|
|
12
|
+
Generic,
|
|
13
|
+
Iterator,
|
|
14
|
+
List,
|
|
15
|
+
Literal,
|
|
16
|
+
Mapping,
|
|
17
|
+
Optional,
|
|
18
|
+
Protocol,
|
|
19
|
+
Sequence,
|
|
20
|
+
SupportsIndex,
|
|
21
|
+
TypedDict,
|
|
22
|
+
TypeVar,
|
|
23
|
+
Union,
|
|
24
|
+
cast,
|
|
25
|
+
overload,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from opentelemetry.sdk.trace import ReadableSpan
|
|
29
|
+
from pydantic import BaseModel, Field, model_validator
|
|
30
|
+
|
|
31
|
+
from .tracer import Span, SpanCoreFields
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from mantisdk.litagent import LitAgent
|
|
35
|
+
from mantisdk.runner.base import Runner
|
|
36
|
+
from mantisdk.tracer.base import Tracer
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"Triplet",
|
|
40
|
+
"RolloutLegacy",
|
|
41
|
+
"Task",
|
|
42
|
+
"TaskInput",
|
|
43
|
+
"TaskIfAny",
|
|
44
|
+
"RolloutRawResultLegacy",
|
|
45
|
+
"RolloutRawResult",
|
|
46
|
+
"RolloutMode",
|
|
47
|
+
"GenericResponse",
|
|
48
|
+
"ParallelWorkerBase",
|
|
49
|
+
"Dataset",
|
|
50
|
+
"AttemptStatus",
|
|
51
|
+
"RolloutStatus",
|
|
52
|
+
"RolloutConfig",
|
|
53
|
+
"Rollout",
|
|
54
|
+
"Attempt",
|
|
55
|
+
"AttemptedRollout",
|
|
56
|
+
"EnqueueRolloutRequest",
|
|
57
|
+
"Hook",
|
|
58
|
+
"Worker",
|
|
59
|
+
"WorkerStatus",
|
|
60
|
+
"PaginatedResult",
|
|
61
|
+
"FilterOptions",
|
|
62
|
+
"SortOptions",
|
|
63
|
+
"FilterField",
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Triplet(BaseModel):
|
|
70
|
+
"""Single interaction turn captured during reinforcement learning."""
|
|
71
|
+
|
|
72
|
+
prompt: Any
|
|
73
|
+
response: Any
|
|
74
|
+
reward: Optional[float] = None
|
|
75
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class RolloutLegacy(BaseModel):
|
|
79
|
+
"""Legacy reporting payload exchanged with the deprecated HTTP server.
|
|
80
|
+
|
|
81
|
+
!!! warning "Deprecated"
|
|
82
|
+
Use [`Rollout`][mantisdk.Rollout] instead.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
rollout_id: str
|
|
86
|
+
|
|
87
|
+
# Echoing the input task
|
|
88
|
+
task: Optional[Task] = None
|
|
89
|
+
|
|
90
|
+
# Primary, high-level feedback
|
|
91
|
+
final_reward: Optional[float] = None
|
|
92
|
+
|
|
93
|
+
# Structured, sequential feedback for RL-style optimization
|
|
94
|
+
triplets: Optional[List[Triplet]] = None
|
|
95
|
+
|
|
96
|
+
# Optional, rich-context data for deep analysis
|
|
97
|
+
trace: Optional[List[Dict[str, Any]]] = Field(
|
|
98
|
+
default=None,
|
|
99
|
+
description="A list of spans that conform to the OpenTelemetry JSON format. "
|
|
100
|
+
"Users of the opentelemetry-sdk can generate this by calling "
|
|
101
|
+
"json.loads(readable_span.to_json()).",
|
|
102
|
+
)
|
|
103
|
+
logs: Optional[List[str]] = None
|
|
104
|
+
|
|
105
|
+
# A bucket for any other relevant information
|
|
106
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
RolloutStatus = Literal[
|
|
110
|
+
"queuing", # initial status
|
|
111
|
+
"preparing", # after the trace is claimed
|
|
112
|
+
"running", # after receiving the first trace
|
|
113
|
+
"failed", # crashed
|
|
114
|
+
"succeeded", # status OK
|
|
115
|
+
"cancelled", # cancelled by user (or watchdog)
|
|
116
|
+
"requeuing", # retrying
|
|
117
|
+
]
|
|
118
|
+
"""The status of a rollout."""
|
|
119
|
+
|
|
120
|
+
AttemptStatus = Literal[
|
|
121
|
+
# A status is essentially a process.
|
|
122
|
+
# It should not have scheduling/management statuses like "queuing" or "cancelled".
|
|
123
|
+
"preparing",
|
|
124
|
+
"running",
|
|
125
|
+
"failed",
|
|
126
|
+
"succeeded",
|
|
127
|
+
"unresponsive", # the worker has not reported results for a while
|
|
128
|
+
"timeout", # the worker has been emitting new logs, but have been working on the task for too long
|
|
129
|
+
]
|
|
130
|
+
"""The status of an attempt."""
|
|
131
|
+
|
|
132
|
+
RolloutMode = Literal["train", "val", "test"]
|
|
133
|
+
"""Possible rollout modes."""
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Attempt(BaseModel):
|
|
137
|
+
"""Execution attempt for a rollout, including metadata for retries."""
|
|
138
|
+
|
|
139
|
+
rollout_id: str
|
|
140
|
+
"""The rollout which this attempt belongs to."""
|
|
141
|
+
attempt_id: str
|
|
142
|
+
"""The universal id for current attempt."""
|
|
143
|
+
sequence_id: int
|
|
144
|
+
"""The sequence number of the attempt, starting from 1."""
|
|
145
|
+
start_time: float
|
|
146
|
+
"""The time when the attempt has started."""
|
|
147
|
+
end_time: Optional[float] = None
|
|
148
|
+
"""The time when the attempt has ended."""
|
|
149
|
+
status: AttemptStatus = "preparing"
|
|
150
|
+
"""The status of the attempt."""
|
|
151
|
+
worker_id: Optional[str] = None
|
|
152
|
+
"""The rollout worker which is executing this attempt."""
|
|
153
|
+
|
|
154
|
+
last_heartbeat_time: Optional[float] = None
|
|
155
|
+
"""The last time when the worker has reported progress (i.e., a span)."""
|
|
156
|
+
|
|
157
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
158
|
+
"""A bucket for any other relevant information."""
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class RolloutConfig(BaseModel):
|
|
162
|
+
"""Configuration controlling rollout retries and timeouts."""
|
|
163
|
+
|
|
164
|
+
timeout_seconds: Optional[float] = None
|
|
165
|
+
"""The timeout for the rollout, in seconds. None indicates no timeout."""
|
|
166
|
+
unresponsive_seconds: Optional[float] = None
|
|
167
|
+
"""The unresponsive timeout for the rollout, in seconds. None indicates no unresponsive timeout."""
|
|
168
|
+
max_attempts: int = Field(default=1, ge=1)
|
|
169
|
+
"""The maximum number of attempts for the rollout, including the first attempt."""
|
|
170
|
+
retry_condition: List[AttemptStatus] = Field(default_factory=cast(Callable[[], List[AttemptStatus]], list))
|
|
171
|
+
"""The list of statuses that should trigger a retry."""
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class Rollout(BaseModel):
|
|
175
|
+
rollout_id: str
|
|
176
|
+
"""Unique identifier for the rollout."""
|
|
177
|
+
|
|
178
|
+
input: TaskInput
|
|
179
|
+
"""Task input used to generate the rollout."""
|
|
180
|
+
|
|
181
|
+
# Time to track the lifecycle of the rollout
|
|
182
|
+
start_time: float
|
|
183
|
+
"""Timestamp when the rollout started."""
|
|
184
|
+
end_time: Optional[float] = None
|
|
185
|
+
"""Timestamp when the rollout ended."""
|
|
186
|
+
|
|
187
|
+
mode: Optional[RolloutMode] = None
|
|
188
|
+
"""Execution mode such as `"train"`, `"val"` or `"test"`. See [`RolloutMode`][mantisdk.RolloutMode]."""
|
|
189
|
+
resources_id: Optional[str] = None
|
|
190
|
+
"""Identifier of the resources required to execute the rollout."""
|
|
191
|
+
|
|
192
|
+
status: RolloutStatus = "queuing"
|
|
193
|
+
"""Latest status emitted by the controller."""
|
|
194
|
+
|
|
195
|
+
config: RolloutConfig = Field(default_factory=RolloutConfig)
|
|
196
|
+
"""Retry and timeout configuration associated with the rollout."""
|
|
197
|
+
|
|
198
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
199
|
+
"""Additional metadata attached to the rollout."""
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class AttemptedRollout(Rollout):
|
|
203
|
+
"""Rollout paired with the currently active attempt."""
|
|
204
|
+
|
|
205
|
+
attempt: Attempt
|
|
206
|
+
"""The attempt that is currently processing the rollout."""
|
|
207
|
+
|
|
208
|
+
@model_validator(mode="after")
|
|
209
|
+
def check_consistency(self) -> AttemptedRollout:
|
|
210
|
+
if self.attempt.rollout_id != self.rollout_id:
|
|
211
|
+
raise ValueError("Inconsistent rollout_id between Rollout and Attempt")
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class EnqueueRolloutRequest(BaseModel):
|
|
216
|
+
"""Payload describing a rollout to be queued via [`enqueue_rollout`][mantisdk.LightningStore.enqueue_rollout].
|
|
217
|
+
|
|
218
|
+
A subset of fields from [`Rollout`][mantisdk.Rollout] used for queuing new rollouts.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
input: TaskInput
|
|
222
|
+
"""Task input used to generate the rollout."""
|
|
223
|
+
mode: Optional[RolloutMode] = None
|
|
224
|
+
"""Execution mode such as `"train"`, `"val"` or `"test"`. See [`RolloutMode`][mantisdk.RolloutMode]."""
|
|
225
|
+
resources_id: Optional[str] = None
|
|
226
|
+
"""Identifier of the resources required to execute the rollout."""
|
|
227
|
+
config: Optional[RolloutConfig] = None
|
|
228
|
+
"""Retry and timeout configuration associated with the rollout."""
|
|
229
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
230
|
+
"""Additional metadata attached to the rollout."""
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
WorkerStatus = Literal["idle", "busy", "unknown"]
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class Worker(BaseModel):
|
|
237
|
+
"""Worker information. This is actually the same as Runner info."""
|
|
238
|
+
|
|
239
|
+
worker_id: str
|
|
240
|
+
"""The ID of the worker."""
|
|
241
|
+
status: WorkerStatus = "unknown"
|
|
242
|
+
"""The status of the worker."""
|
|
243
|
+
heartbeat_stats: Optional[Dict[str, Any]] = None
|
|
244
|
+
"""Statistics about the worker's heartbeat."""
|
|
245
|
+
last_heartbeat_time: Optional[float] = None
|
|
246
|
+
"""The last time when the worker has reported the stats."""
|
|
247
|
+
last_dequeue_time: Optional[float] = None
|
|
248
|
+
"""The last time when the worker has tried to dequeue a rollout."""
|
|
249
|
+
last_busy_time: Optional[float] = None
|
|
250
|
+
"""The last time when the worker has started an attempt and became busy."""
|
|
251
|
+
last_idle_time: Optional[float] = None
|
|
252
|
+
"""The last time when the worker has triggered the end of an attempt and became idle."""
|
|
253
|
+
current_rollout_id: Optional[str] = None
|
|
254
|
+
"""The ID of the current rollout that the worker is processing."""
|
|
255
|
+
current_attempt_id: Optional[str] = None
|
|
256
|
+
"""The ID of the current attempt that the worker is processing."""
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
TaskInput = Any
|
|
260
|
+
"""Task input type. Accepts arbitrary payloads."""
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class Task(BaseModel):
|
|
264
|
+
"""Rollout request served to client agents.
|
|
265
|
+
|
|
266
|
+
!!! warning "Deprecated"
|
|
267
|
+
The legacy HTTP client/server stack still uses this model. Prefer
|
|
268
|
+
[`LightningStore`][mantisdk.LightningStore] APIs for new workflows.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
rollout_id: str
|
|
272
|
+
input: TaskInput
|
|
273
|
+
|
|
274
|
+
mode: Optional[RolloutMode] = None
|
|
275
|
+
resources_id: Optional[str] = None
|
|
276
|
+
|
|
277
|
+
# Optional fields for tracking task lifecycle
|
|
278
|
+
create_time: Optional[float] = None
|
|
279
|
+
last_claim_time: Optional[float] = None
|
|
280
|
+
num_claims: Optional[int] = None
|
|
281
|
+
|
|
282
|
+
# Allow additional metadata fields
|
|
283
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class TaskIfAny(BaseModel):
|
|
287
|
+
"""A task or indication that no task is available.
|
|
288
|
+
|
|
289
|
+
!!! warning "Deprecated"
|
|
290
|
+
Use [`LightningStore`][mantisdk.LightningStore] APIs for new workflows.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
is_available: bool
|
|
294
|
+
"""Indication that a task is available."""
|
|
295
|
+
task: Optional[Task] = None
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
RolloutRawResultLegacy = Union[None, float, List[Triplet], List[Dict[str, Any]], List[ReadableSpan], RolloutLegacy]
|
|
299
|
+
"""Legacy rollout result type.
|
|
300
|
+
|
|
301
|
+
!!! warning "Deprecated"
|
|
302
|
+
Use [`RolloutRawResult`][mantisdk.RolloutRawResult] instead.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
RolloutRawResult = Union[
|
|
306
|
+
None, # nothing (relies on tracer)
|
|
307
|
+
float, # only final reward
|
|
308
|
+
List[ReadableSpan], # constructed OTEL spans by user
|
|
309
|
+
List[Span], # constructed Span objects by user
|
|
310
|
+
List[SpanCoreFields], # constructed SpanCoreFields objects by user
|
|
311
|
+
]
|
|
312
|
+
"""Rollout result type.
|
|
313
|
+
|
|
314
|
+
Possible return values of [`rollout`][mantisdk.LitAgent.rollout].
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class GenericResponse(BaseModel):
|
|
319
|
+
"""Generic server response used by compatibility endpoints.
|
|
320
|
+
|
|
321
|
+
!!! warning "Deprecated"
|
|
322
|
+
This response is no longer used by the new
|
|
323
|
+
[`LightningStore`][mantisdk.LightningStore] APIs.
|
|
324
|
+
|
|
325
|
+
Attributes:
|
|
326
|
+
status: Status string describing the result of the request.
|
|
327
|
+
message: Optional human readable explanation.
|
|
328
|
+
data: Arbitrary payload serialized as JSON.
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
status: str = "success"
|
|
332
|
+
message: Optional[str] = None
|
|
333
|
+
data: Optional[Dict[str, Any]] = None
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class ParallelWorkerBase:
|
|
337
|
+
"""Base class for workloads executed across multiple worker processes.
|
|
338
|
+
|
|
339
|
+
The lifecycle is orchestrated by the main process:
|
|
340
|
+
|
|
341
|
+
* [`init()`][mantisdk.ParallelWorkerBase.init] prepares shared state.
|
|
342
|
+
* Each worker calls [`init_worker()`][mantisdk.ParallelWorkerBase.init_worker] during start-up.
|
|
343
|
+
* [`run()`][mantisdk.ParallelWorkerBase.run] performs the parallel workload.
|
|
344
|
+
* Workers call [`teardown_worker()`][mantisdk.ParallelWorkerBase.teardown_worker] before exiting.
|
|
345
|
+
* The main process finalizes through [`teardown()`][mantisdk.ParallelWorkerBase.teardown].
|
|
346
|
+
|
|
347
|
+
Subclasses must implement [`run()`][mantisdk.ParallelWorkerBase.run]
|
|
348
|
+
and can override other lifecycle hooks.
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
def __init__(self) -> None:
|
|
352
|
+
"""Initialize the base class. This method can be overridden by subclasses."""
|
|
353
|
+
self.worker_id: Optional[int] = None
|
|
354
|
+
|
|
355
|
+
def init(self, *args: Any, **kwargs: Any) -> None:
|
|
356
|
+
"""Initialize before spawning the workers. This method can be overridden by subclasses."""
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
def init_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
|
|
360
|
+
"""Initialize the worker. This method can be overridden by subclasses."""
|
|
361
|
+
self.worker_id = worker_id
|
|
362
|
+
|
|
363
|
+
def run(self, *args: Any, **kwargs: Any) -> Any:
|
|
364
|
+
"""Run the workload. This method can be overridden by subclasses."""
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
|
|
368
|
+
"""Teardown the worker. This method can be overridden by subclasses."""
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
def teardown(self, *args: Any, **kwargs: Any) -> None:
|
|
372
|
+
"""Teardown after the workers have exited. This method can be overridden by subclasses."""
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class Dataset(Protocol, Generic[T_co]):
|
|
377
|
+
"""The general interface for a dataset.
|
|
378
|
+
|
|
379
|
+
It's currently implemented as a protocol, having a similar interface to `torch.utils.data.Dataset`.
|
|
380
|
+
You don't have to inherit from this class; you can use a simple list if you want to.
|
|
381
|
+
"""
|
|
382
|
+
|
|
383
|
+
def __getitem__(self, index: SupportsIndex, /) -> T_co: ...
|
|
384
|
+
|
|
385
|
+
def __len__(self) -> int: ...
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class Hook(ParallelWorkerBase):
|
|
389
|
+
"""Base class for defining hooks in the agent runner's lifecycle."""
|
|
390
|
+
|
|
391
|
+
async def on_trace_start(
|
|
392
|
+
self, *, agent: LitAgent[Any], runner: Runner[Any], tracer: Tracer, rollout: Rollout
|
|
393
|
+
) -> None:
|
|
394
|
+
"""Hook called immediately after the tracer enters the trace context but before the rollout begins.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
|
|
398
|
+
runner: The [`Runner`][mantisdk.Runner] managing the rollout.
|
|
399
|
+
tracer: The [`Tracer`][mantisdk.Tracer] instance associated with the runner.
|
|
400
|
+
rollout: The [`Rollout`][mantisdk.Rollout] object that will be processed.
|
|
401
|
+
|
|
402
|
+
Subclasses can override this method to implement custom logic such as logging,
|
|
403
|
+
metric collection, or resource setup. By default, this is a no-op.
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
async def on_trace_end(
|
|
407
|
+
self, *, agent: LitAgent[Any], runner: Runner[Any], tracer: Tracer, rollout: Rollout
|
|
408
|
+
) -> None:
|
|
409
|
+
"""Hook called immediately after the rollout completes but before the tracer exits the trace context.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
|
|
413
|
+
runner: The [`Runner`][mantisdk.Runner] managing the rollout.
|
|
414
|
+
tracer: The [`Tracer`][mantisdk.Tracer] instance associated with the runner.
|
|
415
|
+
rollout: The [`Rollout`][mantisdk.Rollout] object that has been processed.
|
|
416
|
+
|
|
417
|
+
Subclasses can override this method to implement custom logic such as logging,
|
|
418
|
+
metric collection, or resource cleanup. By default, this is a no-op.
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
async def on_rollout_start(self, *, agent: LitAgent[Any], runner: Runner[Any], rollout: Rollout) -> None:
|
|
422
|
+
"""Hook called immediately before a rollout *attempt* begins.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
|
|
426
|
+
runner: The [`Runner`][mantisdk.Runner] managing the rollout.
|
|
427
|
+
rollout: The [`Rollout`][mantisdk.Rollout] object that will be processed.
|
|
428
|
+
|
|
429
|
+
Subclasses can override this method to implement custom logic such as
|
|
430
|
+
logging, metric collection, or resource setup. By default, this is a
|
|
431
|
+
no-op.
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
async def on_rollout_end(
|
|
435
|
+
self,
|
|
436
|
+
*,
|
|
437
|
+
agent: LitAgent[Any],
|
|
438
|
+
runner: Runner[Any],
|
|
439
|
+
rollout: Rollout,
|
|
440
|
+
spans: Union[List[ReadableSpan], List[Span]],
|
|
441
|
+
) -> None:
|
|
442
|
+
"""Hook called after a rollout *attempt* completes.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
|
|
446
|
+
runner: The [`Runner`][mantisdk.Runner] managing the rollout.
|
|
447
|
+
rollout: The [`Rollout`][mantisdk.Rollout] object that has been processed.
|
|
448
|
+
spans: The spans that have been added to the store.
|
|
449
|
+
|
|
450
|
+
Subclasses can override this method for cleanup or additional
|
|
451
|
+
logging. By default, this is a no-op.
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class FilterField(TypedDict, total=False):
|
|
456
|
+
"""An operator dict for a single field."""
|
|
457
|
+
|
|
458
|
+
exact: Any
|
|
459
|
+
within: Sequence[Any]
|
|
460
|
+
contains: str
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
FilterOptions = Mapping[
|
|
464
|
+
Union[str, Literal["_aggregate", "_must"]],
|
|
465
|
+
Union[FilterField, Literal["and", "or"], Mapping[str, FilterField]],
|
|
466
|
+
]
|
|
467
|
+
"""A mapping of field name -> operator dict.
|
|
468
|
+
|
|
469
|
+
Each operator dict can contain:
|
|
470
|
+
|
|
471
|
+
- "exact": value for exact equality.
|
|
472
|
+
- "within": iterable of allowed values.
|
|
473
|
+
- "contains": substring to search for in string fields.
|
|
474
|
+
|
|
475
|
+
The filter can also have a special field called "_aggregate" that can be used to specify the logic
|
|
476
|
+
to combine the results of the filters:
|
|
477
|
+
|
|
478
|
+
- "and": all conditions must match. This is the default value if not specified.
|
|
479
|
+
- "or": at least one condition must match.
|
|
480
|
+
|
|
481
|
+
All conditions within a field and between different fields are
|
|
482
|
+
stored in a unified pool and combined using `_aggregate`.
|
|
483
|
+
|
|
484
|
+
The filter can also have a special group called "_must", which is a mapping of filters that must all match,
|
|
485
|
+
no matter whether the aggregate logic is "and" or "or".
|
|
486
|
+
|
|
487
|
+
Example:
|
|
488
|
+
|
|
489
|
+
```json
|
|
490
|
+
{
|
|
491
|
+
"_aggregate": "or",
|
|
492
|
+
"_must": {
|
|
493
|
+
"city": {"exact": "New York"},
|
|
494
|
+
"timezone": {"within": ["America/New_York", "America/Los_Angeles"]},
|
|
495
|
+
},
|
|
496
|
+
"status": {"exact": "active"},
|
|
497
|
+
"id": {"within": [1, 2, 3]},
|
|
498
|
+
"name": {"contains": "foo"},
|
|
499
|
+
}
|
|
500
|
+
```
|
|
501
|
+
"""
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class SortOptions(TypedDict):
|
|
505
|
+
"""Options for sorting the collection."""
|
|
506
|
+
|
|
507
|
+
name: str
|
|
508
|
+
"""The name of the field to sort by."""
|
|
509
|
+
order: Literal["asc", "desc"]
|
|
510
|
+
"""The order to sort by."""
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
T_item = TypeVar("T_item")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class PaginatedResult(BaseModel, Sequence[T_item]):
|
|
517
|
+
"""Result of a paginated query.
|
|
518
|
+
|
|
519
|
+
Behaves like a sequence, but also carries pagination metadata (limit, offset, total).
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
items: Sequence[T_item]
|
|
523
|
+
"""Items in the result."""
|
|
524
|
+
limit: int
|
|
525
|
+
"""Limit of the result."""
|
|
526
|
+
offset: int
|
|
527
|
+
"""Offset of the result."""
|
|
528
|
+
total: int
|
|
529
|
+
"""Total number of items in the collection."""
|
|
530
|
+
|
|
531
|
+
def __len__(self) -> int:
|
|
532
|
+
return len(self.items)
|
|
533
|
+
|
|
534
|
+
@overload
|
|
535
|
+
def __getitem__(self, index: int) -> T_item: ...
|
|
536
|
+
|
|
537
|
+
@overload
|
|
538
|
+
def __getitem__(self, index: slice) -> Sequence[T_item]: ...
|
|
539
|
+
|
|
540
|
+
def __getitem__(self, index: Union[int, slice]) -> Union[T_item, Sequence[T_item]]:
|
|
541
|
+
return self.items[index]
|
|
542
|
+
|
|
543
|
+
# Overriding __iter__ enables list(paginated_result) to work as expected,
|
|
544
|
+
# but changes Pydantic's default dict iteration behavior (which would otherwise
|
|
545
|
+
# iterate over field names).
|
|
546
|
+
def __iter__(self) -> Iterator[T_item]: # type: ignore
|
|
547
|
+
return iter(self.items)
|
|
548
|
+
|
|
549
|
+
def __repr__(self) -> str:
|
|
550
|
+
first_item_repr = repr(self.items[0]) if self.items else "empty"
|
|
551
|
+
items_repr = f"[{first_item_repr}, ...]" if len(self.items) > 1 else first_item_repr
|
|
552
|
+
slice_repr = f"{self.offset}:" if self.limit == -1 else f"{self.offset}:{self.offset + self.limit}"
|
|
553
|
+
return f"<PaginatedResult ({slice_repr} of {self.total}) {items_repr}>"
|