mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/types/core.py ADDED
@@ -0,0 +1,553 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ """Core data models shared across Mantisdk components."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Callable,
11
+ Dict,
12
+ Generic,
13
+ Iterator,
14
+ List,
15
+ Literal,
16
+ Mapping,
17
+ Optional,
18
+ Protocol,
19
+ Sequence,
20
+ SupportsIndex,
21
+ TypedDict,
22
+ TypeVar,
23
+ Union,
24
+ cast,
25
+ overload,
26
+ )
27
+
28
+ from opentelemetry.sdk.trace import ReadableSpan
29
+ from pydantic import BaseModel, Field, model_validator
30
+
31
+ from .tracer import Span, SpanCoreFields
32
+
33
+ if TYPE_CHECKING:
34
+ from mantisdk.litagent import LitAgent
35
+ from mantisdk.runner.base import Runner
36
+ from mantisdk.tracer.base import Tracer
37
+
38
+ __all__ = [
39
+ "Triplet",
40
+ "RolloutLegacy",
41
+ "Task",
42
+ "TaskInput",
43
+ "TaskIfAny",
44
+ "RolloutRawResultLegacy",
45
+ "RolloutRawResult",
46
+ "RolloutMode",
47
+ "GenericResponse",
48
+ "ParallelWorkerBase",
49
+ "Dataset",
50
+ "AttemptStatus",
51
+ "RolloutStatus",
52
+ "RolloutConfig",
53
+ "Rollout",
54
+ "Attempt",
55
+ "AttemptedRollout",
56
+ "EnqueueRolloutRequest",
57
+ "Hook",
58
+ "Worker",
59
+ "WorkerStatus",
60
+ "PaginatedResult",
61
+ "FilterOptions",
62
+ "SortOptions",
63
+ "FilterField",
64
+ ]
65
+
66
+ T_co = TypeVar("T_co", covariant=True)
67
+
68
+
69
+ class Triplet(BaseModel):
70
+ """Single interaction turn captured during reinforcement learning."""
71
+
72
+ prompt: Any
73
+ response: Any
74
+ reward: Optional[float] = None
75
+ metadata: Dict[str, Any] = Field(default_factory=dict)
76
+
77
+
78
+ class RolloutLegacy(BaseModel):
79
+ """Legacy reporting payload exchanged with the deprecated HTTP server.
80
+
81
+ !!! warning "Deprecated"
82
+ Use [`Rollout`][mantisdk.Rollout] instead.
83
+ """
84
+
85
+ rollout_id: str
86
+
87
+ # Echoing the input task
88
+ task: Optional[Task] = None
89
+
90
+ # Primary, high-level feedback
91
+ final_reward: Optional[float] = None
92
+
93
+ # Structured, sequential feedback for RL-style optimization
94
+ triplets: Optional[List[Triplet]] = None
95
+
96
+ # Optional, rich-context data for deep analysis
97
+ trace: Optional[List[Dict[str, Any]]] = Field(
98
+ default=None,
99
+ description="A list of spans that conform to the OpenTelemetry JSON format. "
100
+ "Users of the opentelemetry-sdk can generate this by calling "
101
+ "json.loads(readable_span.to_json()).",
102
+ )
103
+ logs: Optional[List[str]] = None
104
+
105
+ # A bucket for any other relevant information
106
+ metadata: Dict[str, Any] = Field(default_factory=dict)
107
+
108
+
109
+ RolloutStatus = Literal[
110
+ "queuing", # initial status
111
+ "preparing", # after the trace is claimed
112
+ "running", # after receiving the first trace
113
+ "failed", # crashed
114
+ "succeeded", # status OK
115
+ "cancelled", # cancelled by user (or watchdog)
116
+ "requeuing", # retrying
117
+ ]
118
+ """The status of a rollout."""
119
+
120
+ AttemptStatus = Literal[
121
+ # A status is essentially a process.
122
+ # It should not have scheduling/management statuses like "queuing" or "cancelled".
123
+ "preparing",
124
+ "running",
125
+ "failed",
126
+ "succeeded",
127
+ "unresponsive", # the worker has not reported results for a while
128
+ "timeout", # the worker has been emitting new logs, but have been working on the task for too long
129
+ ]
130
+ """The status of an attempt."""
131
+
132
+ RolloutMode = Literal["train", "val", "test"]
133
+ """Possible rollout modes."""
134
+
135
+
136
+ class Attempt(BaseModel):
137
+ """Execution attempt for a rollout, including metadata for retries."""
138
+
139
+ rollout_id: str
140
+ """The rollout which this attempt belongs to."""
141
+ attempt_id: str
142
+ """The universal id for current attempt."""
143
+ sequence_id: int
144
+ """The sequence number of the attempt, starting from 1."""
145
+ start_time: float
146
+ """The time when the attempt has started."""
147
+ end_time: Optional[float] = None
148
+ """The time when the attempt has ended."""
149
+ status: AttemptStatus = "preparing"
150
+ """The status of the attempt."""
151
+ worker_id: Optional[str] = None
152
+ """The rollout worker which is executing this attempt."""
153
+
154
+ last_heartbeat_time: Optional[float] = None
155
+ """The last time when the worker has reported progress (i.e., a span)."""
156
+
157
+ metadata: Optional[Dict[str, Any]] = None
158
+ """A bucket for any other relevant information."""
159
+
160
+
161
+ class RolloutConfig(BaseModel):
162
+ """Configuration controlling rollout retries and timeouts."""
163
+
164
+ timeout_seconds: Optional[float] = None
165
+ """The timeout for the rollout, in seconds. None indicates no timeout."""
166
+ unresponsive_seconds: Optional[float] = None
167
+ """The unresponsive timeout for the rollout, in seconds. None indicates no unresponsive timeout."""
168
+ max_attempts: int = Field(default=1, ge=1)
169
+ """The maximum number of attempts for the rollout, including the first attempt."""
170
+ retry_condition: List[AttemptStatus] = Field(default_factory=cast(Callable[[], List[AttemptStatus]], list))
171
+ """The list of statuses that should trigger a retry."""
172
+
173
+
174
+ class Rollout(BaseModel):
175
+ rollout_id: str
176
+ """Unique identifier for the rollout."""
177
+
178
+ input: TaskInput
179
+ """Task input used to generate the rollout."""
180
+
181
+ # Time to track the lifecycle of the rollout
182
+ start_time: float
183
+ """Timestamp when the rollout started."""
184
+ end_time: Optional[float] = None
185
+ """Timestamp when the rollout ended."""
186
+
187
+ mode: Optional[RolloutMode] = None
188
+ """Execution mode such as `"train"`, `"val"` or `"test"`. See [`RolloutMode`][mantisdk.RolloutMode]."""
189
+ resources_id: Optional[str] = None
190
+ """Identifier of the resources required to execute the rollout."""
191
+
192
+ status: RolloutStatus = "queuing"
193
+ """Latest status emitted by the controller."""
194
+
195
+ config: RolloutConfig = Field(default_factory=RolloutConfig)
196
+ """Retry and timeout configuration associated with the rollout."""
197
+
198
+ metadata: Optional[Dict[str, Any]] = None
199
+ """Additional metadata attached to the rollout."""
200
+
201
+
202
+ class AttemptedRollout(Rollout):
203
+ """Rollout paired with the currently active attempt."""
204
+
205
+ attempt: Attempt
206
+ """The attempt that is currently processing the rollout."""
207
+
208
+ @model_validator(mode="after")
209
+ def check_consistency(self) -> AttemptedRollout:
210
+ if self.attempt.rollout_id != self.rollout_id:
211
+ raise ValueError("Inconsistent rollout_id between Rollout and Attempt")
212
+ return self
213
+
214
+
215
+ class EnqueueRolloutRequest(BaseModel):
216
+ """Payload describing a rollout to be queued via [`enqueue_rollout`][mantisdk.LightningStore.enqueue_rollout].
217
+
218
+ A subset of fields from [`Rollout`][mantisdk.Rollout] used for queuing new rollouts.
219
+ """
220
+
221
+ input: TaskInput
222
+ """Task input used to generate the rollout."""
223
+ mode: Optional[RolloutMode] = None
224
+ """Execution mode such as `"train"`, `"val"` or `"test"`. See [`RolloutMode`][mantisdk.RolloutMode]."""
225
+ resources_id: Optional[str] = None
226
+ """Identifier of the resources required to execute the rollout."""
227
+ config: Optional[RolloutConfig] = None
228
+ """Retry and timeout configuration associated with the rollout."""
229
+ metadata: Optional[Dict[str, Any]] = None
230
+ """Additional metadata attached to the rollout."""
231
+
232
+
233
+ WorkerStatus = Literal["idle", "busy", "unknown"]
234
+
235
+
236
+ class Worker(BaseModel):
237
+ """Worker information. This is actually the same as Runner info."""
238
+
239
+ worker_id: str
240
+ """The ID of the worker."""
241
+ status: WorkerStatus = "unknown"
242
+ """The status of the worker."""
243
+ heartbeat_stats: Optional[Dict[str, Any]] = None
244
+ """Statistics about the worker's heartbeat."""
245
+ last_heartbeat_time: Optional[float] = None
246
+ """The last time when the worker has reported the stats."""
247
+ last_dequeue_time: Optional[float] = None
248
+ """The last time when the worker has tried to dequeue a rollout."""
249
+ last_busy_time: Optional[float] = None
250
+ """The last time when the worker has started an attempt and became busy."""
251
+ last_idle_time: Optional[float] = None
252
+ """The last time when the worker has triggered the end of an attempt and became idle."""
253
+ current_rollout_id: Optional[str] = None
254
+ """The ID of the current rollout that the worker is processing."""
255
+ current_attempt_id: Optional[str] = None
256
+ """The ID of the current attempt that the worker is processing."""
257
+
258
+
259
+ TaskInput = Any
260
+ """Task input type. Accepts arbitrary payloads."""
261
+
262
+
263
+ class Task(BaseModel):
264
+ """Rollout request served to client agents.
265
+
266
+ !!! warning "Deprecated"
267
+ The legacy HTTP client/server stack still uses this model. Prefer
268
+ [`LightningStore`][mantisdk.LightningStore] APIs for new workflows.
269
+ """
270
+
271
+ rollout_id: str
272
+ input: TaskInput
273
+
274
+ mode: Optional[RolloutMode] = None
275
+ resources_id: Optional[str] = None
276
+
277
+ # Optional fields for tracking task lifecycle
278
+ create_time: Optional[float] = None
279
+ last_claim_time: Optional[float] = None
280
+ num_claims: Optional[int] = None
281
+
282
+ # Allow additional metadata fields
283
+ metadata: Dict[str, Any] = Field(default_factory=dict)
284
+
285
+
286
+ class TaskIfAny(BaseModel):
287
+ """A task or indication that no task is available.
288
+
289
+ !!! warning "Deprecated"
290
+ Use [`LightningStore`][mantisdk.LightningStore] APIs for new workflows.
291
+ """
292
+
293
+ is_available: bool
294
+ """Indication that a task is available."""
295
+ task: Optional[Task] = None
296
+
297
+
298
+ RolloutRawResultLegacy = Union[None, float, List[Triplet], List[Dict[str, Any]], List[ReadableSpan], RolloutLegacy]
299
+ """Legacy rollout result type.
300
+
301
+ !!! warning "Deprecated"
302
+ Use [`RolloutRawResult`][mantisdk.RolloutRawResult] instead.
303
+ """
304
+
305
+ RolloutRawResult = Union[
306
+ None, # nothing (relies on tracer)
307
+ float, # only final reward
308
+ List[ReadableSpan], # constructed OTEL spans by user
309
+ List[Span], # constructed Span objects by user
310
+ List[SpanCoreFields], # constructed SpanCoreFields objects by user
311
+ ]
312
+ """Rollout result type.
313
+
314
+ Possible return values of [`rollout`][mantisdk.LitAgent.rollout].
315
+ """
316
+
317
+
318
+ class GenericResponse(BaseModel):
319
+ """Generic server response used by compatibility endpoints.
320
+
321
+ !!! warning "Deprecated"
322
+ This response is no longer used by the new
323
+ [`LightningStore`][mantisdk.LightningStore] APIs.
324
+
325
+ Attributes:
326
+ status: Status string describing the result of the request.
327
+ message: Optional human readable explanation.
328
+ data: Arbitrary payload serialized as JSON.
329
+ """
330
+
331
+ status: str = "success"
332
+ message: Optional[str] = None
333
+ data: Optional[Dict[str, Any]] = None
334
+
335
+
336
+ class ParallelWorkerBase:
337
+ """Base class for workloads executed across multiple worker processes.
338
+
339
+ The lifecycle is orchestrated by the main process:
340
+
341
+ * [`init()`][mantisdk.ParallelWorkerBase.init] prepares shared state.
342
+ * Each worker calls [`init_worker()`][mantisdk.ParallelWorkerBase.init_worker] during start-up.
343
+ * [`run()`][mantisdk.ParallelWorkerBase.run] performs the parallel workload.
344
+ * Workers call [`teardown_worker()`][mantisdk.ParallelWorkerBase.teardown_worker] before exiting.
345
+ * The main process finalizes through [`teardown()`][mantisdk.ParallelWorkerBase.teardown].
346
+
347
+ Subclasses must implement [`run()`][mantisdk.ParallelWorkerBase.run]
348
+ and can override other lifecycle hooks.
349
+ """
350
+
351
+ def __init__(self) -> None:
352
+ """Initialize the base class. This method can be overridden by subclasses."""
353
+ self.worker_id: Optional[int] = None
354
+
355
+ def init(self, *args: Any, **kwargs: Any) -> None:
356
+ """Initialize before spawning the workers. This method can be overridden by subclasses."""
357
+ pass
358
+
359
+ def init_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
360
+ """Initialize the worker. This method can be overridden by subclasses."""
361
+ self.worker_id = worker_id
362
+
363
+ def run(self, *args: Any, **kwargs: Any) -> Any:
364
+ """Run the workload. This method can be overridden by subclasses."""
365
+ pass
366
+
367
+ def teardown_worker(self, worker_id: int, *args: Any, **kwargs: Any) -> None:
368
+ """Teardown the worker. This method can be overridden by subclasses."""
369
+ pass
370
+
371
+ def teardown(self, *args: Any, **kwargs: Any) -> None:
372
+ """Teardown after the workers have exited. This method can be overridden by subclasses."""
373
+ pass
374
+
375
+
376
+ class Dataset(Protocol, Generic[T_co]):
377
+ """The general interface for a dataset.
378
+
379
+ It's currently implemented as a protocol, having a similar interface to `torch.utils.data.Dataset`.
380
+ You don't have to inherit from this class; you can use a simple list if you want to.
381
+ """
382
+
383
+ def __getitem__(self, index: SupportsIndex, /) -> T_co: ...
384
+
385
+ def __len__(self) -> int: ...
386
+
387
+
388
+ class Hook(ParallelWorkerBase):
389
+ """Base class for defining hooks in the agent runner's lifecycle."""
390
+
391
+ async def on_trace_start(
392
+ self, *, agent: LitAgent[Any], runner: Runner[Any], tracer: Tracer, rollout: Rollout
393
+ ) -> None:
394
+ """Hook called immediately after the tracer enters the trace context but before the rollout begins.
395
+
396
+ Args:
397
+ agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
398
+ runner: The [`Runner`][mantisdk.Runner] managing the rollout.
399
+ tracer: The [`Tracer`][mantisdk.Tracer] instance associated with the runner.
400
+ rollout: The [`Rollout`][mantisdk.Rollout] object that will be processed.
401
+
402
+ Subclasses can override this method to implement custom logic such as logging,
403
+ metric collection, or resource setup. By default, this is a no-op.
404
+ """
405
+
406
+ async def on_trace_end(
407
+ self, *, agent: LitAgent[Any], runner: Runner[Any], tracer: Tracer, rollout: Rollout
408
+ ) -> None:
409
+ """Hook called immediately after the rollout completes but before the tracer exits the trace context.
410
+
411
+ Args:
412
+ agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
413
+ runner: The [`Runner`][mantisdk.Runner] managing the rollout.
414
+ tracer: The [`Tracer`][mantisdk.Tracer] instance associated with the runner.
415
+ rollout: The [`Rollout`][mantisdk.Rollout] object that has been processed.
416
+
417
+ Subclasses can override this method to implement custom logic such as logging,
418
+ metric collection, or resource cleanup. By default, this is a no-op.
419
+ """
420
+
421
+ async def on_rollout_start(self, *, agent: LitAgent[Any], runner: Runner[Any], rollout: Rollout) -> None:
422
+ """Hook called immediately before a rollout *attempt* begins.
423
+
424
+ Args:
425
+ agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
426
+ runner: The [`Runner`][mantisdk.Runner] managing the rollout.
427
+ rollout: The [`Rollout`][mantisdk.Rollout] object that will be processed.
428
+
429
+ Subclasses can override this method to implement custom logic such as
430
+ logging, metric collection, or resource setup. By default, this is a
431
+ no-op.
432
+ """
433
+
434
+ async def on_rollout_end(
435
+ self,
436
+ *,
437
+ agent: LitAgent[Any],
438
+ runner: Runner[Any],
439
+ rollout: Rollout,
440
+ spans: Union[List[ReadableSpan], List[Span]],
441
+ ) -> None:
442
+ """Hook called after a rollout *attempt* completes.
443
+
444
+ Args:
445
+ agent: The [`LitAgent`][mantisdk.LitAgent] instance associated with the runner.
446
+ runner: The [`Runner`][mantisdk.Runner] managing the rollout.
447
+ rollout: The [`Rollout`][mantisdk.Rollout] object that has been processed.
448
+ spans: The spans that have been added to the store.
449
+
450
+ Subclasses can override this method for cleanup or additional
451
+ logging. By default, this is a no-op.
452
+ """
453
+
454
+
455
+ class FilterField(TypedDict, total=False):
456
+ """An operator dict for a single field."""
457
+
458
+ exact: Any
459
+ within: Sequence[Any]
460
+ contains: str
461
+
462
+
463
+ FilterOptions = Mapping[
464
+ Union[str, Literal["_aggregate", "_must"]],
465
+ Union[FilterField, Literal["and", "or"], Mapping[str, FilterField]],
466
+ ]
467
+ """A mapping of field name -> operator dict.
468
+
469
+ Each operator dict can contain:
470
+
471
+ - "exact": value for exact equality.
472
+ - "within": iterable of allowed values.
473
+ - "contains": substring to search for in string fields.
474
+
475
+ The filter can also have a special field called "_aggregate" that can be used to specify the logic
476
+ to combine the results of the filters:
477
+
478
+ - "and": all conditions must match. This is the default value if not specified.
479
+ - "or": at least one condition must match.
480
+
481
+ All conditions within a field and between different fields are
482
+ stored in a unified pool and combined using `_aggregate`.
483
+
484
+ The filter can also have a special group called "_must", which is a mapping of filters that must all match,
485
+ no matter whether the aggregate logic is "and" or "or".
486
+
487
+ Example:
488
+
489
+ ```json
490
+ {
491
+ "_aggregate": "or",
492
+ "_must": {
493
+ "city": {"exact": "New York"},
494
+ "timezone": {"within": ["America/New_York", "America/Los_Angeles"]},
495
+ },
496
+ "status": {"exact": "active"},
497
+ "id": {"within": [1, 2, 3]},
498
+ "name": {"contains": "foo"},
499
+ }
500
+ ```
501
+ """
502
+
503
+
504
+ class SortOptions(TypedDict):
505
+ """Options for sorting the collection."""
506
+
507
+ name: str
508
+ """The name of the field to sort by."""
509
+ order: Literal["asc", "desc"]
510
+ """The order to sort by."""
511
+
512
+
513
+ T_item = TypeVar("T_item")
514
+
515
+
516
+ class PaginatedResult(BaseModel, Sequence[T_item]):
517
+ """Result of a paginated query.
518
+
519
+ Behaves like a sequence, but also carries pagination metadata (limit, offset, total).
520
+ """
521
+
522
+ items: Sequence[T_item]
523
+ """Items in the result."""
524
+ limit: int
525
+ """Limit of the result."""
526
+ offset: int
527
+ """Offset of the result."""
528
+ total: int
529
+ """Total number of items in the collection."""
530
+
531
+ def __len__(self) -> int:
532
+ return len(self.items)
533
+
534
+ @overload
535
+ def __getitem__(self, index: int) -> T_item: ...
536
+
537
+ @overload
538
+ def __getitem__(self, index: slice) -> Sequence[T_item]: ...
539
+
540
+ def __getitem__(self, index: Union[int, slice]) -> Union[T_item, Sequence[T_item]]:
541
+ return self.items[index]
542
+
543
+ # Overriding __iter__ enables list(paginated_result) to work as expected,
544
+ # but changes Pydantic's default dict iteration behavior (which would otherwise
545
+ # iterate over field names).
546
+ def __iter__(self) -> Iterator[T_item]: # type: ignore
547
+ return iter(self.items)
548
+
549
+ def __repr__(self) -> str:
550
+ first_item_repr = repr(self.items[0]) if self.items else "empty"
551
+ items_repr = f"[{first_item_repr}, ...]" if len(self.items) > 1 else first_item_repr
552
+ slice_repr = f"{self.offset}:" if self.limit == -1 else f"{self.offset}:{self.offset + self.limit}"
553
+ return f"<PaginatedResult ({slice_repr} of {self.total}) {items_repr}>"