mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,2092 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ import re
9
+ import threading
10
+ import time
11
+ import traceback
12
+ from pathlib import Path
13
+ from typing import (
14
+ Any,
15
+ Awaitable,
16
+ Callable,
17
+ Dict,
18
+ List,
19
+ Literal,
20
+ Mapping,
21
+ Optional,
22
+ Sequence,
23
+ Tuple,
24
+ Type,
25
+ TypeVar,
26
+ Union,
27
+ cast,
28
+ )
29
+
30
+ import aiohttp
31
+ from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException
32
+ from fastapi import Query as FastAPIQuery
33
+ from fastapi import Request, Response
34
+ from fastapi.middleware.cors import CORSMiddleware
35
+ from fastapi.responses import FileResponse, JSONResponse
36
+ from fastapi.staticfiles import StaticFiles
37
+ from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
38
+ ExportTraceServiceRequest as PbExportTraceServiceRequest,
39
+ )
40
+ from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
41
+ ExportTraceServiceResponse as PbExportTraceServiceResponse,
42
+ )
43
+ from opentelemetry.sdk.trace import ReadableSpan
44
+ from pydantic import BaseModel, Field, TypeAdapter
45
+
46
+ from mantisdk.types import (
47
+ Attempt,
48
+ AttemptedRollout,
49
+ AttemptStatus,
50
+ EnqueueRolloutRequest,
51
+ NamedResources,
52
+ PaginatedResult,
53
+ ResourcesUpdate,
54
+ Rollout,
55
+ RolloutConfig,
56
+ RolloutStatus,
57
+ Span,
58
+ TaskInput,
59
+ Worker,
60
+ WorkerStatus,
61
+ )
62
+ from mantisdk.utils.metrics import MetricsBackend, get_prometheus_registry
63
+ from mantisdk.utils.otlp import handle_otlp_export, spans_from_proto
64
+ from mantisdk.utils.server_launcher import LaunchMode, PythonServerLauncher, PythonServerLauncherArgs
65
+
66
+ from .base import UNSET, LightningStore, LightningStoreCapabilities, LightningStoreStatistics, Unset
67
+ from .collection.base import resolve_error_type
68
+ from .utils import LATENCY_BUCKETS
69
+
70
+ server_logger = logging.getLogger("mantisdk.store.server")
71
+ client_logger = logging.getLogger("mantisdk.store.client")
72
+
73
+ API_V1_PREFIX = "/v1"
74
+ API_AGL_PREFIX = "/msk"
75
+ API_V1_AGL_PREFIX = API_V1_PREFIX + API_AGL_PREFIX
76
+
77
+ T = TypeVar("T")
78
+ T_model = TypeVar("T_model", bound=BaseModel)
79
+
80
+
81
+ class RolloutRequest(BaseModel):
82
+ input: TaskInput
83
+ mode: Optional[Literal["train", "val", "test"]] = None
84
+ resources_id: Optional[str] = None
85
+ config: Optional[RolloutConfig] = None
86
+ metadata: Optional[Dict[str, Any]] = None
87
+ worker_id: Optional[str] = None
88
+
89
+
90
+ class DequeueRolloutRequest(BaseModel):
91
+ worker_id: Optional[str] = None
92
+
93
+
94
+ class StartAttemptRequest(BaseModel):
95
+ worker_id: Optional[str] = None
96
+
97
+
98
+ class EnqueueManyRolloutsRequest(BaseModel):
99
+ rollouts: List[EnqueueRolloutRequest]
100
+
101
+
102
+ class DequeueManyRolloutsRequest(BaseModel):
103
+ limit: int = 1
104
+ worker_id: Optional[str] = None
105
+
106
+
107
+ class QueryRolloutsRequest(BaseModel):
108
+ status_in: Optional[List[RolloutStatus]] = Field(FastAPIQuery(default=None))
109
+ rollout_id_in: Optional[List[str]] = Field(FastAPIQuery(default=None))
110
+ rollout_id_contains: Optional[str] = None
111
+ # Pagination
112
+ limit: int = -1
113
+ offset: int = 0
114
+ # Sorting
115
+ sort_by: Optional[str] = None
116
+ sort_order: Literal["asc", "desc"] = "asc"
117
+ # Filtering logic
118
+ filter_logic: Literal["and", "or"] = "and"
119
+
120
+
121
+ class WaitForRolloutsRequest(BaseModel):
122
+ rollout_ids: List[str]
123
+ timeout: Optional[float] = None
124
+
125
+
126
+ class NextSequenceIdRequest(BaseModel):
127
+ rollout_id: str
128
+ attempt_id: str
129
+
130
+
131
+ class NextSequenceIdResponse(BaseModel):
132
+ sequence_id: int
133
+
134
+
135
+ class UpdateRolloutRequest(BaseModel):
136
+ input: Optional[TaskInput] = None
137
+ mode: Optional[Literal["train", "val", "test"]] = None
138
+ resources_id: Optional[str] = None
139
+ status: Optional[RolloutStatus] = None
140
+ config: Optional[RolloutConfig] = None
141
+ metadata: Optional[Dict[str, Any]] = None
142
+
143
+
144
+ class UpdateAttemptRequest(BaseModel):
145
+ status: Optional[AttemptStatus] = None
146
+ worker_id: Optional[str] = None
147
+ last_heartbeat_time: Optional[float] = None
148
+ metadata: Optional[Dict[str, Any]] = None
149
+
150
+
151
+ class UpdateWorkerRequest(BaseModel):
152
+ heartbeat_stats: Optional[Dict[str, Any]] = None
153
+
154
+
155
+ class QueryAttemptsRequest(BaseModel):
156
+ # Pagination
157
+ limit: int = -1
158
+ offset: int = 0
159
+ # Sorting
160
+ sort_by: Optional[str] = "sequence_id"
161
+ sort_order: Literal["asc", "desc"] = "asc"
162
+
163
+
164
+ class QueryResourcesRequest(BaseModel):
165
+ # Filtering
166
+ resources_id: Optional[str] = None
167
+ resources_id_contains: Optional[str] = None
168
+ # Pagination
169
+ limit: int = -1
170
+ offset: int = 0
171
+ # Sorting
172
+ sort_by: Optional[str] = None
173
+ sort_order: Literal["asc", "desc"] = "asc"
174
+
175
+
176
+ class QuerySpansRequest(BaseModel):
177
+ rollout_id: str
178
+ attempt_id: Optional[str] = None
179
+ # Filtering
180
+ trace_id: Optional[str] = None
181
+ trace_id_contains: Optional[str] = None
182
+ span_id: Optional[str] = None
183
+ span_id_contains: Optional[str] = None
184
+ parent_id: Optional[str] = None
185
+ parent_id_contains: Optional[str] = None
186
+ name: Optional[str] = None
187
+ name_contains: Optional[str] = None
188
+ filter_logic: Literal["and", "or"] = "and"
189
+ # Pagination
190
+ limit: int = -1
191
+ offset: int = 0
192
+ # Sorting
193
+ sort_by: Optional[str] = "sequence_id"
194
+ sort_order: Literal["asc", "desc"] = "asc"
195
+
196
+
197
+ class QueryWorkersRequest(BaseModel):
198
+ status_in: Optional[List[WorkerStatus]] = Field(FastAPIQuery(default=None))
199
+ worker_id_contains: Optional[str] = None
200
+ # Pagination
201
+ limit: int = -1
202
+ offset: int = 0
203
+ # Sorting
204
+ sort_by: Optional[str] = None
205
+ sort_order: Literal["asc", "desc"] = "asc"
206
+ # Filtering logic
207
+ filter_logic: Literal["and", "or"] = "and"
208
+
209
+
210
+ class CachedStaticFiles(StaticFiles):
211
+ def file_response(self, *args: Any, **kwargs: Any) -> Response:
212
+ resp = super().file_response(*args, **kwargs)
213
+ # hashed filenames are safe to cache "forever"
214
+ resp.headers.setdefault("Cache-Control", "public, max-age=31536000, immutable")
215
+ return resp
216
+
217
+
218
+ class LightningStoreServer(LightningStore):
219
+ """
220
+ Server wrapper that exposes a LightningStore via HTTP API.
221
+ Delegates all operations to an underlying store implementation.
222
+
223
+ Healthcheck and watchdog relies on the underlying store.
224
+
225
+ `msk store` is a convenient CLI to start a store server.
226
+
227
+ When the server is executed in a subprocess, the store will discover itself having a different PID
228
+ and automatically delegate to an HTTP client instead of using the local store.
229
+ This ensures one single copy of the store will be shared across all processes.
230
+
231
+ This server exporting OTLP-compatible traces via the `/v1/traces` endpoint.
232
+
233
+ Args:
234
+ store: The underlying store to delegate operations to.
235
+ host: The hostname or IP address to bind the server to.
236
+ port: The TCP port to listen on.
237
+ cors_allow_origins: A list of CORS origins to allow. Use '*' to allow all origins.
238
+ launch_mode: The launch mode to use for the server. Defaults to "thread",
239
+ which runs the server in a separate thread.
240
+ launcher_args: The arguments to use for the server launcher.
241
+ It's not allowed to set `host`, `port`, `launch_mode` together with `launcher_args`.
242
+ n_workers: The number of workers to run in the server. Only applicable for `mp` launch mode.
243
+ tracker: The metrics tracker to use for the server.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ store: LightningStore,
249
+ host: str | None = None,
250
+ port: int | None = None,
251
+ cors_allow_origins: Sequence[str] | str | None = None,
252
+ launch_mode: LaunchMode = "thread",
253
+ launcher_args: PythonServerLauncherArgs | None = None,
254
+ n_workers: int = 1,
255
+ tracker: MetricsBackend | None = None,
256
+ ):
257
+ super().__init__()
258
+ self.store = store
259
+
260
+ if launcher_args is not None:
261
+ if host is not None or port is not None or launch_mode != "thread":
262
+ raise ValueError("host, port, and launch_mode cannot be set when launcher_args is provided.")
263
+ self.launcher_args = launcher_args
264
+ else:
265
+ if port is None:
266
+ server_logger.warning("No port provided, using default port 4747.")
267
+ port = 4747
268
+ self.launcher_args = PythonServerLauncherArgs(
269
+ host=host,
270
+ port=port,
271
+ launch_mode=launch_mode,
272
+ healthcheck_url=API_V1_AGL_PREFIX + "/health",
273
+ n_workers=n_workers,
274
+ )
275
+
276
+ store_capabilities = self.store.capabilities
277
+ if not store_capabilities.get("async_safe", False):
278
+ raise ValueError("The store is not async-safe. Please use another store for the server.")
279
+ if self.launcher_args.launch_mode == "mp" and not store_capabilities.get("zero_copy", False):
280
+ raise ValueError(
281
+ "The store does not support zero-copy. Please use another store, or use asyncio or thread mode to launch the server."
282
+ )
283
+ if self.launcher_args.launch_mode == "thread" and not store_capabilities.get("thread_safe", False):
284
+ server_logger.warning(
285
+ "The store is not thread-safe. Please be careful when using the store server and the underlying store in different threads."
286
+ )
287
+
288
+ self.app: FastAPI | None = FastAPI(title="LightningStore Server")
289
+ self.server_launcher = PythonServerLauncher(
290
+ app=self.app,
291
+ args=self.launcher_args,
292
+ )
293
+ self._tracker = tracker
294
+
295
+ self._lock: threading.Lock = threading.Lock()
296
+ self._cors_allow_origins = self._normalize_cors_origins(cors_allow_origins)
297
+ self._apply_cors()
298
+ self._setup_routes()
299
+
300
+ # Process-awareness:
301
+ # LightningStoreServer holds a plain Python object (self.store) in one process
302
+ # (the process that runs uvicorn/FastAPI).
303
+ # When you multiprocessing.Process(...) and call methods on a different LightningStore instance
304
+ # (or on a copy inherited via fork), you’re mutating another process’s memory, not the server’s memory.
305
+ # So we need to track the owner process (whoever creates the server),
306
+ # and only mutate the store in that process.
307
+ self._owner_pid = os.getpid()
308
+ self._client: Optional[LightningStoreClient] = None
309
+
310
+ @property
311
+ def capabilities(self) -> LightningStoreCapabilities:
312
+ """Return the capabilities of the store."""
313
+ return LightningStoreCapabilities(
314
+ async_safe=True,
315
+ thread_safe=True,
316
+ zero_copy=True,
317
+ otlp_traces=True,
318
+ )
319
+
320
+ def otlp_traces_endpoint(self) -> str:
321
+ """Return the OTLP/HTTP traces endpoint of the store."""
322
+ return f"{self.endpoint}/v1/traces"
323
+
324
+ def __getstate__(self):
325
+ """
326
+ Control pickling to prevent server state from being sent to subprocesses.
327
+
328
+ When LightningStoreServer is pickled (e.g., passed to a subprocess), we only
329
+ serialize the underlying store and connection details. The client instance
330
+ and process-awareness state are excluded as they should not be transferred between processes.
331
+
332
+ The subprocess should create its own server instance if needed.
333
+ """
334
+ # server-launcher is needed for the host/port address are propagated to the subprocess
335
+ return {
336
+ "launcher_args": self.launcher_args,
337
+ "server_launcher": self.server_launcher,
338
+ "_owner_pid": self._owner_pid,
339
+ }
340
+
341
+ def __setstate__(self, state: Dict[str, Any]):
342
+ """
343
+ Restore from pickle by reconstructing only the essential attributes.
344
+
345
+ Note: This creates a new server instance without FastAPI/uvicorn initialized.
346
+ Call __init__() pattern or create a new LightningStoreServer if you need
347
+ a fully functional server in the subprocess.
348
+ The unpickled server will also have no app and store attributes,
349
+ this is to make sure there is only one copy of the server in the whole system.
350
+ """
351
+ self.app = None
352
+ self.store = None
353
+ self.launcher_args = state["launcher_args"]
354
+ self.server_launcher = state["server_launcher"]
355
+ self._tracker = None
356
+ self._owner_pid = state["_owner_pid"]
357
+ self._cors_allow_origins = state.get("_cors_allow_origins")
358
+ self._client = None
359
+ self._lock = threading.Lock()
360
+ self._prometheus_registry = None
361
+ # Do NOT reconstruct app, _uvicorn_config, _uvicorn_server
362
+ # to avoid transferring server state to subprocess
363
+
364
+ @staticmethod
365
+ def _normalize_cors_origins(
366
+ origins: Sequence[str] | str | None,
367
+ ) -> list[str] | None:
368
+ if origins is None:
369
+ return None
370
+
371
+ if isinstance(origins, str):
372
+ candidates = [origins]
373
+ else:
374
+ candidates = list(origins)
375
+
376
+ cleaned: list[str] = []
377
+ for origin in candidates:
378
+ if not origin or not origin.strip():
379
+ continue
380
+ value = origin.strip()
381
+ if value == "*":
382
+ return ["*"]
383
+ cleaned.append(value)
384
+
385
+ return cleaned or None
386
+
387
+ def _apply_cors(self) -> None:
388
+ if self.app is None or not self._cors_allow_origins:
389
+ return
390
+
391
+ self.app.add_middleware(
392
+ CORSMiddleware,
393
+ allow_origins=self._cors_allow_origins.copy(),
394
+ allow_methods=["*"],
395
+ allow_headers=["*"],
396
+ allow_credentials=True,
397
+ expose_headers=["*"],
398
+ )
399
+
400
+ @property
401
+ def endpoint(self) -> str:
402
+ """Endpoint is the address that the client will use to connect to the server."""
403
+ return self.server_launcher.access_endpoint
404
+
405
+ async def start(self):
406
+ """Starts the FastAPI server in the background.
407
+
408
+ You need to call this method in the same process as the server was created in.
409
+ """
410
+ server_logger.info(
411
+ f"Serving the lightning store at {self.server_launcher.endpoint}, accessible at {self.server_launcher.access_endpoint}"
412
+ )
413
+
414
+ start_time = time.time()
415
+ await self.server_launcher.start()
416
+ end_time = time.time()
417
+ server_logger.info(f"Lightning store server started in {end_time - start_time:.2f} seconds")
418
+
419
+ async def run_forever(self):
420
+ """Runs the FastAPI server indefinitely."""
421
+ server_logger.info(
422
+ f"Running the lightning store server at {self.server_launcher.endpoint}, accessible at {self.server_launcher.access_endpoint}"
423
+ )
424
+ await self.server_launcher.run_forever()
425
+
426
+ async def stop(self):
427
+ """Gracefully stops the running FastAPI server.
428
+
429
+ You need to call this method in the same process as the server was created in.
430
+ """
431
+ server_logger.info("Stopping the lightning store server...")
432
+ await self.server_launcher.stop()
433
+ server_logger.info("Lightning store server stopped.")
434
+
435
+ def _setup_routes(self):
436
+ """Set up FastAPI routes for all store operations."""
437
+ assert self.app is not None
438
+ api = APIRouter(prefix=API_V1_PREFIX)
439
+
440
+ # The outermost-layer of monitoring
441
+ if self._tracker is not None:
442
+ self._setup_metrics(api=api, app=self.app)
443
+
444
+ # TODO: This should only be enabled in development mode.
445
+ @self.app.middleware("http")
446
+ async def _app_exception_handler( # pyright: ignore[reportUnusedFunction]
447
+ request: Request, call_next: Callable[[Request], Awaitable[Response]]
448
+ ) -> Response:
449
+ """
450
+ Convert unhandled application exceptions into 500 responses.
451
+
452
+ Only covers /v1/msk requests.
453
+
454
+ - Client needs a reliable signal to distinguish "app bug / bad request"
455
+ from transport/session failures.
456
+ - 400 means "do not retry"; network issues will surface as aiohttp
457
+ exceptions or 5xx and will be retried by the client shield.
458
+ """
459
+ try:
460
+ return await call_next(request)
461
+ except Exception as exc:
462
+ # decide whether to convert this into your 400 JSONResponse
463
+ if request.url.path.startswith(API_V1_AGL_PREFIX):
464
+ server_logger.exception("Unhandled application error", exc_info=exc)
465
+ payload = {
466
+ "detail": "Internal server error",
467
+ "error_type": type(exc).__name__,
468
+ "traceback": traceback.format_exc(),
469
+ }
470
+ # 500 so clients can decide to retry
471
+ return JSONResponse(status_code=500, content=payload)
472
+ # otherwise re-raise and let FastAPI/Starlette handle it (500 or other handlers)
473
+ raise
474
+
475
+ @self.app.middleware("http")
476
+ async def _log_time( # pyright: ignore[reportUnusedFunction]
477
+ request: Request, call_next: Callable[[Request], Awaitable[Response]]
478
+ ):
479
+ # If not API request, just pass through
480
+ if not request.url.path.startswith(API_V1_AGL_PREFIX) and not request.url.path.startswith(
481
+ API_V1_PREFIX + "/traces"
482
+ ):
483
+ return await call_next(request)
484
+
485
+ start = time.perf_counter()
486
+ response = await call_next(request)
487
+ duration = (time.perf_counter() - start) * 1000
488
+ client = request.client
489
+ if client is None:
490
+ client_address = "unknown"
491
+ else:
492
+ client_address = f"{client.host}:{client.port}"
493
+ server_logger.debug(
494
+ f"{client_address} - "
495
+ f'"{request.method} {request.url.path} HTTP/{request.scope["http_version"]}" '
496
+ f"{response.status_code} in {duration:.2f} ms"
497
+ )
498
+ return response
499
+
500
+ def _validate_paginated_request(
501
+ request: Union[
502
+ QueryRolloutsRequest,
503
+ QueryAttemptsRequest,
504
+ QueryResourcesRequest,
505
+ QueryWorkersRequest,
506
+ QuerySpansRequest,
507
+ ],
508
+ target_type: Type[T_model],
509
+ ) -> None:
510
+ """Raise an error early if the request is not a valid paginated request."""
511
+ if request.sort_by is not None and request.sort_by not in target_type.model_fields:
512
+ raise HTTPException(
513
+ status_code=400,
514
+ detail=f"Invalid sort_by: {request.sort_by}, allowed fields are: {', '.join(target_type.model_fields.keys())}",
515
+ )
516
+ if request.sort_order not in ["asc", "desc"]:
517
+ raise HTTPException(
518
+ status_code=400, detail=f"Invalid sort_order: {request.sort_order}, allowed values are: asc, desc"
519
+ )
520
+ if request.limit == 0 or (request.limit < 0 and request.limit != -1):
521
+ raise HTTPException(status_code=400, detail="Limit must be greater than 0 or -1 for no limit")
522
+ if not request.offset >= 0:
523
+ raise HTTPException(status_code=400, detail="Offset must be greater than or equal to 0")
524
+ if hasattr(request, "filter_logic") and request.filter_logic not in ["and", "or"]: # type: ignore
525
+ raise HTTPException(
526
+ status_code=400, detail=f"Invalid filter_logic: {request.filter_logic}, allowed values are: and, or" # type: ignore
527
+ )
528
+
529
+ def _build_paginated_response(items: Sequence[Any], *, limit: int, offset: int) -> PaginatedResult[Any]:
530
+ """FastAPI routes expect PaginatedResult payloads; wrap plain lists accordingly."""
531
+ if isinstance(items, PaginatedResult):
532
+ return items
533
+
534
+ # Assuming it's a list.
535
+ server_logger.warning(
536
+ "PaginatedResult expected; got a plain list. Converting to PaginatedResult. "
537
+ "Total items count will be inaccurate: %d",
538
+ len(items),
539
+ )
540
+ return PaginatedResult(items=items, limit=limit, offset=offset, total=len(items))
541
+
542
+ @api.get(API_AGL_PREFIX + "/health")
543
+ async def health(): # pyright: ignore[reportUnusedFunction]
544
+ return {"status": "ok"}
545
+
546
+ @api.post(API_AGL_PREFIX + "/queues/rollouts/enqueue", status_code=201, response_model=List[Rollout])
547
+ async def enqueue_rollouts( # pyright: ignore[reportUnusedFunction]
548
+ request: EnqueueManyRolloutsRequest,
549
+ ) -> List[Rollout]:
550
+ enqueue_requests = request.rollouts
551
+ if not enqueue_requests:
552
+ return []
553
+ if len(enqueue_requests) == 1:
554
+ single = enqueue_requests[0]
555
+ rollout = await self.enqueue_rollout(
556
+ input=single.input,
557
+ mode=single.mode,
558
+ resources_id=single.resources_id,
559
+ config=single.config,
560
+ metadata=single.metadata,
561
+ )
562
+ return [rollout]
563
+ rollouts = await self.enqueue_many_rollouts(enqueue_requests)
564
+ return list(rollouts)
565
+
566
+ @api.post(API_AGL_PREFIX + "/queues/rollouts/dequeue", response_model=List[AttemptedRollout])
567
+ async def dequeue_rollouts( # pyright: ignore[reportUnusedFunction]
568
+ request: DequeueManyRolloutsRequest | None = Body(None),
569
+ ) -> List[AttemptedRollout]:
570
+ payload = request or DequeueManyRolloutsRequest()
571
+ if payload.limit <= 0:
572
+ return []
573
+ if payload.limit == 1:
574
+ single = await self.dequeue_rollout(worker_id=payload.worker_id)
575
+ return [single] if single else []
576
+ rollouts = await self.dequeue_many_rollouts(limit=payload.limit, worker_id=payload.worker_id)
577
+ return list(rollouts)
578
+
579
+ @api.post(API_AGL_PREFIX + "/rollouts", status_code=201, response_model=AttemptedRollout)
580
+ async def start_rollout(request: RolloutRequest): # pyright: ignore[reportUnusedFunction]
581
+ return await self.start_rollout(
582
+ input=request.input,
583
+ mode=request.mode,
584
+ resources_id=request.resources_id,
585
+ config=request.config,
586
+ metadata=request.metadata,
587
+ worker_id=request.worker_id,
588
+ )
589
+
590
+ @api.get(API_AGL_PREFIX + "/rollouts", response_model=PaginatedResult[Union[AttemptedRollout, Rollout]])
591
+ async def query_rollouts(params: QueryRolloutsRequest = Depends()): # pyright: ignore[reportUnusedFunction]
592
+ _validate_paginated_request(params, Rollout)
593
+ # Get all rollouts from the underlying store
594
+ results = await self.query_rollouts(
595
+ status_in=params.status_in,
596
+ rollout_id_in=params.rollout_id_in,
597
+ rollout_id_contains=params.rollout_id_contains,
598
+ filter_logic=params.filter_logic,
599
+ sort_by=params.sort_by,
600
+ sort_order=params.sort_order,
601
+ limit=params.limit,
602
+ offset=params.offset,
603
+ )
604
+ return _build_paginated_response(results, limit=params.limit, offset=params.offset)
605
+
606
+ @api.post(API_AGL_PREFIX + "/rollouts/search", response_model=PaginatedResult[Union[AttemptedRollout, Rollout]])
607
+ async def search_rollouts(request: QueryRolloutsRequest): # pyright: ignore[reportUnusedFunction]
608
+ _validate_paginated_request(request, Rollout)
609
+ status_in = request.status_in if "status_in" in request.model_fields_set else None
610
+ rollout_id_in = request.rollout_id_in if "rollout_id_in" in request.model_fields_set else None
611
+ # Get all rollouts from the underlying store
612
+ results = await self.query_rollouts(
613
+ status_in=status_in,
614
+ rollout_id_in=rollout_id_in,
615
+ rollout_id_contains=request.rollout_id_contains,
616
+ filter_logic=request.filter_logic,
617
+ sort_by=request.sort_by,
618
+ sort_order=request.sort_order,
619
+ limit=request.limit,
620
+ offset=request.offset,
621
+ )
622
+ return _build_paginated_response(results, limit=request.limit, offset=request.offset)
623
+
624
+ @api.get(API_AGL_PREFIX + "/rollouts/{rollout_id}", response_model=Union[AttemptedRollout, Rollout])
625
+ async def get_rollout_by_id(rollout_id: str): # pyright: ignore[reportUnusedFunction]
626
+ return await self.get_rollout_by_id(rollout_id)
627
+
628
+ def _get_mandatory_field_or_unset(request: BaseModel | None, field: str) -> Any:
629
+ # If some fields are mandatory by the underlying store, but optional in the FastAPI,
630
+ # we make sure it's set to non-null value or UNSET via this function.
631
+ if request is None:
632
+ return UNSET
633
+ if field in request.model_fields_set:
634
+ value = getattr(request, field)
635
+ if value is None:
636
+ raise HTTPException(status_code=400, detail=f"{field} is invalid; it cannot be a null value.")
637
+ return value
638
+ else:
639
+ return UNSET
640
+
641
+ @api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}", response_model=Rollout)
642
+ async def update_rollout( # pyright: ignore[reportUnusedFunction]
643
+ rollout_id: str, request: UpdateRolloutRequest = Body(...)
644
+ ):
645
+ return await self.update_rollout(
646
+ rollout_id=rollout_id,
647
+ input=request.input if "input" in request.model_fields_set else UNSET,
648
+ mode=request.mode if "mode" in request.model_fields_set else UNSET,
649
+ resources_id=request.resources_id if "resources_id" in request.model_fields_set else UNSET,
650
+ status=_get_mandatory_field_or_unset(request, "status"),
651
+ config=_get_mandatory_field_or_unset(request, "config"),
652
+ metadata=request.metadata if "metadata" in request.model_fields_set else UNSET,
653
+ )
654
+
655
+ @api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts", status_code=201, response_model=AttemptedRollout)
656
+ async def start_attempt( # pyright: ignore[reportUnusedFunction]
657
+ rollout_id: str, request: StartAttemptRequest | None = Body(None)
658
+ ):
659
+ worker_id = request.worker_id if request else None
660
+ return await self.start_attempt(rollout_id, worker_id=worker_id)
661
+
662
+ @api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts/search", response_model=PaginatedResult[Attempt])
663
+ async def search_attempts( # pyright: ignore[reportUnusedFunction]
664
+ rollout_id: str, request: QueryAttemptsRequest
665
+ ):
666
+ _validate_paginated_request(request, Attempt)
667
+ attempts = await self.query_attempts(
668
+ rollout_id,
669
+ sort_by=request.sort_by,
670
+ sort_order=request.sort_order,
671
+ limit=request.limit,
672
+ offset=request.offset,
673
+ )
674
+ return _build_paginated_response(attempts, limit=request.limit, offset=request.offset)
675
+
676
+ @api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts/{attempt_id}", response_model=Attempt)
677
+ async def update_attempt( # pyright: ignore[reportUnusedFunction]
678
+ rollout_id: str, attempt_id: str, request: UpdateAttemptRequest = Body(...)
679
+ ):
680
+ return await self.update_attempt(
681
+ rollout_id=rollout_id,
682
+ attempt_id=attempt_id,
683
+ status=_get_mandatory_field_or_unset(request, "status"),
684
+ worker_id=_get_mandatory_field_or_unset(request, "worker_id"),
685
+ last_heartbeat_time=_get_mandatory_field_or_unset(request, "last_heartbeat_time"),
686
+ metadata=_get_mandatory_field_or_unset(request, "metadata"),
687
+ )
688
+
689
+ @api.get(API_AGL_PREFIX + "/workers", response_model=PaginatedResult[Worker])
690
+ async def query_workers(params: QueryWorkersRequest = Depends()): # pyright: ignore[reportUnusedFunction]
691
+ _validate_paginated_request(params, Worker)
692
+ workers = await self.query_workers(
693
+ status_in=params.status_in,
694
+ worker_id_contains=params.worker_id_contains,
695
+ filter_logic=params.filter_logic,
696
+ sort_by=params.sort_by,
697
+ sort_order=params.sort_order,
698
+ limit=params.limit,
699
+ offset=params.offset,
700
+ )
701
+ return _build_paginated_response(workers, limit=params.limit, offset=params.offset)
702
+
703
+ @api.post(API_AGL_PREFIX + "/workers/search", response_model=PaginatedResult[Worker])
704
+ async def search_workers(request: QueryWorkersRequest): # pyright: ignore[reportUnusedFunction]
705
+ _validate_paginated_request(request, Worker)
706
+ status_in = request.status_in if "status_in" in request.model_fields_set else None
707
+ workers = await self.query_workers(
708
+ status_in=status_in,
709
+ worker_id_contains=request.worker_id_contains,
710
+ filter_logic=request.filter_logic,
711
+ sort_by=request.sort_by,
712
+ sort_order=request.sort_order,
713
+ limit=request.limit,
714
+ offset=request.offset,
715
+ )
716
+ return _build_paginated_response(workers, limit=request.limit, offset=request.offset)
717
+
718
+ @api.get(API_AGL_PREFIX + "/workers/{worker_id}", response_model=Optional[Worker])
719
+ async def get_worker(worker_id: str): # pyright: ignore[reportUnusedFunction]
720
+ return await self.get_worker_by_id(worker_id)
721
+
722
+ @api.post(API_AGL_PREFIX + "/workers/{worker_id}", response_model=Worker)
723
+ async def update_worker( # pyright: ignore[reportUnusedFunction]
724
+ worker_id: str, request: UpdateWorkerRequest | None = Body(None)
725
+ ):
726
+ return await self.update_worker(
727
+ worker_id=worker_id,
728
+ heartbeat_stats=_get_mandatory_field_or_unset(request, "heartbeat_stats"),
729
+ )
730
+
731
+ @api.get(API_AGL_PREFIX + "/statistics", response_model=Dict[str, Any])
732
+ async def get_statistics(): # pyright: ignore[reportUnusedFunction]
733
+ return await self.statistics()
734
+
735
+ @api.get(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts", response_model=PaginatedResult[Attempt])
736
+ async def query_attempts( # pyright: ignore[reportUnusedFunction]
737
+ rollout_id: str, params: QueryAttemptsRequest = Depends()
738
+ ):
739
+ _validate_paginated_request(params, Attempt)
740
+ attempts = await self.query_attempts(
741
+ rollout_id,
742
+ sort_by=params.sort_by,
743
+ sort_order=params.sort_order,
744
+ limit=params.limit,
745
+ offset=params.offset,
746
+ )
747
+ return _build_paginated_response(attempts, limit=params.limit, offset=params.offset)
748
+
749
+ @api.get(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts/latest", response_model=Optional[Attempt])
750
+ async def get_latest_attempt(rollout_id: str): # pyright: ignore[reportUnusedFunction]
751
+ return await self.get_latest_attempt(rollout_id)
752
+
753
+ @api.get(API_AGL_PREFIX + "/resources", response_model=PaginatedResult[ResourcesUpdate])
754
+ async def query_resources(params: QueryResourcesRequest = Depends()): # pyright: ignore[reportUnusedFunction]
755
+ _validate_paginated_request(params, ResourcesUpdate)
756
+ resources = await self.query_resources(
757
+ resources_id=params.resources_id,
758
+ resources_id_contains=params.resources_id_contains,
759
+ sort_by=params.sort_by,
760
+ sort_order=params.sort_order,
761
+ limit=params.limit,
762
+ offset=params.offset,
763
+ )
764
+ return _build_paginated_response(resources, limit=params.limit, offset=params.offset)
765
+
766
+ @api.post(API_AGL_PREFIX + "/resources", status_code=201, response_model=ResourcesUpdate)
767
+ async def add_resources(resources: NamedResources): # pyright: ignore[reportUnusedFunction]
768
+ return await self.add_resources(resources)
769
+
770
+ @api.get(API_AGL_PREFIX + "/resources/latest", response_model=Optional[ResourcesUpdate])
771
+ async def get_latest_resources(): # pyright: ignore[reportUnusedFunction]
772
+ return await self.get_latest_resources()
773
+
774
+ @api.post(API_AGL_PREFIX + "/resources/{resources_id}", response_model=ResourcesUpdate)
775
+ async def update_resources( # pyright: ignore[reportUnusedFunction]
776
+ resources_id: str, resources: NamedResources
777
+ ):
778
+ return await self.update_resources(resources_id, resources)
779
+
780
+ @api.get(API_AGL_PREFIX + "/resources/{resources_id}", response_model=Optional[ResourcesUpdate])
781
+ async def get_resources_by_id(resources_id: str): # pyright: ignore[reportUnusedFunction]
782
+ return await self.get_resources_by_id(resources_id)
783
+
784
+ @api.post(API_AGL_PREFIX + "/spans", status_code=201, response_model=Optional[Span])
785
+ async def add_span(span: Span): # pyright: ignore[reportUnusedFunction]
786
+ return await self.add_span(span)
787
+
788
+ @api.get(API_AGL_PREFIX + "/spans", response_model=PaginatedResult[Span])
789
+ async def query_spans(params: QuerySpansRequest = Depends()): # pyright: ignore[reportUnusedFunction]
790
+ _validate_paginated_request(params, Span)
791
+ spans = await self.query_spans(
792
+ params.rollout_id,
793
+ params.attempt_id,
794
+ trace_id=params.trace_id,
795
+ trace_id_contains=params.trace_id_contains,
796
+ span_id=params.span_id,
797
+ span_id_contains=params.span_id_contains,
798
+ parent_id=params.parent_id,
799
+ parent_id_contains=params.parent_id_contains,
800
+ name=params.name,
801
+ name_contains=params.name_contains,
802
+ filter_logic=params.filter_logic,
803
+ sort_by=params.sort_by,
804
+ sort_order=params.sort_order,
805
+ limit=params.limit,
806
+ offset=params.offset,
807
+ )
808
+ return _build_paginated_response(spans, limit=params.limit, offset=params.offset)
809
+
810
+ @api.post(API_AGL_PREFIX + "/spans/search", response_model=PaginatedResult[Span])
811
+ async def search_spans(request: QuerySpansRequest): # pyright: ignore[reportUnusedFunction]
812
+ _validate_paginated_request(request, Span)
813
+ spans = await self.query_spans(
814
+ request.rollout_id,
815
+ request.attempt_id,
816
+ trace_id=request.trace_id,
817
+ trace_id_contains=request.trace_id_contains,
818
+ span_id=request.span_id,
819
+ span_id_contains=request.span_id_contains,
820
+ parent_id=request.parent_id,
821
+ parent_id_contains=request.parent_id_contains,
822
+ name=request.name,
823
+ name_contains=request.name_contains,
824
+ filter_logic=request.filter_logic,
825
+ sort_by=request.sort_by,
826
+ sort_order=request.sort_order,
827
+ limit=request.limit,
828
+ offset=request.offset,
829
+ )
830
+ return _build_paginated_response(spans, limit=request.limit, offset=request.offset)
831
+
832
+ @api.post(API_AGL_PREFIX + "/spans/next", response_model=NextSequenceIdResponse)
833
+ async def get_next_span_sequence_id(request: NextSequenceIdRequest): # pyright: ignore[reportUnusedFunction]
834
+ sequence_id = await self.get_next_span_sequence_id(request.rollout_id, request.attempt_id)
835
+ return NextSequenceIdResponse(sequence_id=sequence_id)
836
+
837
+ @api.post(API_AGL_PREFIX + "/waits/rollouts", response_model=List[Rollout])
838
+ async def wait_for_rollouts(request: WaitForRolloutsRequest): # pyright: ignore[reportUnusedFunction]
839
+ return await self.wait_for_rollouts(rollout_ids=request.rollout_ids, timeout=request.timeout)
840
+
841
+ # Setup OTLP endpoints
842
+ self._setup_otlp(api)
843
+
844
+ # Mount the API router of /v1/...
845
+ self.app.include_router(api)
846
+
847
+ # Finally, mount the dashboard assets
848
+ self._setup_dashboard()
849
+
850
+ def _setup_metrics(self, api: APIRouter, app: FastAPI):
851
+ """Setup Prometheus metrics endpoints."""
852
+ if self._tracker is None:
853
+ return
854
+
855
+ self._tracker.register_counter(
856
+ "msk.http.total",
857
+ ["path", "method", "status"],
858
+ group_level=2,
859
+ )
860
+ self._tracker.register_histogram(
861
+ "msk.http.latency",
862
+ ["path", "method", "status"],
863
+ buckets=LATENCY_BUCKETS,
864
+ group_level=2,
865
+ )
866
+
867
+ def get_template_path(path: str) -> str:
868
+ # Handle "latest" keywords BEFORE generic IDs
869
+ if path.endswith("/attempts/latest") and "/rollouts/" in path:
870
+ return re.sub(r"rollouts/[^/]+/attempts/latest$", "rollouts/{rollout_id}/attempts/latest", path)
871
+ if path.endswith("/attempts/search") and "/rollouts/" in path:
872
+ return re.sub(r"rollouts/[^/]+/attempts/search$", "rollouts/{rollout_id}/attempts/search", path)
873
+ if path.endswith("/resources/latest"):
874
+ return path
875
+ if path.endswith("/search"):
876
+ return path
877
+ if "enqueue" in path or "dequeue" in path:
878
+ return path
879
+
880
+ # Handle generic IDs
881
+ # (Order matters: longest paths first or lookaheads)
882
+ path = re.sub(r"/attempts/[^/]+$", "/attempts/{attempt_id}", path)
883
+ path = re.sub(r"/rollouts/[^/]+", "/rollouts/{rollout_id}", path) # Handles root and middle
884
+ path = re.sub(r"/resources/[^/]+$", "/resources/{resources_id}", path)
885
+ path = re.sub(r"/workers/[^/]+$", "/workers/{worker_id}", path)
886
+
887
+ return path
888
+
889
+ @app.middleware("http")
890
+ async def tracking_middleware( # pyright: ignore[reportUnusedFunction]
891
+ request: Request, call_next: Callable[[Request], Awaitable[Response]]
892
+ ) -> Response:
893
+ if self._tracker is None:
894
+ return await call_next(request)
895
+
896
+ start = time.perf_counter()
897
+ status = 520 # Default to 520 if things crash hard
898
+
899
+ try:
900
+ response = await call_next(request)
901
+ status = response.status_code
902
+ return response
903
+ except asyncio.CancelledError:
904
+ # Client disconnected (Timeout)
905
+ status = 499 # Standard Nginx code for "Client Closed Request"
906
+ server_logger.debug(f"Client disconnected (Timeout): {request.url.path}", exc_info=True)
907
+ raise # Re-raise to let Uvicorn handle the cleanup
908
+ except Exception as exc:
909
+ status = resolve_error_type(exc)
910
+ server_logger.debug(f"Server error: {request.url.path}", exc_info=True)
911
+ raise
912
+ finally:
913
+ # This block executes NO MATTER WHAT happens above
914
+ elapsed = time.perf_counter() - start
915
+
916
+ # Strip the ID-specific URL parts
917
+ path = get_template_path(request.url.path)
918
+ method = request.method
919
+
920
+ await self._tracker.inc_counter(
921
+ "msk.http.total",
922
+ labels={"method": method, "path": path, "status": str(status)},
923
+ )
924
+ await self._tracker.observe_histogram(
925
+ "msk.http.latency",
926
+ value=elapsed,
927
+ labels={"method": method, "path": path, "status": str(status)},
928
+ )
929
+
930
+ if self._tracker.has_prometheus():
931
+ from prometheus_client import make_asgi_app # pyright: ignore[reportUnknownVariableType]
932
+
933
+ metrics_app = make_asgi_app( # pyright: ignore[reportUnknownVariableType]
934
+ registry=get_prometheus_registry()
935
+ )
936
+
937
+ # This App would need to be accessed via /v1/prometheus/ (note the trailing slash)
938
+ app.mount(api.prefix + "/prometheus", metrics_app) # pyright: ignore[reportUnknownArgumentType]
939
+
940
+ def _setup_otlp(self, api: APIRouter):
941
+ """Setup OTLP endpoints."""
942
+
943
+ async def _trace_handler(request: PbExportTraceServiceRequest) -> None:
944
+ spans = await spans_from_proto(request, self.get_many_span_sequence_ids)
945
+ server_logger.debug(f"Received {len(spans)} OTLP spans: {', '.join([span.name for span in spans])}")
946
+ await self.add_many_spans(spans)
947
+
948
+ # Reserved methods for OTEL traces
949
+ # https://opentelemetry.io/docs/specs/otlp/#otlphttp-request
950
+ # This is currently the recommended path for Otel compatibility and bulk-insertion support.
951
+ @api.post("/traces")
952
+ async def otlp_traces(request: Request): # pyright: ignore[reportUnusedFunction]
953
+ return await handle_otlp_export(
954
+ request, PbExportTraceServiceRequest, PbExportTraceServiceResponse, _trace_handler, "traces"
955
+ )
956
+
957
+ # Other API endpoints are not supported yet
958
+ @api.post("/metrics")
959
+ async def otlp_metrics(): # pyright: ignore[reportUnusedFunction]
960
+ return Response(status_code=501)
961
+
962
+ @api.post("/logs")
963
+ async def otlp_logs(): # pyright: ignore[reportUnusedFunction]
964
+ return Response(status_code=501)
965
+
966
+ @api.post("/development/profiles")
967
+ async def otlp_development_profiles(): # pyright: ignore[reportUnusedFunction]
968
+ return Response(status_code=501)
969
+
970
+ def _setup_dashboard(self):
971
+ """Setup the dashboard static files and SPA."""
972
+ assert self.app is not None
973
+
974
+ dashboard_dir = (Path(__file__).parent.parent / "dashboard").resolve()
975
+ if not dashboard_dir.exists():
976
+ server_logger.error("Dashboard directory not found at %s. Please build the dashboard first.", dashboard_dir)
977
+ return
978
+
979
+ dashboard_assets_dir = dashboard_dir / "assets"
980
+ if not dashboard_assets_dir.exists():
981
+ server_logger.error(
982
+ "Dashboard assets directory not found at %s. Please build the dashboard first.", dashboard_assets_dir
983
+ )
984
+ return
985
+
986
+ index_file = dashboard_dir / "index.html"
987
+ if not index_file.exists():
988
+ server_logger.error("Dashboard index file not found at %s. Please build the dashboard first.", index_file)
989
+ return
990
+
991
+ # Mount the static files in dashboard/assets
992
+ self.app.mount("/assets", CachedStaticFiles(directory=dashboard_assets_dir), name="assets")
993
+
994
+ # SPA fallback (client-side routing)
995
+ # Anything that's not /v1/* or a real file in /assets will serve index.html
996
+ @self.app.get("/", include_in_schema=False)
997
+ def root(): # pyright: ignore[reportUnusedFunction]
998
+ return FileResponse(index_file)
999
+
1000
+ @self.app.get("/{full_path:path}", include_in_schema=False)
1001
+ def spa_fallback(full_path: str): # pyright: ignore[reportUnusedFunction]
1002
+ if full_path.startswith("v1/"):
1003
+ raise HTTPException(status_code=404, detail="Not Found")
1004
+ # Let the frontend router handle it
1005
+ return FileResponse(index_file)
1006
+
1007
+ server_logger.info("Mantisdk dashboard will be available at %s", self.endpoint)
1008
+
1009
+ # Delegate methods
1010
+ async def _call_store_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
1011
+ """First decide what store to delegate to in *this* process, and then call the method on it.
1012
+
1013
+ - In the owner process: delegate to the in-process store.
1014
+ - In a different process: delegate to a HTTP client talking to the server.
1015
+ """
1016
+ # If the store is zero-copy, we can just call the method directly.
1017
+ if self.store is not None and self.store.capabilities.get("zero_copy", False):
1018
+ return await getattr(self.store, method_name)(*args, **kwargs)
1019
+
1020
+ if os.getpid() == self._owner_pid:
1021
+ if method_name == "wait_for_rollouts":
1022
+ # wait_for_rollouts can block for a long time; avoid holding the lock
1023
+ # so other requests can make progress while we wait.
1024
+ return await getattr(self.store, method_name)(*args, **kwargs)
1025
+
1026
+ # If it's already thread-safe, we can just call the method directly.
1027
+ # Acquiring the threading lock directly would block the event loop if it's
1028
+ # already held by another thread (for example, the HTTP server thread).
1029
+ # Potential fix here are needed to make it work. For example:
1030
+ # ```
1031
+ # acquired = self._lock.acquire(blocking=False)
1032
+ # if not acquired:
1033
+ # await asyncio.to_thread(self._lock.acquire)
1034
+ # try:
1035
+ # return await getattr(self.store, method_name)(*args, **kwargs)
1036
+ # finally:
1037
+ # self._lock.release()
1038
+ # ```
1039
+ # Or we can just bypass the lock for thread-safe stores.
1040
+ if self.store is not None and self.store.capabilities.get("thread_safe", False):
1041
+ return await getattr(self.store, method_name)(*args, **kwargs)
1042
+ else:
1043
+ with self._lock:
1044
+ return await getattr(self.store, method_name)(*args, **kwargs)
1045
+ if self._client is None:
1046
+ self._client = LightningStoreClient(self.endpoint)
1047
+ return await getattr(self._client, method_name)(*args, **kwargs)
1048
+
1049
+ async def statistics(self) -> LightningStoreStatistics:
1050
+ return await self._call_store_method("statistics")
1051
+
1052
+ async def start_rollout(
1053
+ self,
1054
+ input: TaskInput,
1055
+ mode: Literal["train", "val", "test"] | None = None,
1056
+ resources_id: str | None = None,
1057
+ config: RolloutConfig | None = None,
1058
+ metadata: Dict[str, Any] | None = None,
1059
+ worker_id: Optional[str] = None,
1060
+ ) -> AttemptedRollout:
1061
+ return await self._call_store_method(
1062
+ "start_rollout",
1063
+ input,
1064
+ mode,
1065
+ resources_id,
1066
+ config,
1067
+ metadata,
1068
+ worker_id,
1069
+ )
1070
+
1071
+ async def enqueue_rollout(
1072
+ self,
1073
+ input: TaskInput,
1074
+ mode: Literal["train", "val", "test"] | None = None,
1075
+ resources_id: str | None = None,
1076
+ config: RolloutConfig | None = None,
1077
+ metadata: Dict[str, Any] | None = None,
1078
+ ) -> Rollout:
1079
+ return await self._call_store_method(
1080
+ "enqueue_rollout",
1081
+ input,
1082
+ mode,
1083
+ resources_id,
1084
+ config,
1085
+ metadata,
1086
+ )
1087
+
1088
+ async def enqueue_many_rollouts(self, rollouts: Sequence[EnqueueRolloutRequest]) -> Sequence[Rollout]:
1089
+ return await self._call_store_method("enqueue_many_rollouts", rollouts)
1090
+
1091
+ async def dequeue_rollout(self, worker_id: Optional[str] = None) -> Optional[AttemptedRollout]:
1092
+ return await self._call_store_method("dequeue_rollout", worker_id)
1093
+
1094
+ async def dequeue_many_rollouts(
1095
+ self,
1096
+ *,
1097
+ limit: int = 1,
1098
+ worker_id: Optional[str] = None,
1099
+ ) -> Sequence[AttemptedRollout]:
1100
+ return await self._call_store_method("dequeue_many_rollouts", limit=limit, worker_id=worker_id)
1101
+
1102
+ async def start_attempt(self, rollout_id: str, worker_id: Optional[str] = None) -> AttemptedRollout:
1103
+ return await self._call_store_method("start_attempt", rollout_id, worker_id)
1104
+
1105
+ async def query_rollouts(
1106
+ self,
1107
+ *,
1108
+ status_in: Optional[Sequence[RolloutStatus]] = None,
1109
+ rollout_id_in: Optional[Sequence[str]] = None,
1110
+ rollout_id_contains: Optional[str] = None,
1111
+ filter_logic: Literal["and", "or"] = "and",
1112
+ sort_by: Optional[str] = None,
1113
+ sort_order: Literal["asc", "desc"] = "asc",
1114
+ limit: int = -1,
1115
+ offset: int = 0,
1116
+ status: Optional[Sequence[RolloutStatus]] = None,
1117
+ rollout_ids: Optional[Sequence[str]] = None,
1118
+ ) -> PaginatedResult[Union[AttemptedRollout, Rollout]]:
1119
+ return await self._call_store_method(
1120
+ "query_rollouts",
1121
+ status_in=status_in,
1122
+ rollout_id_in=rollout_id_in,
1123
+ rollout_id_contains=rollout_id_contains,
1124
+ filter_logic=filter_logic,
1125
+ sort_by=sort_by,
1126
+ sort_order=sort_order,
1127
+ limit=limit,
1128
+ offset=offset,
1129
+ status=status,
1130
+ rollout_ids=rollout_ids,
1131
+ )
1132
+
1133
+ async def query_attempts(
1134
+ self,
1135
+ rollout_id: str,
1136
+ *,
1137
+ sort_by: Optional[str] = "sequence_id",
1138
+ sort_order: Literal["asc", "desc"] = "asc",
1139
+ limit: int = -1,
1140
+ offset: int = 0,
1141
+ ) -> PaginatedResult[Attempt]:
1142
+ return await self._call_store_method(
1143
+ "query_attempts",
1144
+ rollout_id,
1145
+ sort_by=sort_by,
1146
+ sort_order=sort_order,
1147
+ limit=limit,
1148
+ offset=offset,
1149
+ )
1150
+
1151
+ async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
1152
+ return await self._call_store_method("get_latest_attempt", rollout_id)
1153
+
1154
+ async def query_resources(
1155
+ self,
1156
+ *,
1157
+ resources_id: Optional[str] = None,
1158
+ resources_id_contains: Optional[str] = None,
1159
+ sort_by: Optional[str] = None,
1160
+ sort_order: Literal["asc", "desc"] = "asc",
1161
+ limit: int = -1,
1162
+ offset: int = 0,
1163
+ ) -> PaginatedResult[ResourcesUpdate]:
1164
+ return await self._call_store_method(
1165
+ "query_resources",
1166
+ resources_id=resources_id,
1167
+ resources_id_contains=resources_id_contains,
1168
+ sort_by=sort_by,
1169
+ sort_order=sort_order,
1170
+ limit=limit,
1171
+ offset=offset,
1172
+ )
1173
+
1174
+ async def get_rollout_by_id(self, rollout_id: str) -> Optional[Rollout]:
1175
+ return await self._call_store_method("get_rollout_by_id", rollout_id)
1176
+
1177
+ async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
1178
+ return await self._call_store_method("add_resources", resources)
1179
+
1180
+ async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
1181
+ return await self._call_store_method("update_resources", resources_id, resources)
1182
+
1183
+ async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
1184
+ return await self._call_store_method("get_resources_by_id", resources_id)
1185
+
1186
+ async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
1187
+ return await self._call_store_method("get_latest_resources")
1188
+
1189
+ async def add_span(self, span: Span) -> Optional[Span]:
1190
+ return await self._call_store_method("add_span", span)
1191
+
1192
+ async def add_many_spans(self, spans: Sequence[Span]) -> Sequence[Span]:
1193
+ return await self._call_store_method("add_many_spans", spans)
1194
+
1195
+ async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
1196
+ return await self._call_store_method("get_next_span_sequence_id", rollout_id, attempt_id)
1197
+
1198
+ async def get_many_span_sequence_ids(self, rollout_attempt_ids: Sequence[Tuple[str, str]]) -> Sequence[int]:
1199
+ return await self._call_store_method("get_many_span_sequence_ids", rollout_attempt_ids)
1200
+
1201
+ async def add_otel_span(
1202
+ self,
1203
+ rollout_id: str,
1204
+ attempt_id: str,
1205
+ readable_span: ReadableSpan,
1206
+ sequence_id: int | None = None,
1207
+ ) -> Optional[Span]:
1208
+ return await self._call_store_method(
1209
+ "add_otel_span",
1210
+ rollout_id,
1211
+ attempt_id,
1212
+ readable_span,
1213
+ sequence_id,
1214
+ )
1215
+
1216
+ async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
1217
+ return await self._call_store_method("wait_for_rollouts", rollout_ids=rollout_ids, timeout=timeout)
1218
+
1219
+ async def query_spans(
1220
+ self,
1221
+ rollout_id: str,
1222
+ attempt_id: str | Literal["latest"] | None = None,
1223
+ *,
1224
+ trace_id: Optional[str] = None,
1225
+ trace_id_contains: Optional[str] = None,
1226
+ span_id: Optional[str] = None,
1227
+ span_id_contains: Optional[str] = None,
1228
+ parent_id: Optional[str] = None,
1229
+ parent_id_contains: Optional[str] = None,
1230
+ name: Optional[str] = None,
1231
+ name_contains: Optional[str] = None,
1232
+ filter_logic: Literal["and", "or"] = "and",
1233
+ limit: int = -1,
1234
+ offset: int = 0,
1235
+ sort_by: Optional[str] = "sequence_id",
1236
+ sort_order: Literal["asc", "desc"] = "asc",
1237
+ ) -> PaginatedResult[Span]:
1238
+ return await self._call_store_method(
1239
+ "query_spans",
1240
+ rollout_id,
1241
+ attempt_id,
1242
+ trace_id=trace_id,
1243
+ trace_id_contains=trace_id_contains,
1244
+ span_id=span_id,
1245
+ span_id_contains=span_id_contains,
1246
+ parent_id=parent_id,
1247
+ parent_id_contains=parent_id_contains,
1248
+ name=name,
1249
+ name_contains=name_contains,
1250
+ filter_logic=filter_logic,
1251
+ limit=limit,
1252
+ offset=offset,
1253
+ sort_by=sort_by,
1254
+ sort_order=sort_order,
1255
+ )
1256
+
1257
+ async def update_rollout(
1258
+ self,
1259
+ rollout_id: str,
1260
+ input: TaskInput | Unset = UNSET,
1261
+ mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
1262
+ resources_id: Optional[str] | Unset = UNSET,
1263
+ status: RolloutStatus | Unset = UNSET,
1264
+ config: RolloutConfig | Unset = UNSET,
1265
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
1266
+ ) -> Rollout:
1267
+ return await self._call_store_method(
1268
+ "update_rollout",
1269
+ rollout_id,
1270
+ input,
1271
+ mode,
1272
+ resources_id,
1273
+ status,
1274
+ config,
1275
+ metadata,
1276
+ )
1277
+
1278
+ async def update_attempt(
1279
+ self,
1280
+ rollout_id: str,
1281
+ attempt_id: str | Literal["latest"],
1282
+ status: AttemptStatus | Unset = UNSET,
1283
+ worker_id: str | Unset = UNSET,
1284
+ last_heartbeat_time: float | Unset = UNSET,
1285
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
1286
+ ) -> Attempt:
1287
+ return await self._call_store_method(
1288
+ "update_attempt",
1289
+ rollout_id,
1290
+ attempt_id,
1291
+ status,
1292
+ worker_id,
1293
+ last_heartbeat_time,
1294
+ metadata,
1295
+ )
1296
+
1297
+ async def query_workers(
1298
+ self,
1299
+ *,
1300
+ status_in: Optional[Sequence[WorkerStatus]] = None,
1301
+ worker_id_contains: Optional[str] = None,
1302
+ filter_logic: Literal["and", "or"] = "and",
1303
+ sort_by: Optional[str] = None,
1304
+ sort_order: Literal["asc", "desc"] = "asc",
1305
+ limit: int = -1,
1306
+ offset: int = 0,
1307
+ ) -> PaginatedResult[Worker]:
1308
+ return await self._call_store_method(
1309
+ "query_workers",
1310
+ status_in=status_in,
1311
+ worker_id_contains=worker_id_contains,
1312
+ filter_logic=filter_logic,
1313
+ sort_by=sort_by,
1314
+ sort_order=sort_order,
1315
+ limit=limit,
1316
+ offset=offset,
1317
+ )
1318
+
1319
+ async def get_worker_by_id(self, worker_id: str) -> Optional[Worker]:
1320
+ return await self._call_store_method("get_worker_by_id", worker_id)
1321
+
1322
+ async def update_worker(
1323
+ self,
1324
+ worker_id: str,
1325
+ heartbeat_stats: Dict[str, Any] | Unset = UNSET,
1326
+ ) -> Worker:
1327
+ return await self._call_store_method(
1328
+ "update_worker",
1329
+ worker_id,
1330
+ heartbeat_stats,
1331
+ )
1332
+
1333
+
1334
+ class LightningStoreClient(LightningStore):
1335
+ """HTTP client that talks to a remote LightningStoreServer.
1336
+
1337
+ Args:
1338
+ server_address: The address of the LightningStoreServer to connect to.
1339
+ retry_delays:
1340
+ Backoff schedule (seconds) used when the initial request fails for a
1341
+ non-application reason. Each entry is a retry attempt.
1342
+ Setting to an empty sequence to disable retries.
1343
+ health_retry_delays:
1344
+ Delays between /health probes while waiting for the server to come back.
1345
+ Setting to an empty sequence to disable health checks.
1346
+ request_timeout: Timeout (seconds) for each request.
1347
+ connection_timeout: Timeout (seconds) for establishing connection.
1348
+ """
1349
+
1350
+ def __init__(
1351
+ self,
1352
+ server_address: str,
1353
+ *,
1354
+ retry_delays: Sequence[float] = (1.0, 2.0, 5.0),
1355
+ health_retry_delays: Sequence[float] = (0.1, 0.2, 0.5),
1356
+ request_timeout: float = 30.0,
1357
+ connection_timeout: float = 5.0,
1358
+ ):
1359
+ self.server_address_root = server_address.rstrip("/")
1360
+ self.server_address = self.server_address_root + API_V1_AGL_PREFIX
1361
+ self._sessions: Dict[int, aiohttp.ClientSession] = {} # id(loop) -> ClientSession
1362
+ self._lock = threading.Lock()
1363
+
1364
+ # retry config
1365
+ self._retry_delays = tuple(float(d) for d in retry_delays)
1366
+ self._health_retry_delays = tuple(float(d) for d in health_retry_delays)
1367
+
1368
+ # Timeouts
1369
+ self._request_timeout = request_timeout
1370
+ self._connection_timeout = connection_timeout
1371
+
1372
+ # Store whether the dequeue was successful in history
1373
+ self._dequeue_was_successful: bool = False
1374
+ self._dequeue_first_unsuccessful: bool = True
1375
+
1376
+ @property
1377
+ def capabilities(self) -> LightningStoreCapabilities:
1378
+ """Return the capabilities of the store."""
1379
+ return LightningStoreCapabilities(
1380
+ thread_safe=True,
1381
+ async_safe=True,
1382
+ zero_copy=True,
1383
+ otlp_traces=True,
1384
+ )
1385
+
1386
+ def otlp_traces_endpoint(self) -> str:
1387
+ """Return the OTLP/HTTP traces endpoint of the store."""
1388
+ return f"{self.server_address_root}/v1/traces"
1389
+
1390
+ async def statistics(self) -> LightningStoreStatistics:
1391
+ payload = await self._request_json("get", "/statistics")
1392
+ return cast(LightningStoreStatistics, payload)
1393
+
1394
+ def __getstate__(self):
1395
+ """
1396
+ When LightningStoreClient is pickled (e.g., passed to a subprocess), we only
1397
+ serialize the server address and retry configurations. The ClientSessions
1398
+ are excluded as they should not be transferred between processes.
1399
+ """
1400
+ return {
1401
+ "server_address_root": self.server_address_root,
1402
+ "server_address": self.server_address,
1403
+ "_retry_delays": self._retry_delays,
1404
+ "_health_retry_delays": self._health_retry_delays,
1405
+ "_request_timeout": self._request_timeout,
1406
+ "_connection_timeout": self._connection_timeout,
1407
+ }
1408
+
1409
+ def __setstate__(self, state: Dict[str, Any]):
1410
+ """
1411
+ Restore from pickle by reconstructing only the essential attributes.
1412
+
1413
+ Replicating `__init__` logic to create another client instance in the subprocess.
1414
+ """
1415
+ self.server_address = state["server_address"]
1416
+ self.server_address_root = state["server_address_root"]
1417
+ self._sessions = {}
1418
+ self._lock = threading.Lock()
1419
+ self._retry_delays = state["_retry_delays"]
1420
+ self._health_retry_delays = state["_health_retry_delays"]
1421
+ self._request_timeout = state["_request_timeout"]
1422
+ self._connection_timeout = state["_connection_timeout"]
1423
+ self._dequeue_was_successful = False
1424
+ self._dequeue_first_unsuccessful = True
1425
+
1426
+ async def _get_session(self) -> aiohttp.ClientSession:
1427
+ # In the proxy process, FastAPI middleware calls
1428
+ # client_store.get_next_span_sequence_id(...). With
1429
+ # reuse_session=True, _get_session() creates and caches a
1430
+ # single ClientSession bound to the uvicorn event loop.
1431
+ #
1432
+ # Later, the OpenTelemetry exporter (LightningSpanExporter)
1433
+ # runs its flush on its own private event loop (in a different
1434
+ # thread) and calls client_store.add_otel_span(...) ->
1435
+ # client_store.add_span(...).
1436
+ #
1437
+ # If we reuse one session across all, the exporter tries to reuse the
1438
+ # same cached ClientSession that was created on the uvicorn
1439
+ # loop. aiohttp.ClientSession is not loop-agnostic or
1440
+ # thread-safe. Using it from another loop can hang on the
1441
+ # first request. That's why we need a map from loop to session.
1442
+
1443
+ loop = asyncio.get_running_loop()
1444
+ key = id(loop)
1445
+ with self._lock:
1446
+ sess = self._sessions.get(key)
1447
+ if sess is None or sess.closed:
1448
+ timeout = aiohttp.ClientTimeout(
1449
+ total=self._request_timeout,
1450
+ connect=self._connection_timeout,
1451
+ sock_connect=self._connection_timeout,
1452
+ sock_read=self._request_timeout,
1453
+ )
1454
+ sess = aiohttp.ClientSession(timeout=timeout)
1455
+ self._sessions[key] = sess
1456
+ return sess
1457
+
1458
+ async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
1459
+ """
1460
+ Probe the server's /health until it responds 200 or retries are exhausted.
1461
+ Returns True if healthy, False otherwise.
1462
+ """
1463
+ if not self._health_retry_delays:
1464
+ client_logger.info("No health retry delays configured; skipping health checks.")
1465
+ return True
1466
+
1467
+ client_logger.info(f"Waiting for server to be healthy at {self.server_address}/health")
1468
+ for delay in [*self._health_retry_delays, 0.0]:
1469
+ try:
1470
+ async with session.get(f"{self.server_address}/health") as r:
1471
+ if r.status == 200:
1472
+ client_logger.info(f"Server is healthy at {self.server_address}/health")
1473
+ return True
1474
+ except Exception:
1475
+ # swallow and retry
1476
+ if delay > 0.0:
1477
+ client_logger.warning(f"Server is not healthy yet. Retrying in {delay} seconds.")
1478
+ if delay > 0.0:
1479
+ await asyncio.sleep(delay)
1480
+ client_logger.error(
1481
+ f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts"
1482
+ )
1483
+ return False
1484
+
1485
+ async def _request_json(
1486
+ self,
1487
+ method: Literal["get", "post"],
1488
+ path: str,
1489
+ *,
1490
+ json: Any | None = None,
1491
+ params: Mapping[str, Any] | Sequence[Tuple[str, Any]] | None = None,
1492
+ ) -> Any:
1493
+ """
1494
+ Make an HTTP request with:
1495
+
1496
+ 1) First attempt.
1497
+ 2) On network/session failures: probe /health until back, then retry
1498
+ according to self._retry_delays.
1499
+ 3) On 4xx (e.g., 400 set by server exception handler): do not retry.
1500
+
1501
+ Returns parsed JSON (or raw JSON scalar like int).
1502
+ Raises the last exception if all retries fail.
1503
+ """
1504
+ session = await self._get_session()
1505
+ url = f"{self.server_address}{path if path.startswith('/') else '/'+path}"
1506
+
1507
+ # attempt 0 is immediate, then follow retry schedule
1508
+ attempts = (0.0,) + self._retry_delays
1509
+ last_exc: Exception | None = None
1510
+
1511
+ for delay in attempts:
1512
+ if delay:
1513
+ client_logger.info(f"Waiting {delay} seconds before retrying {method}: {path}")
1514
+ await asyncio.sleep(delay)
1515
+ try:
1516
+ http_call = getattr(session, method)
1517
+ async with http_call(url, json=json, params=params) as resp:
1518
+ resp.raise_for_status()
1519
+ return await resp.json()
1520
+ except aiohttp.ClientResponseError as cre:
1521
+ # Respect app-level 4xx as final
1522
+ # 4xx => application issue; do not retry (except 408 which is transient)
1523
+ client_logger.debug(f"ClientResponseError ({method} {path}): {cre.status} {cre.message}", exc_info=True)
1524
+ if 400 <= cre.status < 500 and cre.status != 408:
1525
+ raise
1526
+ # 5xx and others will be retried below if they raise
1527
+ last_exc = cre
1528
+ client_logger.info(f"5xx and other status codes will be retried. Retrying the request {method}: {path}")
1529
+ # before next retry, ensure server is healthy
1530
+ if not await self._wait_until_healthy(session):
1531
+ break # server is not healthy, do not retry
1532
+ except (
1533
+ aiohttp.ServerDisconnectedError,
1534
+ aiohttp.ClientConnectorError,
1535
+ aiohttp.ClientOSError,
1536
+ asyncio.TimeoutError,
1537
+ ) as net_exc:
1538
+ # Network/session issue: probe health before retrying
1539
+ client_logger.debug(f"Network/session issue ({method} {path}): {net_exc}", exc_info=True)
1540
+ last_exc = net_exc
1541
+ client_logger.info(f"Network/session issue: {net_exc} - will retry the request {method}: {path}")
1542
+ if not await self._wait_until_healthy(session):
1543
+ break # server is not healthy, do not retry
1544
+
1545
+ # exhausted retries
1546
+ assert last_exc is not None
1547
+ raise last_exc
1548
+
1549
+ async def close(self):
1550
+ """Close the HTTP session."""
1551
+ with self._lock:
1552
+ sessions = list(self._sessions.values())
1553
+ self._sessions.clear()
1554
+
1555
+ # close them on their own loops to avoid warnings
1556
+ async def _close(sess: aiohttp.ClientSession):
1557
+ if not sess.closed:
1558
+ await sess.close()
1559
+
1560
+ # If called from one loop, best-effort close here.
1561
+ for s in sessions:
1562
+ try:
1563
+ await _close(s)
1564
+ except RuntimeError:
1565
+ # If created on a different loop/thread, schedule a thread-safe close
1566
+ # Fallback: close without awaiting (library tolerates it in practice),
1567
+ # or keep a per-loop shutdown hook where they were created.
1568
+ pass
1569
+
1570
+ async def start_rollout(
1571
+ self,
1572
+ input: TaskInput,
1573
+ mode: Literal["train", "val", "test"] | None = None,
1574
+ resources_id: str | None = None,
1575
+ config: RolloutConfig | None = None,
1576
+ metadata: Dict[str, Any] | None = None,
1577
+ worker_id: Optional[str] = None,
1578
+ ) -> AttemptedRollout:
1579
+ data = await self._request_json(
1580
+ "post",
1581
+ "/rollouts",
1582
+ json=RolloutRequest(
1583
+ input=input,
1584
+ mode=mode,
1585
+ resources_id=resources_id,
1586
+ config=config,
1587
+ metadata=metadata,
1588
+ worker_id=worker_id,
1589
+ ).model_dump(exclude_none=False),
1590
+ )
1591
+ return AttemptedRollout.model_validate(data)
1592
+
1593
+ async def enqueue_rollout(
1594
+ self,
1595
+ input: TaskInput,
1596
+ mode: Literal["train", "val", "test"] | None = None,
1597
+ resources_id: str | None = None,
1598
+ config: RolloutConfig | None = None,
1599
+ metadata: Dict[str, Any] | None = None,
1600
+ ) -> Rollout:
1601
+ request_body = EnqueueManyRolloutsRequest(
1602
+ rollouts=[
1603
+ EnqueueRolloutRequest(
1604
+ input=input,
1605
+ mode=mode,
1606
+ resources_id=resources_id,
1607
+ config=config,
1608
+ metadata=metadata,
1609
+ )
1610
+ ]
1611
+ ).model_dump(exclude_none=False)
1612
+ data = await self._request_json(
1613
+ "post",
1614
+ "/queues/rollouts/enqueue",
1615
+ json=request_body,
1616
+ )
1617
+ if not data:
1618
+ raise RuntimeError("enqueue_rollout returned no rollouts")
1619
+ return Rollout.model_validate(data[0])
1620
+
1621
+ async def enqueue_many_rollouts(self, rollouts: Sequence[EnqueueRolloutRequest]) -> Sequence[Rollout]:
1622
+ if not rollouts:
1623
+ return []
1624
+ request_body = EnqueueManyRolloutsRequest(rollouts=list(rollouts)).model_dump(exclude_none=False)
1625
+ data = await self._request_json(
1626
+ "post",
1627
+ "/queues/rollouts/enqueue",
1628
+ json=request_body,
1629
+ )
1630
+ return [Rollout.model_validate(entry) for entry in data]
1631
+
1632
+ async def _dequeue_batch(
1633
+ self,
1634
+ *,
1635
+ limit: int,
1636
+ worker_id: Optional[str],
1637
+ ) -> List[AttemptedRollout]:
1638
+ if limit <= 0:
1639
+ return []
1640
+ session = await self._get_session()
1641
+ url = f"{self.server_address}/queues/rollouts/dequeue"
1642
+ payload: Dict[str, Any] = {"limit": limit}
1643
+ if worker_id is not None:
1644
+ payload["worker_id"] = worker_id
1645
+ try:
1646
+ async with session.post(url, json=payload) as resp:
1647
+ resp.raise_for_status()
1648
+ data = await resp.json()
1649
+ self._dequeue_was_successful = True
1650
+ return [AttemptedRollout.model_validate(item) for item in data]
1651
+ except Exception as e:
1652
+ if self._dequeue_was_successful:
1653
+ if self._dequeue_first_unsuccessful:
1654
+ client_logger.warning(f"dequeue_rollout failed with exception: {e}")
1655
+ self._dequeue_first_unsuccessful = False
1656
+ client_logger.debug("dequeue_rollout failed with exception. Details:", exc_info=True)
1657
+ # Else ignore the exception because the server is not ready yet
1658
+ return []
1659
+
1660
+ async def dequeue_rollout(self, worker_id: Optional[str] = None) -> Optional[AttemptedRollout]:
1661
+ """
1662
+ Dequeue a rollout from the server queue.
1663
+
1664
+ Returns:
1665
+ AttemptedRollout if a rollout is available, None if queue is empty.
1666
+
1667
+ Note:
1668
+ This method does NOT retry on failures. If any exception occurs (network error,
1669
+ server error, etc.), it logs the error and returns None immediately.
1670
+ """
1671
+ attempts = await self._dequeue_batch(limit=1, worker_id=worker_id)
1672
+ return attempts[0] if attempts else None
1673
+
1674
+ async def dequeue_many_rollouts(
1675
+ self,
1676
+ *,
1677
+ limit: int = 1,
1678
+ worker_id: Optional[str] = None,
1679
+ ) -> Sequence[AttemptedRollout]:
1680
+ return await self._dequeue_batch(limit=limit, worker_id=worker_id)
1681
+
1682
+ async def start_attempt(self, rollout_id: str, worker_id: Optional[str] = None) -> AttemptedRollout:
1683
+ payload = {"worker_id": worker_id} if worker_id is not None else None
1684
+ data = await self._request_json(
1685
+ "post",
1686
+ f"/rollouts/{rollout_id}/attempts",
1687
+ json=payload,
1688
+ )
1689
+ return AttemptedRollout.model_validate(data)
1690
+
1691
+ async def query_rollouts(
1692
+ self,
1693
+ *,
1694
+ status_in: Optional[Sequence[RolloutStatus]] = None,
1695
+ rollout_id_in: Optional[Sequence[str]] = None,
1696
+ rollout_id_contains: Optional[str] = None,
1697
+ filter_logic: Literal["and", "or"] = "and",
1698
+ sort_by: Optional[str] = None,
1699
+ sort_order: Literal["asc", "desc"] = "asc",
1700
+ limit: int = -1,
1701
+ offset: int = 0,
1702
+ status: Optional[Sequence[RolloutStatus]] = None,
1703
+ rollout_ids: Optional[Sequence[str]] = None,
1704
+ ) -> PaginatedResult[Union[AttemptedRollout, Rollout]]:
1705
+ resolved_status = status_in if status_in is not None else status
1706
+ resolved_rollout_ids = rollout_id_in if rollout_id_in is not None else rollout_ids
1707
+
1708
+ payload: Dict[str, Any] = {
1709
+ "limit": limit,
1710
+ "offset": offset,
1711
+ }
1712
+ if resolved_status is not None:
1713
+ payload["status_in"] = resolved_status
1714
+ if resolved_rollout_ids is not None:
1715
+ payload["rollout_id_in"] = resolved_rollout_ids
1716
+ if rollout_id_contains is not None:
1717
+ payload["rollout_id_contains"] = rollout_id_contains
1718
+ payload["filter_logic"] = filter_logic
1719
+ if sort_by is not None:
1720
+ payload["sort_by"] = sort_by
1721
+ payload["sort_order"] = sort_order
1722
+
1723
+ data = await self._request_json("post", "/rollouts/search", json=payload)
1724
+ items = [
1725
+ (
1726
+ AttemptedRollout.model_validate(item)
1727
+ if isinstance(item, dict) and "attempt" in item
1728
+ else Rollout.model_validate(item)
1729
+ )
1730
+ for item in data["items"]
1731
+ ]
1732
+ return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
1733
+
1734
+ async def query_attempts(
1735
+ self,
1736
+ rollout_id: str,
1737
+ *,
1738
+ sort_by: Optional[str] = "sequence_id",
1739
+ sort_order: Literal["asc", "desc"] = "asc",
1740
+ limit: int = -1,
1741
+ offset: int = 0,
1742
+ ) -> PaginatedResult[Attempt]:
1743
+ payload: Dict[str, Any] = {
1744
+ "limit": limit,
1745
+ "offset": offset,
1746
+ }
1747
+ if sort_by is not None:
1748
+ payload["sort_by"] = sort_by
1749
+ payload["sort_order"] = sort_order
1750
+ data = await self._request_json("post", f"/rollouts/{rollout_id}/attempts/search", json=payload)
1751
+ items = [Attempt.model_validate(item) for item in data["items"]]
1752
+ return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
1753
+
1754
+ async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
1755
+ """
1756
+ Get the latest attempt for a rollout.
1757
+
1758
+ Args:
1759
+ rollout_id: ID of the rollout to query.
1760
+
1761
+ Returns:
1762
+ Attempt if found, None if not found or if all retries are exhausted.
1763
+
1764
+ Note:
1765
+ This method retries on transient failures (network errors, 5xx status codes).
1766
+ If all retries fail, it logs the error and returns None instead of raising an exception.
1767
+ """
1768
+ try:
1769
+ data = await self._request_json("get", f"/rollouts/{rollout_id}/attempts/latest")
1770
+ return Attempt.model_validate(data) if data else None
1771
+ except Exception as e:
1772
+ client_logger.error(
1773
+ f"get_latest_attempt failed after all retries for rollout_id={rollout_id}: {e}", exc_info=True
1774
+ )
1775
+ return None
1776
+
1777
+ async def get_rollout_by_id(self, rollout_id: str) -> Optional[Rollout]:
1778
+ """
1779
+ Get a rollout by its ID.
1780
+
1781
+ Args:
1782
+ rollout_id: ID of the rollout to retrieve.
1783
+
1784
+ Returns:
1785
+ Rollout if found, None if not found or if all retries are exhausted.
1786
+
1787
+ Note:
1788
+ This method retries on transient failures (network errors, 5xx status codes).
1789
+ If all retries fail, it logs the error and returns None instead of raising an exception.
1790
+ """
1791
+ try:
1792
+ data = await self._request_json("get", f"/rollouts/{rollout_id}")
1793
+ if data is None:
1794
+ return None
1795
+ elif isinstance(data, dict) and "attempt" in data:
1796
+ return AttemptedRollout.model_validate(data)
1797
+ else:
1798
+ return Rollout.model_validate(data)
1799
+ except Exception as e:
1800
+ client_logger.error(
1801
+ f"get_rollout_by_id failed after all retries for rollout_id={rollout_id}: {e}", exc_info=True
1802
+ )
1803
+ return None
1804
+
1805
+ async def query_resources(
1806
+ self,
1807
+ *,
1808
+ resources_id: Optional[str] = None,
1809
+ resources_id_contains: Optional[str] = None,
1810
+ sort_by: Optional[str] = None,
1811
+ sort_order: Literal["asc", "desc"] = "asc",
1812
+ limit: int = -1,
1813
+ offset: int = 0,
1814
+ ) -> PaginatedResult[ResourcesUpdate]:
1815
+ """
1816
+ List all resource snapshots stored on the server.
1817
+ """
1818
+ params: List[Tuple[str, Any]] = [
1819
+ ("limit", limit),
1820
+ ("offset", offset),
1821
+ ]
1822
+ if sort_by is not None:
1823
+ params.append(("sort_by", sort_by))
1824
+ params.append(("sort_order", sort_order))
1825
+ if resources_id is not None:
1826
+ params.append(("resources_id", resources_id))
1827
+ if resources_id_contains is not None:
1828
+ params.append(("resources_id_contains", resources_id_contains))
1829
+
1830
+ data = await self._request_json("get", "/resources", params=params)
1831
+ items = [ResourcesUpdate.model_validate(item) for item in data["items"]]
1832
+ return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
1833
+
1834
+ async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
1835
+ data = await self._request_json("post", "/resources", json=TypeAdapter(NamedResources).dump_python(resources))
1836
+ return ResourcesUpdate.model_validate(data)
1837
+
1838
+ async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
1839
+ data = await self._request_json(
1840
+ "post", f"/resources/{resources_id}", json=TypeAdapter(NamedResources).dump_python(resources)
1841
+ )
1842
+ return ResourcesUpdate.model_validate(data)
1843
+
1844
+ async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
1845
+ """
1846
+ Get resources by their ID.
1847
+
1848
+ Args:
1849
+ resources_id: ID of the resources to retrieve.
1850
+
1851
+ Returns:
1852
+ ResourcesUpdate if found, None if not found or if all retries are exhausted.
1853
+
1854
+ Note:
1855
+ This method retries on transient failures (network errors, 5xx status codes).
1856
+ If all retries fail, it logs the error and returns None instead of raising an exception.
1857
+ """
1858
+ try:
1859
+ data = await self._request_json("get", f"/resources/{resources_id}")
1860
+ return ResourcesUpdate.model_validate(data) if data else None
1861
+ except Exception as e:
1862
+ client_logger.error(
1863
+ f"get_resources_by_id failed after all retries for resources_id={resources_id}: {e}", exc_info=True
1864
+ )
1865
+ return None
1866
+
1867
+ async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
1868
+ """
1869
+ Get the latest resources.
1870
+
1871
+ Returns:
1872
+ ResourcesUpdate if found, None if not found or if all retries are exhausted.
1873
+
1874
+ Note:
1875
+ This method retries on transient failures (network errors, 5xx status codes).
1876
+ If all retries fail, it logs the error and returns None instead of raising an exception.
1877
+ """
1878
+ try:
1879
+ data = await self._request_json("get", "/resources/latest")
1880
+ return ResourcesUpdate.model_validate(data) if data else None
1881
+ except Exception as e:
1882
+ client_logger.error(f"get_latest_resources failed after all retries: {e}", exc_info=True)
1883
+ return None
1884
+
1885
+ async def add_span(self, span: Span) -> Optional[Span]:
1886
+ data = await self._request_json("post", "/spans", json=span.model_dump(mode="json"))
1887
+ return Span.model_validate(data) if data is not None else None
1888
+
1889
+ async def add_many_spans(self, spans: Sequence[Span]) -> Sequence[Span]:
1890
+ result: List[Span] = []
1891
+ for span in spans:
1892
+ ret = await self.add_span(span)
1893
+ if ret is not None:
1894
+ result.append(ret)
1895
+ return result
1896
+
1897
+ async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
1898
+ data = await self._request_json(
1899
+ "post",
1900
+ "/spans/next",
1901
+ json=NextSequenceIdRequest(rollout_id=rollout_id, attempt_id=attempt_id).model_dump(),
1902
+ )
1903
+ response = NextSequenceIdResponse.model_validate(data)
1904
+ return response.sequence_id
1905
+
1906
+ async def get_many_span_sequence_ids(self, rollout_attempt_ids: Sequence[Tuple[str, str]]) -> Sequence[int]:
1907
+ return [
1908
+ await self.get_next_span_sequence_id(rollout_id, attempt_id)
1909
+ for rollout_id, attempt_id in rollout_attempt_ids
1910
+ ]
1911
+
1912
+ async def add_otel_span(
1913
+ self,
1914
+ rollout_id: str,
1915
+ attempt_id: str,
1916
+ readable_span: ReadableSpan,
1917
+ sequence_id: int | None = None,
1918
+ ) -> Optional[Span]:
1919
+ # unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id
1920
+ if sequence_id is None:
1921
+ sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id)
1922
+ span = Span.from_opentelemetry(
1923
+ readable_span,
1924
+ rollout_id=rollout_id,
1925
+ attempt_id=attempt_id,
1926
+ sequence_id=sequence_id,
1927
+ )
1928
+ return await self.add_span(span)
1929
+
1930
+ async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
1931
+ """Wait for rollouts to complete.
1932
+
1933
+ Args:
1934
+ rollout_ids: List of rollout IDs to wait for.
1935
+ timeout: Timeout in seconds. If not None, the method will raise a ValueError if the timeout is greater than 0.1 seconds.
1936
+
1937
+ Returns:
1938
+ List of rollouts that are completed.
1939
+ """
1940
+ if timeout is not None and timeout > 0.1:
1941
+ raise ValueError(
1942
+ "Timeout must be less than 0.1 seconds in LightningStoreClient to avoid blocking the event loop"
1943
+ )
1944
+ data = await self._request_json(
1945
+ "post",
1946
+ "/waits/rollouts",
1947
+ json=WaitForRolloutsRequest(rollout_ids=rollout_ids, timeout=timeout).model_dump(),
1948
+ )
1949
+ return [Rollout.model_validate(item) for item in data]
1950
+
1951
+ async def query_spans(
1952
+ self,
1953
+ rollout_id: str,
1954
+ attempt_id: str | Literal["latest"] | None = None,
1955
+ *,
1956
+ trace_id: Optional[str] = None,
1957
+ trace_id_contains: Optional[str] = None,
1958
+ span_id: Optional[str] = None,
1959
+ span_id_contains: Optional[str] = None,
1960
+ parent_id: Optional[str] = None,
1961
+ parent_id_contains: Optional[str] = None,
1962
+ name: Optional[str] = None,
1963
+ name_contains: Optional[str] = None,
1964
+ filter_logic: Literal["and", "or"] = "and",
1965
+ limit: int = -1,
1966
+ offset: int = 0,
1967
+ sort_by: Optional[str] = "sequence_id",
1968
+ sort_order: Literal["asc", "desc"] = "asc",
1969
+ ) -> PaginatedResult[Span]:
1970
+ payload: Dict[str, Any] = {"rollout_id": rollout_id, "limit": limit, "offset": offset}
1971
+ if attempt_id is not None:
1972
+ payload["attempt_id"] = attempt_id
1973
+ if trace_id is not None:
1974
+ payload["trace_id"] = trace_id
1975
+ if trace_id_contains is not None:
1976
+ payload["trace_id_contains"] = trace_id_contains
1977
+ if span_id is not None:
1978
+ payload["span_id"] = span_id
1979
+ if span_id_contains is not None:
1980
+ payload["span_id_contains"] = span_id_contains
1981
+ if parent_id is not None:
1982
+ payload["parent_id"] = parent_id
1983
+ if parent_id_contains is not None:
1984
+ payload["parent_id_contains"] = parent_id_contains
1985
+ if name is not None:
1986
+ payload["name"] = name
1987
+ if name_contains is not None:
1988
+ payload["name_contains"] = name_contains
1989
+ payload["filter_logic"] = filter_logic
1990
+ if sort_by is not None:
1991
+ payload["sort_by"] = sort_by
1992
+ payload["sort_order"] = sort_order
1993
+ data = await self._request_json("post", "/spans/search", json=payload)
1994
+ items = [Span.model_validate(item) for item in data["items"]]
1995
+ return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
1996
+
1997
+ async def update_rollout(
1998
+ self,
1999
+ rollout_id: str,
2000
+ input: TaskInput | Unset = UNSET,
2001
+ mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
2002
+ resources_id: Optional[str] | Unset = UNSET,
2003
+ status: RolloutStatus | Unset = UNSET,
2004
+ config: RolloutConfig | Unset = UNSET,
2005
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
2006
+ ) -> Rollout:
2007
+ payload: Dict[str, Any] = {}
2008
+ if not isinstance(input, Unset):
2009
+ payload["input"] = input
2010
+ if not isinstance(mode, Unset):
2011
+ payload["mode"] = mode
2012
+ if not isinstance(resources_id, Unset):
2013
+ payload["resources_id"] = resources_id
2014
+ if not isinstance(status, Unset):
2015
+ payload["status"] = status
2016
+ if not isinstance(config, Unset):
2017
+ payload["config"] = config.model_dump()
2018
+ if not isinstance(metadata, Unset):
2019
+ payload["metadata"] = metadata
2020
+
2021
+ data = await self._request_json("post", f"/rollouts/{rollout_id}", json=payload)
2022
+ return Rollout.model_validate(data)
2023
+
2024
+ async def update_attempt(
2025
+ self,
2026
+ rollout_id: str,
2027
+ attempt_id: str | Literal["latest"],
2028
+ status: AttemptStatus | Unset = UNSET,
2029
+ worker_id: str | Unset = UNSET,
2030
+ last_heartbeat_time: float | Unset = UNSET,
2031
+ metadata: Optional[Dict[str, Any]] | Unset = UNSET,
2032
+ ) -> Attempt:
2033
+ payload: Dict[str, Any] = {}
2034
+ if not isinstance(status, Unset):
2035
+ payload["status"] = status
2036
+ if not isinstance(worker_id, Unset):
2037
+ payload["worker_id"] = worker_id
2038
+ if not isinstance(last_heartbeat_time, Unset):
2039
+ payload["last_heartbeat_time"] = last_heartbeat_time
2040
+ if not isinstance(metadata, Unset):
2041
+ payload["metadata"] = metadata
2042
+
2043
+ data = await self._request_json(
2044
+ "post",
2045
+ f"/rollouts/{rollout_id}/attempts/{attempt_id}",
2046
+ json=payload,
2047
+ )
2048
+ return Attempt.model_validate(data)
2049
+
2050
+ async def query_workers(
2051
+ self,
2052
+ *,
2053
+ status_in: Optional[Sequence[WorkerStatus]] = None,
2054
+ worker_id_contains: Optional[str] = None,
2055
+ filter_logic: Literal["and", "or"] = "and",
2056
+ sort_by: Optional[str] = None,
2057
+ sort_order: Literal["asc", "desc"] = "asc",
2058
+ limit: int = -1,
2059
+ offset: int = 0,
2060
+ ) -> PaginatedResult[Worker]:
2061
+ payload: Dict[str, Any] = {}
2062
+ if status_in is not None:
2063
+ payload["status_in"] = status_in
2064
+ if worker_id_contains is not None:
2065
+ payload["worker_id_contains"] = worker_id_contains
2066
+ payload["filter_logic"] = filter_logic
2067
+ if sort_by is not None:
2068
+ payload["sort_by"] = sort_by
2069
+ payload["sort_order"] = sort_order
2070
+
2071
+ data = await self._request_json("post", "/workers/search", json=payload)
2072
+ items = [Worker.model_validate(item) for item in data.get("items", [])]
2073
+ return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
2074
+
2075
+ async def get_worker_by_id(self, worker_id: str) -> Optional[Worker]:
2076
+ data = await self._request_json("get", f"/workers/{worker_id}")
2077
+ if data is None:
2078
+ return None
2079
+ return Worker.model_validate(data)
2080
+
2081
+ async def update_worker(
2082
+ self,
2083
+ worker_id: str,
2084
+ heartbeat_stats: Dict[str, Any] | Unset = UNSET,
2085
+ ) -> Worker:
2086
+ payload: Dict[str, Any] = {}
2087
+ if not isinstance(heartbeat_stats, Unset):
2088
+ payload["heartbeat_stats"] = heartbeat_stats
2089
+ json_payload = payload if payload else None
2090
+
2091
+ data = await self._request_json("post", f"/workers/{worker_id}", json=json_payload)
2092
+ return Worker.model_validate(data)