mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,2092 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
import traceback
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
Awaitable,
|
|
16
|
+
Callable,
|
|
17
|
+
Dict,
|
|
18
|
+
List,
|
|
19
|
+
Literal,
|
|
20
|
+
Mapping,
|
|
21
|
+
Optional,
|
|
22
|
+
Sequence,
|
|
23
|
+
Tuple,
|
|
24
|
+
Type,
|
|
25
|
+
TypeVar,
|
|
26
|
+
Union,
|
|
27
|
+
cast,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
import aiohttp
|
|
31
|
+
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException
|
|
32
|
+
from fastapi import Query as FastAPIQuery
|
|
33
|
+
from fastapi import Request, Response
|
|
34
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
35
|
+
from fastapi.responses import FileResponse, JSONResponse
|
|
36
|
+
from fastapi.staticfiles import StaticFiles
|
|
37
|
+
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
38
|
+
ExportTraceServiceRequest as PbExportTraceServiceRequest,
|
|
39
|
+
)
|
|
40
|
+
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
41
|
+
ExportTraceServiceResponse as PbExportTraceServiceResponse,
|
|
42
|
+
)
|
|
43
|
+
from opentelemetry.sdk.trace import ReadableSpan
|
|
44
|
+
from pydantic import BaseModel, Field, TypeAdapter
|
|
45
|
+
|
|
46
|
+
from mantisdk.types import (
|
|
47
|
+
Attempt,
|
|
48
|
+
AttemptedRollout,
|
|
49
|
+
AttemptStatus,
|
|
50
|
+
EnqueueRolloutRequest,
|
|
51
|
+
NamedResources,
|
|
52
|
+
PaginatedResult,
|
|
53
|
+
ResourcesUpdate,
|
|
54
|
+
Rollout,
|
|
55
|
+
RolloutConfig,
|
|
56
|
+
RolloutStatus,
|
|
57
|
+
Span,
|
|
58
|
+
TaskInput,
|
|
59
|
+
Worker,
|
|
60
|
+
WorkerStatus,
|
|
61
|
+
)
|
|
62
|
+
from mantisdk.utils.metrics import MetricsBackend, get_prometheus_registry
|
|
63
|
+
from mantisdk.utils.otlp import handle_otlp_export, spans_from_proto
|
|
64
|
+
from mantisdk.utils.server_launcher import LaunchMode, PythonServerLauncher, PythonServerLauncherArgs
|
|
65
|
+
|
|
66
|
+
from .base import UNSET, LightningStore, LightningStoreCapabilities, LightningStoreStatistics, Unset
|
|
67
|
+
from .collection.base import resolve_error_type
|
|
68
|
+
from .utils import LATENCY_BUCKETS
|
|
69
|
+
|
|
70
|
+
server_logger = logging.getLogger("mantisdk.store.server")
|
|
71
|
+
client_logger = logging.getLogger("mantisdk.store.client")
|
|
72
|
+
|
|
73
|
+
API_V1_PREFIX = "/v1"
|
|
74
|
+
API_AGL_PREFIX = "/msk"
|
|
75
|
+
API_V1_AGL_PREFIX = API_V1_PREFIX + API_AGL_PREFIX
|
|
76
|
+
|
|
77
|
+
T = TypeVar("T")
|
|
78
|
+
T_model = TypeVar("T_model", bound=BaseModel)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RolloutRequest(BaseModel):
|
|
82
|
+
input: TaskInput
|
|
83
|
+
mode: Optional[Literal["train", "val", "test"]] = None
|
|
84
|
+
resources_id: Optional[str] = None
|
|
85
|
+
config: Optional[RolloutConfig] = None
|
|
86
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
87
|
+
worker_id: Optional[str] = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DequeueRolloutRequest(BaseModel):
|
|
91
|
+
worker_id: Optional[str] = None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class StartAttemptRequest(BaseModel):
|
|
95
|
+
worker_id: Optional[str] = None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class EnqueueManyRolloutsRequest(BaseModel):
|
|
99
|
+
rollouts: List[EnqueueRolloutRequest]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class DequeueManyRolloutsRequest(BaseModel):
|
|
103
|
+
limit: int = 1
|
|
104
|
+
worker_id: Optional[str] = None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class QueryRolloutsRequest(BaseModel):
|
|
108
|
+
status_in: Optional[List[RolloutStatus]] = Field(FastAPIQuery(default=None))
|
|
109
|
+
rollout_id_in: Optional[List[str]] = Field(FastAPIQuery(default=None))
|
|
110
|
+
rollout_id_contains: Optional[str] = None
|
|
111
|
+
# Pagination
|
|
112
|
+
limit: int = -1
|
|
113
|
+
offset: int = 0
|
|
114
|
+
# Sorting
|
|
115
|
+
sort_by: Optional[str] = None
|
|
116
|
+
sort_order: Literal["asc", "desc"] = "asc"
|
|
117
|
+
# Filtering logic
|
|
118
|
+
filter_logic: Literal["and", "or"] = "and"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class WaitForRolloutsRequest(BaseModel):
|
|
122
|
+
rollout_ids: List[str]
|
|
123
|
+
timeout: Optional[float] = None
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class NextSequenceIdRequest(BaseModel):
|
|
127
|
+
rollout_id: str
|
|
128
|
+
attempt_id: str
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class NextSequenceIdResponse(BaseModel):
|
|
132
|
+
sequence_id: int
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class UpdateRolloutRequest(BaseModel):
|
|
136
|
+
input: Optional[TaskInput] = None
|
|
137
|
+
mode: Optional[Literal["train", "val", "test"]] = None
|
|
138
|
+
resources_id: Optional[str] = None
|
|
139
|
+
status: Optional[RolloutStatus] = None
|
|
140
|
+
config: Optional[RolloutConfig] = None
|
|
141
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class UpdateAttemptRequest(BaseModel):
|
|
145
|
+
status: Optional[AttemptStatus] = None
|
|
146
|
+
worker_id: Optional[str] = None
|
|
147
|
+
last_heartbeat_time: Optional[float] = None
|
|
148
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class UpdateWorkerRequest(BaseModel):
|
|
152
|
+
heartbeat_stats: Optional[Dict[str, Any]] = None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class QueryAttemptsRequest(BaseModel):
|
|
156
|
+
# Pagination
|
|
157
|
+
limit: int = -1
|
|
158
|
+
offset: int = 0
|
|
159
|
+
# Sorting
|
|
160
|
+
sort_by: Optional[str] = "sequence_id"
|
|
161
|
+
sort_order: Literal["asc", "desc"] = "asc"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class QueryResourcesRequest(BaseModel):
|
|
165
|
+
# Filtering
|
|
166
|
+
resources_id: Optional[str] = None
|
|
167
|
+
resources_id_contains: Optional[str] = None
|
|
168
|
+
# Pagination
|
|
169
|
+
limit: int = -1
|
|
170
|
+
offset: int = 0
|
|
171
|
+
# Sorting
|
|
172
|
+
sort_by: Optional[str] = None
|
|
173
|
+
sort_order: Literal["asc", "desc"] = "asc"
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class QuerySpansRequest(BaseModel):
|
|
177
|
+
rollout_id: str
|
|
178
|
+
attempt_id: Optional[str] = None
|
|
179
|
+
# Filtering
|
|
180
|
+
trace_id: Optional[str] = None
|
|
181
|
+
trace_id_contains: Optional[str] = None
|
|
182
|
+
span_id: Optional[str] = None
|
|
183
|
+
span_id_contains: Optional[str] = None
|
|
184
|
+
parent_id: Optional[str] = None
|
|
185
|
+
parent_id_contains: Optional[str] = None
|
|
186
|
+
name: Optional[str] = None
|
|
187
|
+
name_contains: Optional[str] = None
|
|
188
|
+
filter_logic: Literal["and", "or"] = "and"
|
|
189
|
+
# Pagination
|
|
190
|
+
limit: int = -1
|
|
191
|
+
offset: int = 0
|
|
192
|
+
# Sorting
|
|
193
|
+
sort_by: Optional[str] = "sequence_id"
|
|
194
|
+
sort_order: Literal["asc", "desc"] = "asc"
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class QueryWorkersRequest(BaseModel):
|
|
198
|
+
status_in: Optional[List[WorkerStatus]] = Field(FastAPIQuery(default=None))
|
|
199
|
+
worker_id_contains: Optional[str] = None
|
|
200
|
+
# Pagination
|
|
201
|
+
limit: int = -1
|
|
202
|
+
offset: int = 0
|
|
203
|
+
# Sorting
|
|
204
|
+
sort_by: Optional[str] = None
|
|
205
|
+
sort_order: Literal["asc", "desc"] = "asc"
|
|
206
|
+
# Filtering logic
|
|
207
|
+
filter_logic: Literal["and", "or"] = "and"
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class CachedStaticFiles(StaticFiles):
|
|
211
|
+
def file_response(self, *args: Any, **kwargs: Any) -> Response:
|
|
212
|
+
resp = super().file_response(*args, **kwargs)
|
|
213
|
+
# hashed filenames are safe to cache "forever"
|
|
214
|
+
resp.headers.setdefault("Cache-Control", "public, max-age=31536000, immutable")
|
|
215
|
+
return resp
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class LightningStoreServer(LightningStore):
|
|
219
|
+
"""
|
|
220
|
+
Server wrapper that exposes a LightningStore via HTTP API.
|
|
221
|
+
Delegates all operations to an underlying store implementation.
|
|
222
|
+
|
|
223
|
+
Healthcheck and watchdog relies on the underlying store.
|
|
224
|
+
|
|
225
|
+
`msk store` is a convenient CLI to start a store server.
|
|
226
|
+
|
|
227
|
+
When the server is executed in a subprocess, the store will discover itself having a different PID
|
|
228
|
+
and automatically delegate to an HTTP client instead of using the local store.
|
|
229
|
+
This ensures one single copy of the store will be shared across all processes.
|
|
230
|
+
|
|
231
|
+
This server exporting OTLP-compatible traces via the `/v1/traces` endpoint.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
store: The underlying store to delegate operations to.
|
|
235
|
+
host: The hostname or IP address to bind the server to.
|
|
236
|
+
port: The TCP port to listen on.
|
|
237
|
+
cors_allow_origins: A list of CORS origins to allow. Use '*' to allow all origins.
|
|
238
|
+
launch_mode: The launch mode to use for the server. Defaults to "thread",
|
|
239
|
+
which runs the server in a separate thread.
|
|
240
|
+
launcher_args: The arguments to use for the server launcher.
|
|
241
|
+
It's not allowed to set `host`, `port`, `launch_mode` together with `launcher_args`.
|
|
242
|
+
n_workers: The number of workers to run in the server. Only applicable for `mp` launch mode.
|
|
243
|
+
tracker: The metrics tracker to use for the server.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def __init__(
|
|
247
|
+
self,
|
|
248
|
+
store: LightningStore,
|
|
249
|
+
host: str | None = None,
|
|
250
|
+
port: int | None = None,
|
|
251
|
+
cors_allow_origins: Sequence[str] | str | None = None,
|
|
252
|
+
launch_mode: LaunchMode = "thread",
|
|
253
|
+
launcher_args: PythonServerLauncherArgs | None = None,
|
|
254
|
+
n_workers: int = 1,
|
|
255
|
+
tracker: MetricsBackend | None = None,
|
|
256
|
+
):
|
|
257
|
+
super().__init__()
|
|
258
|
+
self.store = store
|
|
259
|
+
|
|
260
|
+
if launcher_args is not None:
|
|
261
|
+
if host is not None or port is not None or launch_mode != "thread":
|
|
262
|
+
raise ValueError("host, port, and launch_mode cannot be set when launcher_args is provided.")
|
|
263
|
+
self.launcher_args = launcher_args
|
|
264
|
+
else:
|
|
265
|
+
if port is None:
|
|
266
|
+
server_logger.warning("No port provided, using default port 4747.")
|
|
267
|
+
port = 4747
|
|
268
|
+
self.launcher_args = PythonServerLauncherArgs(
|
|
269
|
+
host=host,
|
|
270
|
+
port=port,
|
|
271
|
+
launch_mode=launch_mode,
|
|
272
|
+
healthcheck_url=API_V1_AGL_PREFIX + "/health",
|
|
273
|
+
n_workers=n_workers,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
store_capabilities = self.store.capabilities
|
|
277
|
+
if not store_capabilities.get("async_safe", False):
|
|
278
|
+
raise ValueError("The store is not async-safe. Please use another store for the server.")
|
|
279
|
+
if self.launcher_args.launch_mode == "mp" and not store_capabilities.get("zero_copy", False):
|
|
280
|
+
raise ValueError(
|
|
281
|
+
"The store does not support zero-copy. Please use another store, or use asyncio or thread mode to launch the server."
|
|
282
|
+
)
|
|
283
|
+
if self.launcher_args.launch_mode == "thread" and not store_capabilities.get("thread_safe", False):
|
|
284
|
+
server_logger.warning(
|
|
285
|
+
"The store is not thread-safe. Please be careful when using the store server and the underlying store in different threads."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
self.app: FastAPI | None = FastAPI(title="LightningStore Server")
|
|
289
|
+
self.server_launcher = PythonServerLauncher(
|
|
290
|
+
app=self.app,
|
|
291
|
+
args=self.launcher_args,
|
|
292
|
+
)
|
|
293
|
+
self._tracker = tracker
|
|
294
|
+
|
|
295
|
+
self._lock: threading.Lock = threading.Lock()
|
|
296
|
+
self._cors_allow_origins = self._normalize_cors_origins(cors_allow_origins)
|
|
297
|
+
self._apply_cors()
|
|
298
|
+
self._setup_routes()
|
|
299
|
+
|
|
300
|
+
# Process-awareness:
|
|
301
|
+
# LightningStoreServer holds a plain Python object (self.store) in one process
|
|
302
|
+
# (the process that runs uvicorn/FastAPI).
|
|
303
|
+
# When you multiprocessing.Process(...) and call methods on a different LightningStore instance
|
|
304
|
+
# (or on a copy inherited via fork), you’re mutating another process’s memory, not the server’s memory.
|
|
305
|
+
# So we need to track the owner process (whoever creates the server),
|
|
306
|
+
# and only mutate the store in that process.
|
|
307
|
+
self._owner_pid = os.getpid()
|
|
308
|
+
self._client: Optional[LightningStoreClient] = None
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def capabilities(self) -> LightningStoreCapabilities:
|
|
312
|
+
"""Return the capabilities of the store."""
|
|
313
|
+
return LightningStoreCapabilities(
|
|
314
|
+
async_safe=True,
|
|
315
|
+
thread_safe=True,
|
|
316
|
+
zero_copy=True,
|
|
317
|
+
otlp_traces=True,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
def otlp_traces_endpoint(self) -> str:
|
|
321
|
+
"""Return the OTLP/HTTP traces endpoint of the store."""
|
|
322
|
+
return f"{self.endpoint}/v1/traces"
|
|
323
|
+
|
|
324
|
+
def __getstate__(self):
|
|
325
|
+
"""
|
|
326
|
+
Control pickling to prevent server state from being sent to subprocesses.
|
|
327
|
+
|
|
328
|
+
When LightningStoreServer is pickled (e.g., passed to a subprocess), we only
|
|
329
|
+
serialize the underlying store and connection details. The client instance
|
|
330
|
+
and process-awareness state are excluded as they should not be transferred between processes.
|
|
331
|
+
|
|
332
|
+
The subprocess should create its own server instance if needed.
|
|
333
|
+
"""
|
|
334
|
+
# server-launcher is needed for the host/port address are propagated to the subprocess
|
|
335
|
+
return {
|
|
336
|
+
"launcher_args": self.launcher_args,
|
|
337
|
+
"server_launcher": self.server_launcher,
|
|
338
|
+
"_owner_pid": self._owner_pid,
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
def __setstate__(self, state: Dict[str, Any]):
|
|
342
|
+
"""
|
|
343
|
+
Restore from pickle by reconstructing only the essential attributes.
|
|
344
|
+
|
|
345
|
+
Note: This creates a new server instance without FastAPI/uvicorn initialized.
|
|
346
|
+
Call __init__() pattern or create a new LightningStoreServer if you need
|
|
347
|
+
a fully functional server in the subprocess.
|
|
348
|
+
The unpickled server will also have no app and store attributes,
|
|
349
|
+
this is to make sure there is only one copy of the server in the whole system.
|
|
350
|
+
"""
|
|
351
|
+
self.app = None
|
|
352
|
+
self.store = None
|
|
353
|
+
self.launcher_args = state["launcher_args"]
|
|
354
|
+
self.server_launcher = state["server_launcher"]
|
|
355
|
+
self._tracker = None
|
|
356
|
+
self._owner_pid = state["_owner_pid"]
|
|
357
|
+
self._cors_allow_origins = state.get("_cors_allow_origins")
|
|
358
|
+
self._client = None
|
|
359
|
+
self._lock = threading.Lock()
|
|
360
|
+
self._prometheus_registry = None
|
|
361
|
+
# Do NOT reconstruct app, _uvicorn_config, _uvicorn_server
|
|
362
|
+
# to avoid transferring server state to subprocess
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def _normalize_cors_origins(
|
|
366
|
+
origins: Sequence[str] | str | None,
|
|
367
|
+
) -> list[str] | None:
|
|
368
|
+
if origins is None:
|
|
369
|
+
return None
|
|
370
|
+
|
|
371
|
+
if isinstance(origins, str):
|
|
372
|
+
candidates = [origins]
|
|
373
|
+
else:
|
|
374
|
+
candidates = list(origins)
|
|
375
|
+
|
|
376
|
+
cleaned: list[str] = []
|
|
377
|
+
for origin in candidates:
|
|
378
|
+
if not origin or not origin.strip():
|
|
379
|
+
continue
|
|
380
|
+
value = origin.strip()
|
|
381
|
+
if value == "*":
|
|
382
|
+
return ["*"]
|
|
383
|
+
cleaned.append(value)
|
|
384
|
+
|
|
385
|
+
return cleaned or None
|
|
386
|
+
|
|
387
|
+
def _apply_cors(self) -> None:
|
|
388
|
+
if self.app is None or not self._cors_allow_origins:
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
self.app.add_middleware(
|
|
392
|
+
CORSMiddleware,
|
|
393
|
+
allow_origins=self._cors_allow_origins.copy(),
|
|
394
|
+
allow_methods=["*"],
|
|
395
|
+
allow_headers=["*"],
|
|
396
|
+
allow_credentials=True,
|
|
397
|
+
expose_headers=["*"],
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def endpoint(self) -> str:
|
|
402
|
+
"""Endpoint is the address that the client will use to connect to the server."""
|
|
403
|
+
return self.server_launcher.access_endpoint
|
|
404
|
+
|
|
405
|
+
async def start(self):
|
|
406
|
+
"""Starts the FastAPI server in the background.
|
|
407
|
+
|
|
408
|
+
You need to call this method in the same process as the server was created in.
|
|
409
|
+
"""
|
|
410
|
+
server_logger.info(
|
|
411
|
+
f"Serving the lightning store at {self.server_launcher.endpoint}, accessible at {self.server_launcher.access_endpoint}"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
start_time = time.time()
|
|
415
|
+
await self.server_launcher.start()
|
|
416
|
+
end_time = time.time()
|
|
417
|
+
server_logger.info(f"Lightning store server started in {end_time - start_time:.2f} seconds")
|
|
418
|
+
|
|
419
|
+
async def run_forever(self):
|
|
420
|
+
"""Runs the FastAPI server indefinitely."""
|
|
421
|
+
server_logger.info(
|
|
422
|
+
f"Running the lightning store server at {self.server_launcher.endpoint}, accessible at {self.server_launcher.access_endpoint}"
|
|
423
|
+
)
|
|
424
|
+
await self.server_launcher.run_forever()
|
|
425
|
+
|
|
426
|
+
async def stop(self):
|
|
427
|
+
"""Gracefully stops the running FastAPI server.
|
|
428
|
+
|
|
429
|
+
You need to call this method in the same process as the server was created in.
|
|
430
|
+
"""
|
|
431
|
+
server_logger.info("Stopping the lightning store server...")
|
|
432
|
+
await self.server_launcher.stop()
|
|
433
|
+
server_logger.info("Lightning store server stopped.")
|
|
434
|
+
|
|
435
|
+
def _setup_routes(self):
|
|
436
|
+
"""Set up FastAPI routes for all store operations."""
|
|
437
|
+
assert self.app is not None
|
|
438
|
+
api = APIRouter(prefix=API_V1_PREFIX)
|
|
439
|
+
|
|
440
|
+
# The outermost-layer of monitoring
|
|
441
|
+
if self._tracker is not None:
|
|
442
|
+
self._setup_metrics(api=api, app=self.app)
|
|
443
|
+
|
|
444
|
+
# TODO: This should only be enabled in development mode.
|
|
445
|
+
@self.app.middleware("http")
|
|
446
|
+
async def _app_exception_handler( # pyright: ignore[reportUnusedFunction]
|
|
447
|
+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
448
|
+
) -> Response:
|
|
449
|
+
"""
|
|
450
|
+
Convert unhandled application exceptions into 500 responses.
|
|
451
|
+
|
|
452
|
+
Only covers /v1/msk requests.
|
|
453
|
+
|
|
454
|
+
- Client needs a reliable signal to distinguish "app bug / bad request"
|
|
455
|
+
from transport/session failures.
|
|
456
|
+
- 400 means "do not retry"; network issues will surface as aiohttp
|
|
457
|
+
exceptions or 5xx and will be retried by the client shield.
|
|
458
|
+
"""
|
|
459
|
+
try:
|
|
460
|
+
return await call_next(request)
|
|
461
|
+
except Exception as exc:
|
|
462
|
+
# decide whether to convert this into your 400 JSONResponse
|
|
463
|
+
if request.url.path.startswith(API_V1_AGL_PREFIX):
|
|
464
|
+
server_logger.exception("Unhandled application error", exc_info=exc)
|
|
465
|
+
payload = {
|
|
466
|
+
"detail": "Internal server error",
|
|
467
|
+
"error_type": type(exc).__name__,
|
|
468
|
+
"traceback": traceback.format_exc(),
|
|
469
|
+
}
|
|
470
|
+
# 500 so clients can decide to retry
|
|
471
|
+
return JSONResponse(status_code=500, content=payload)
|
|
472
|
+
# otherwise re-raise and let FastAPI/Starlette handle it (500 or other handlers)
|
|
473
|
+
raise
|
|
474
|
+
|
|
475
|
+
@self.app.middleware("http")
|
|
476
|
+
async def _log_time( # pyright: ignore[reportUnusedFunction]
|
|
477
|
+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
478
|
+
):
|
|
479
|
+
# If not API request, just pass through
|
|
480
|
+
if not request.url.path.startswith(API_V1_AGL_PREFIX) and not request.url.path.startswith(
|
|
481
|
+
API_V1_PREFIX + "/traces"
|
|
482
|
+
):
|
|
483
|
+
return await call_next(request)
|
|
484
|
+
|
|
485
|
+
start = time.perf_counter()
|
|
486
|
+
response = await call_next(request)
|
|
487
|
+
duration = (time.perf_counter() - start) * 1000
|
|
488
|
+
client = request.client
|
|
489
|
+
if client is None:
|
|
490
|
+
client_address = "unknown"
|
|
491
|
+
else:
|
|
492
|
+
client_address = f"{client.host}:{client.port}"
|
|
493
|
+
server_logger.debug(
|
|
494
|
+
f"{client_address} - "
|
|
495
|
+
f'"{request.method} {request.url.path} HTTP/{request.scope["http_version"]}" '
|
|
496
|
+
f"{response.status_code} in {duration:.2f} ms"
|
|
497
|
+
)
|
|
498
|
+
return response
|
|
499
|
+
|
|
500
|
+
def _validate_paginated_request(
|
|
501
|
+
request: Union[
|
|
502
|
+
QueryRolloutsRequest,
|
|
503
|
+
QueryAttemptsRequest,
|
|
504
|
+
QueryResourcesRequest,
|
|
505
|
+
QueryWorkersRequest,
|
|
506
|
+
QuerySpansRequest,
|
|
507
|
+
],
|
|
508
|
+
target_type: Type[T_model],
|
|
509
|
+
) -> None:
|
|
510
|
+
"""Raise an error early if the request is not a valid paginated request."""
|
|
511
|
+
if request.sort_by is not None and request.sort_by not in target_type.model_fields:
|
|
512
|
+
raise HTTPException(
|
|
513
|
+
status_code=400,
|
|
514
|
+
detail=f"Invalid sort_by: {request.sort_by}, allowed fields are: {', '.join(target_type.model_fields.keys())}",
|
|
515
|
+
)
|
|
516
|
+
if request.sort_order not in ["asc", "desc"]:
|
|
517
|
+
raise HTTPException(
|
|
518
|
+
status_code=400, detail=f"Invalid sort_order: {request.sort_order}, allowed values are: asc, desc"
|
|
519
|
+
)
|
|
520
|
+
if request.limit == 0 or (request.limit < 0 and request.limit != -1):
|
|
521
|
+
raise HTTPException(status_code=400, detail="Limit must be greater than 0 or -1 for no limit")
|
|
522
|
+
if not request.offset >= 0:
|
|
523
|
+
raise HTTPException(status_code=400, detail="Offset must be greater than or equal to 0")
|
|
524
|
+
if hasattr(request, "filter_logic") and request.filter_logic not in ["and", "or"]: # type: ignore
|
|
525
|
+
raise HTTPException(
|
|
526
|
+
status_code=400, detail=f"Invalid filter_logic: {request.filter_logic}, allowed values are: and, or" # type: ignore
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
def _build_paginated_response(items: Sequence[Any], *, limit: int, offset: int) -> PaginatedResult[Any]:
|
|
530
|
+
"""FastAPI routes expect PaginatedResult payloads; wrap plain lists accordingly."""
|
|
531
|
+
if isinstance(items, PaginatedResult):
|
|
532
|
+
return items
|
|
533
|
+
|
|
534
|
+
# Assuming it's a list.
|
|
535
|
+
server_logger.warning(
|
|
536
|
+
"PaginatedResult expected; got a plain list. Converting to PaginatedResult. "
|
|
537
|
+
"Total items count will be inaccurate: %d",
|
|
538
|
+
len(items),
|
|
539
|
+
)
|
|
540
|
+
return PaginatedResult(items=items, limit=limit, offset=offset, total=len(items))
|
|
541
|
+
|
|
542
|
+
@api.get(API_AGL_PREFIX + "/health")
|
|
543
|
+
async def health(): # pyright: ignore[reportUnusedFunction]
|
|
544
|
+
return {"status": "ok"}
|
|
545
|
+
|
|
546
|
+
@api.post(API_AGL_PREFIX + "/queues/rollouts/enqueue", status_code=201, response_model=List[Rollout])
|
|
547
|
+
async def enqueue_rollouts( # pyright: ignore[reportUnusedFunction]
|
|
548
|
+
request: EnqueueManyRolloutsRequest,
|
|
549
|
+
) -> List[Rollout]:
|
|
550
|
+
enqueue_requests = request.rollouts
|
|
551
|
+
if not enqueue_requests:
|
|
552
|
+
return []
|
|
553
|
+
if len(enqueue_requests) == 1:
|
|
554
|
+
single = enqueue_requests[0]
|
|
555
|
+
rollout = await self.enqueue_rollout(
|
|
556
|
+
input=single.input,
|
|
557
|
+
mode=single.mode,
|
|
558
|
+
resources_id=single.resources_id,
|
|
559
|
+
config=single.config,
|
|
560
|
+
metadata=single.metadata,
|
|
561
|
+
)
|
|
562
|
+
return [rollout]
|
|
563
|
+
rollouts = await self.enqueue_many_rollouts(enqueue_requests)
|
|
564
|
+
return list(rollouts)
|
|
565
|
+
|
|
566
|
+
@api.post(API_AGL_PREFIX + "/queues/rollouts/dequeue", response_model=List[AttemptedRollout])
|
|
567
|
+
async def dequeue_rollouts( # pyright: ignore[reportUnusedFunction]
|
|
568
|
+
request: DequeueManyRolloutsRequest | None = Body(None),
|
|
569
|
+
) -> List[AttemptedRollout]:
|
|
570
|
+
payload = request or DequeueManyRolloutsRequest()
|
|
571
|
+
if payload.limit <= 0:
|
|
572
|
+
return []
|
|
573
|
+
if payload.limit == 1:
|
|
574
|
+
single = await self.dequeue_rollout(worker_id=payload.worker_id)
|
|
575
|
+
return [single] if single else []
|
|
576
|
+
rollouts = await self.dequeue_many_rollouts(limit=payload.limit, worker_id=payload.worker_id)
|
|
577
|
+
return list(rollouts)
|
|
578
|
+
|
|
579
|
+
@api.post(API_AGL_PREFIX + "/rollouts", status_code=201, response_model=AttemptedRollout)
|
|
580
|
+
async def start_rollout(request: RolloutRequest): # pyright: ignore[reportUnusedFunction]
|
|
581
|
+
return await self.start_rollout(
|
|
582
|
+
input=request.input,
|
|
583
|
+
mode=request.mode,
|
|
584
|
+
resources_id=request.resources_id,
|
|
585
|
+
config=request.config,
|
|
586
|
+
metadata=request.metadata,
|
|
587
|
+
worker_id=request.worker_id,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
@api.get(API_AGL_PREFIX + "/rollouts", response_model=PaginatedResult[Union[AttemptedRollout, Rollout]])
|
|
591
|
+
async def query_rollouts(params: QueryRolloutsRequest = Depends()): # pyright: ignore[reportUnusedFunction]
|
|
592
|
+
_validate_paginated_request(params, Rollout)
|
|
593
|
+
# Get all rollouts from the underlying store
|
|
594
|
+
results = await self.query_rollouts(
|
|
595
|
+
status_in=params.status_in,
|
|
596
|
+
rollout_id_in=params.rollout_id_in,
|
|
597
|
+
rollout_id_contains=params.rollout_id_contains,
|
|
598
|
+
filter_logic=params.filter_logic,
|
|
599
|
+
sort_by=params.sort_by,
|
|
600
|
+
sort_order=params.sort_order,
|
|
601
|
+
limit=params.limit,
|
|
602
|
+
offset=params.offset,
|
|
603
|
+
)
|
|
604
|
+
return _build_paginated_response(results, limit=params.limit, offset=params.offset)
|
|
605
|
+
|
|
606
|
+
@api.post(API_AGL_PREFIX + "/rollouts/search", response_model=PaginatedResult[Union[AttemptedRollout, Rollout]])
|
|
607
|
+
async def search_rollouts(request: QueryRolloutsRequest): # pyright: ignore[reportUnusedFunction]
|
|
608
|
+
_validate_paginated_request(request, Rollout)
|
|
609
|
+
status_in = request.status_in if "status_in" in request.model_fields_set else None
|
|
610
|
+
rollout_id_in = request.rollout_id_in if "rollout_id_in" in request.model_fields_set else None
|
|
611
|
+
# Get all rollouts from the underlying store
|
|
612
|
+
results = await self.query_rollouts(
|
|
613
|
+
status_in=status_in,
|
|
614
|
+
rollout_id_in=rollout_id_in,
|
|
615
|
+
rollout_id_contains=request.rollout_id_contains,
|
|
616
|
+
filter_logic=request.filter_logic,
|
|
617
|
+
sort_by=request.sort_by,
|
|
618
|
+
sort_order=request.sort_order,
|
|
619
|
+
limit=request.limit,
|
|
620
|
+
offset=request.offset,
|
|
621
|
+
)
|
|
622
|
+
return _build_paginated_response(results, limit=request.limit, offset=request.offset)
|
|
623
|
+
|
|
624
|
+
@api.get(API_AGL_PREFIX + "/rollouts/{rollout_id}", response_model=Union[AttemptedRollout, Rollout])
|
|
625
|
+
async def get_rollout_by_id(rollout_id: str): # pyright: ignore[reportUnusedFunction]
|
|
626
|
+
return await self.get_rollout_by_id(rollout_id)
|
|
627
|
+
|
|
628
|
+
def _get_mandatory_field_or_unset(request: BaseModel | None, field: str) -> Any:
|
|
629
|
+
# If some fields are mandatory by the underlying store, but optional in the FastAPI,
|
|
630
|
+
# we make sure it's set to non-null value or UNSET via this function.
|
|
631
|
+
if request is None:
|
|
632
|
+
return UNSET
|
|
633
|
+
if field in request.model_fields_set:
|
|
634
|
+
value = getattr(request, field)
|
|
635
|
+
if value is None:
|
|
636
|
+
raise HTTPException(status_code=400, detail=f"{field} is invalid; it cannot be a null value.")
|
|
637
|
+
return value
|
|
638
|
+
else:
|
|
639
|
+
return UNSET
|
|
640
|
+
|
|
641
|
+
@api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}", response_model=Rollout)
|
|
642
|
+
async def update_rollout( # pyright: ignore[reportUnusedFunction]
|
|
643
|
+
rollout_id: str, request: UpdateRolloutRequest = Body(...)
|
|
644
|
+
):
|
|
645
|
+
return await self.update_rollout(
|
|
646
|
+
rollout_id=rollout_id,
|
|
647
|
+
input=request.input if "input" in request.model_fields_set else UNSET,
|
|
648
|
+
mode=request.mode if "mode" in request.model_fields_set else UNSET,
|
|
649
|
+
resources_id=request.resources_id if "resources_id" in request.model_fields_set else UNSET,
|
|
650
|
+
status=_get_mandatory_field_or_unset(request, "status"),
|
|
651
|
+
config=_get_mandatory_field_or_unset(request, "config"),
|
|
652
|
+
metadata=request.metadata if "metadata" in request.model_fields_set else UNSET,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
@api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts", status_code=201, response_model=AttemptedRollout)
|
|
656
|
+
async def start_attempt( # pyright: ignore[reportUnusedFunction]
|
|
657
|
+
rollout_id: str, request: StartAttemptRequest | None = Body(None)
|
|
658
|
+
):
|
|
659
|
+
worker_id = request.worker_id if request else None
|
|
660
|
+
return await self.start_attempt(rollout_id, worker_id=worker_id)
|
|
661
|
+
|
|
662
|
+
@api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts/search", response_model=PaginatedResult[Attempt])
|
|
663
|
+
async def search_attempts( # pyright: ignore[reportUnusedFunction]
|
|
664
|
+
rollout_id: str, request: QueryAttemptsRequest
|
|
665
|
+
):
|
|
666
|
+
_validate_paginated_request(request, Attempt)
|
|
667
|
+
attempts = await self.query_attempts(
|
|
668
|
+
rollout_id,
|
|
669
|
+
sort_by=request.sort_by,
|
|
670
|
+
sort_order=request.sort_order,
|
|
671
|
+
limit=request.limit,
|
|
672
|
+
offset=request.offset,
|
|
673
|
+
)
|
|
674
|
+
return _build_paginated_response(attempts, limit=request.limit, offset=request.offset)
|
|
675
|
+
|
|
676
|
+
@api.post(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts/{attempt_id}", response_model=Attempt)
|
|
677
|
+
async def update_attempt( # pyright: ignore[reportUnusedFunction]
|
|
678
|
+
rollout_id: str, attempt_id: str, request: UpdateAttemptRequest = Body(...)
|
|
679
|
+
):
|
|
680
|
+
return await self.update_attempt(
|
|
681
|
+
rollout_id=rollout_id,
|
|
682
|
+
attempt_id=attempt_id,
|
|
683
|
+
status=_get_mandatory_field_or_unset(request, "status"),
|
|
684
|
+
worker_id=_get_mandatory_field_or_unset(request, "worker_id"),
|
|
685
|
+
last_heartbeat_time=_get_mandatory_field_or_unset(request, "last_heartbeat_time"),
|
|
686
|
+
metadata=_get_mandatory_field_or_unset(request, "metadata"),
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
@api.get(API_AGL_PREFIX + "/workers", response_model=PaginatedResult[Worker])
|
|
690
|
+
async def query_workers(params: QueryWorkersRequest = Depends()): # pyright: ignore[reportUnusedFunction]
|
|
691
|
+
_validate_paginated_request(params, Worker)
|
|
692
|
+
workers = await self.query_workers(
|
|
693
|
+
status_in=params.status_in,
|
|
694
|
+
worker_id_contains=params.worker_id_contains,
|
|
695
|
+
filter_logic=params.filter_logic,
|
|
696
|
+
sort_by=params.sort_by,
|
|
697
|
+
sort_order=params.sort_order,
|
|
698
|
+
limit=params.limit,
|
|
699
|
+
offset=params.offset,
|
|
700
|
+
)
|
|
701
|
+
return _build_paginated_response(workers, limit=params.limit, offset=params.offset)
|
|
702
|
+
|
|
703
|
+
@api.post(API_AGL_PREFIX + "/workers/search", response_model=PaginatedResult[Worker])
|
|
704
|
+
async def search_workers(request: QueryWorkersRequest): # pyright: ignore[reportUnusedFunction]
|
|
705
|
+
_validate_paginated_request(request, Worker)
|
|
706
|
+
status_in = request.status_in if "status_in" in request.model_fields_set else None
|
|
707
|
+
workers = await self.query_workers(
|
|
708
|
+
status_in=status_in,
|
|
709
|
+
worker_id_contains=request.worker_id_contains,
|
|
710
|
+
filter_logic=request.filter_logic,
|
|
711
|
+
sort_by=request.sort_by,
|
|
712
|
+
sort_order=request.sort_order,
|
|
713
|
+
limit=request.limit,
|
|
714
|
+
offset=request.offset,
|
|
715
|
+
)
|
|
716
|
+
return _build_paginated_response(workers, limit=request.limit, offset=request.offset)
|
|
717
|
+
|
|
718
|
+
@api.get(API_AGL_PREFIX + "/workers/{worker_id}", response_model=Optional[Worker])
|
|
719
|
+
async def get_worker(worker_id: str): # pyright: ignore[reportUnusedFunction]
|
|
720
|
+
return await self.get_worker_by_id(worker_id)
|
|
721
|
+
|
|
722
|
+
@api.post(API_AGL_PREFIX + "/workers/{worker_id}", response_model=Worker)
|
|
723
|
+
async def update_worker( # pyright: ignore[reportUnusedFunction]
|
|
724
|
+
worker_id: str, request: UpdateWorkerRequest | None = Body(None)
|
|
725
|
+
):
|
|
726
|
+
return await self.update_worker(
|
|
727
|
+
worker_id=worker_id,
|
|
728
|
+
heartbeat_stats=_get_mandatory_field_or_unset(request, "heartbeat_stats"),
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
@api.get(API_AGL_PREFIX + "/statistics", response_model=Dict[str, Any])
|
|
732
|
+
async def get_statistics(): # pyright: ignore[reportUnusedFunction]
|
|
733
|
+
return await self.statistics()
|
|
734
|
+
|
|
735
|
+
@api.get(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts", response_model=PaginatedResult[Attempt])
|
|
736
|
+
async def query_attempts( # pyright: ignore[reportUnusedFunction]
|
|
737
|
+
rollout_id: str, params: QueryAttemptsRequest = Depends()
|
|
738
|
+
):
|
|
739
|
+
_validate_paginated_request(params, Attempt)
|
|
740
|
+
attempts = await self.query_attempts(
|
|
741
|
+
rollout_id,
|
|
742
|
+
sort_by=params.sort_by,
|
|
743
|
+
sort_order=params.sort_order,
|
|
744
|
+
limit=params.limit,
|
|
745
|
+
offset=params.offset,
|
|
746
|
+
)
|
|
747
|
+
return _build_paginated_response(attempts, limit=params.limit, offset=params.offset)
|
|
748
|
+
|
|
749
|
+
@api.get(API_AGL_PREFIX + "/rollouts/{rollout_id}/attempts/latest", response_model=Optional[Attempt])
|
|
750
|
+
async def get_latest_attempt(rollout_id: str): # pyright: ignore[reportUnusedFunction]
|
|
751
|
+
return await self.get_latest_attempt(rollout_id)
|
|
752
|
+
|
|
753
|
+
@api.get(API_AGL_PREFIX + "/resources", response_model=PaginatedResult[ResourcesUpdate])
|
|
754
|
+
async def query_resources(params: QueryResourcesRequest = Depends()): # pyright: ignore[reportUnusedFunction]
|
|
755
|
+
_validate_paginated_request(params, ResourcesUpdate)
|
|
756
|
+
resources = await self.query_resources(
|
|
757
|
+
resources_id=params.resources_id,
|
|
758
|
+
resources_id_contains=params.resources_id_contains,
|
|
759
|
+
sort_by=params.sort_by,
|
|
760
|
+
sort_order=params.sort_order,
|
|
761
|
+
limit=params.limit,
|
|
762
|
+
offset=params.offset,
|
|
763
|
+
)
|
|
764
|
+
return _build_paginated_response(resources, limit=params.limit, offset=params.offset)
|
|
765
|
+
|
|
766
|
+
@api.post(API_AGL_PREFIX + "/resources", status_code=201, response_model=ResourcesUpdate)
|
|
767
|
+
async def add_resources(resources: NamedResources): # pyright: ignore[reportUnusedFunction]
|
|
768
|
+
return await self.add_resources(resources)
|
|
769
|
+
|
|
770
|
+
@api.get(API_AGL_PREFIX + "/resources/latest", response_model=Optional[ResourcesUpdate])
|
|
771
|
+
async def get_latest_resources(): # pyright: ignore[reportUnusedFunction]
|
|
772
|
+
return await self.get_latest_resources()
|
|
773
|
+
|
|
774
|
+
@api.post(API_AGL_PREFIX + "/resources/{resources_id}", response_model=ResourcesUpdate)
|
|
775
|
+
async def update_resources( # pyright: ignore[reportUnusedFunction]
|
|
776
|
+
resources_id: str, resources: NamedResources
|
|
777
|
+
):
|
|
778
|
+
return await self.update_resources(resources_id, resources)
|
|
779
|
+
|
|
780
|
+
@api.get(API_AGL_PREFIX + "/resources/{resources_id}", response_model=Optional[ResourcesUpdate])
|
|
781
|
+
async def get_resources_by_id(resources_id: str): # pyright: ignore[reportUnusedFunction]
|
|
782
|
+
return await self.get_resources_by_id(resources_id)
|
|
783
|
+
|
|
784
|
+
@api.post(API_AGL_PREFIX + "/spans", status_code=201, response_model=Optional[Span])
|
|
785
|
+
async def add_span(span: Span): # pyright: ignore[reportUnusedFunction]
|
|
786
|
+
return await self.add_span(span)
|
|
787
|
+
|
|
788
|
+
@api.get(API_AGL_PREFIX + "/spans", response_model=PaginatedResult[Span])
|
|
789
|
+
async def query_spans(params: QuerySpansRequest = Depends()): # pyright: ignore[reportUnusedFunction]
|
|
790
|
+
_validate_paginated_request(params, Span)
|
|
791
|
+
spans = await self.query_spans(
|
|
792
|
+
params.rollout_id,
|
|
793
|
+
params.attempt_id,
|
|
794
|
+
trace_id=params.trace_id,
|
|
795
|
+
trace_id_contains=params.trace_id_contains,
|
|
796
|
+
span_id=params.span_id,
|
|
797
|
+
span_id_contains=params.span_id_contains,
|
|
798
|
+
parent_id=params.parent_id,
|
|
799
|
+
parent_id_contains=params.parent_id_contains,
|
|
800
|
+
name=params.name,
|
|
801
|
+
name_contains=params.name_contains,
|
|
802
|
+
filter_logic=params.filter_logic,
|
|
803
|
+
sort_by=params.sort_by,
|
|
804
|
+
sort_order=params.sort_order,
|
|
805
|
+
limit=params.limit,
|
|
806
|
+
offset=params.offset,
|
|
807
|
+
)
|
|
808
|
+
return _build_paginated_response(spans, limit=params.limit, offset=params.offset)
|
|
809
|
+
|
|
810
|
+
@api.post(API_AGL_PREFIX + "/spans/search", response_model=PaginatedResult[Span])
|
|
811
|
+
async def search_spans(request: QuerySpansRequest): # pyright: ignore[reportUnusedFunction]
|
|
812
|
+
_validate_paginated_request(request, Span)
|
|
813
|
+
spans = await self.query_spans(
|
|
814
|
+
request.rollout_id,
|
|
815
|
+
request.attempt_id,
|
|
816
|
+
trace_id=request.trace_id,
|
|
817
|
+
trace_id_contains=request.trace_id_contains,
|
|
818
|
+
span_id=request.span_id,
|
|
819
|
+
span_id_contains=request.span_id_contains,
|
|
820
|
+
parent_id=request.parent_id,
|
|
821
|
+
parent_id_contains=request.parent_id_contains,
|
|
822
|
+
name=request.name,
|
|
823
|
+
name_contains=request.name_contains,
|
|
824
|
+
filter_logic=request.filter_logic,
|
|
825
|
+
sort_by=request.sort_by,
|
|
826
|
+
sort_order=request.sort_order,
|
|
827
|
+
limit=request.limit,
|
|
828
|
+
offset=request.offset,
|
|
829
|
+
)
|
|
830
|
+
return _build_paginated_response(spans, limit=request.limit, offset=request.offset)
|
|
831
|
+
|
|
832
|
+
@api.post(API_AGL_PREFIX + "/spans/next", response_model=NextSequenceIdResponse)
|
|
833
|
+
async def get_next_span_sequence_id(request: NextSequenceIdRequest): # pyright: ignore[reportUnusedFunction]
|
|
834
|
+
sequence_id = await self.get_next_span_sequence_id(request.rollout_id, request.attempt_id)
|
|
835
|
+
return NextSequenceIdResponse(sequence_id=sequence_id)
|
|
836
|
+
|
|
837
|
+
@api.post(API_AGL_PREFIX + "/waits/rollouts", response_model=List[Rollout])
|
|
838
|
+
async def wait_for_rollouts(request: WaitForRolloutsRequest): # pyright: ignore[reportUnusedFunction]
|
|
839
|
+
return await self.wait_for_rollouts(rollout_ids=request.rollout_ids, timeout=request.timeout)
|
|
840
|
+
|
|
841
|
+
# Setup OTLP endpoints
|
|
842
|
+
self._setup_otlp(api)
|
|
843
|
+
|
|
844
|
+
# Mount the API router of /v1/...
|
|
845
|
+
self.app.include_router(api)
|
|
846
|
+
|
|
847
|
+
# Finally, mount the dashboard assets
|
|
848
|
+
self._setup_dashboard()
|
|
849
|
+
|
|
850
|
+
def _setup_metrics(self, api: APIRouter, app: FastAPI):
|
|
851
|
+
"""Setup Prometheus metrics endpoints."""
|
|
852
|
+
if self._tracker is None:
|
|
853
|
+
return
|
|
854
|
+
|
|
855
|
+
self._tracker.register_counter(
|
|
856
|
+
"msk.http.total",
|
|
857
|
+
["path", "method", "status"],
|
|
858
|
+
group_level=2,
|
|
859
|
+
)
|
|
860
|
+
self._tracker.register_histogram(
|
|
861
|
+
"msk.http.latency",
|
|
862
|
+
["path", "method", "status"],
|
|
863
|
+
buckets=LATENCY_BUCKETS,
|
|
864
|
+
group_level=2,
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
def get_template_path(path: str) -> str:
|
|
868
|
+
# Handle "latest" keywords BEFORE generic IDs
|
|
869
|
+
if path.endswith("/attempts/latest") and "/rollouts/" in path:
|
|
870
|
+
return re.sub(r"rollouts/[^/]+/attempts/latest$", "rollouts/{rollout_id}/attempts/latest", path)
|
|
871
|
+
if path.endswith("/attempts/search") and "/rollouts/" in path:
|
|
872
|
+
return re.sub(r"rollouts/[^/]+/attempts/search$", "rollouts/{rollout_id}/attempts/search", path)
|
|
873
|
+
if path.endswith("/resources/latest"):
|
|
874
|
+
return path
|
|
875
|
+
if path.endswith("/search"):
|
|
876
|
+
return path
|
|
877
|
+
if "enqueue" in path or "dequeue" in path:
|
|
878
|
+
return path
|
|
879
|
+
|
|
880
|
+
# Handle generic IDs
|
|
881
|
+
# (Order matters: longest paths first or lookaheads)
|
|
882
|
+
path = re.sub(r"/attempts/[^/]+$", "/attempts/{attempt_id}", path)
|
|
883
|
+
path = re.sub(r"/rollouts/[^/]+", "/rollouts/{rollout_id}", path) # Handles root and middle
|
|
884
|
+
path = re.sub(r"/resources/[^/]+$", "/resources/{resources_id}", path)
|
|
885
|
+
path = re.sub(r"/workers/[^/]+$", "/workers/{worker_id}", path)
|
|
886
|
+
|
|
887
|
+
return path
|
|
888
|
+
|
|
889
|
+
@app.middleware("http")
|
|
890
|
+
async def tracking_middleware( # pyright: ignore[reportUnusedFunction]
|
|
891
|
+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
892
|
+
) -> Response:
|
|
893
|
+
if self._tracker is None:
|
|
894
|
+
return await call_next(request)
|
|
895
|
+
|
|
896
|
+
start = time.perf_counter()
|
|
897
|
+
status = 520 # Default to 520 if things crash hard
|
|
898
|
+
|
|
899
|
+
try:
|
|
900
|
+
response = await call_next(request)
|
|
901
|
+
status = response.status_code
|
|
902
|
+
return response
|
|
903
|
+
except asyncio.CancelledError:
|
|
904
|
+
# Client disconnected (Timeout)
|
|
905
|
+
status = 499 # Standard Nginx code for "Client Closed Request"
|
|
906
|
+
server_logger.debug(f"Client disconnected (Timeout): {request.url.path}", exc_info=True)
|
|
907
|
+
raise # Re-raise to let Uvicorn handle the cleanup
|
|
908
|
+
except Exception as exc:
|
|
909
|
+
status = resolve_error_type(exc)
|
|
910
|
+
server_logger.debug(f"Server error: {request.url.path}", exc_info=True)
|
|
911
|
+
raise
|
|
912
|
+
finally:
|
|
913
|
+
# This block executes NO MATTER WHAT happens above
|
|
914
|
+
elapsed = time.perf_counter() - start
|
|
915
|
+
|
|
916
|
+
# Strip the ID-specific URL parts
|
|
917
|
+
path = get_template_path(request.url.path)
|
|
918
|
+
method = request.method
|
|
919
|
+
|
|
920
|
+
await self._tracker.inc_counter(
|
|
921
|
+
"msk.http.total",
|
|
922
|
+
labels={"method": method, "path": path, "status": str(status)},
|
|
923
|
+
)
|
|
924
|
+
await self._tracker.observe_histogram(
|
|
925
|
+
"msk.http.latency",
|
|
926
|
+
value=elapsed,
|
|
927
|
+
labels={"method": method, "path": path, "status": str(status)},
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
if self._tracker.has_prometheus():
|
|
931
|
+
from prometheus_client import make_asgi_app # pyright: ignore[reportUnknownVariableType]
|
|
932
|
+
|
|
933
|
+
metrics_app = make_asgi_app( # pyright: ignore[reportUnknownVariableType]
|
|
934
|
+
registry=get_prometheus_registry()
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
# This App would need to be accessed via /v1/prometheus/ (note the trailing slash)
|
|
938
|
+
app.mount(api.prefix + "/prometheus", metrics_app) # pyright: ignore[reportUnknownArgumentType]
|
|
939
|
+
|
|
940
|
+
def _setup_otlp(self, api: APIRouter):
|
|
941
|
+
"""Setup OTLP endpoints."""
|
|
942
|
+
|
|
943
|
+
async def _trace_handler(request: PbExportTraceServiceRequest) -> None:
|
|
944
|
+
spans = await spans_from_proto(request, self.get_many_span_sequence_ids)
|
|
945
|
+
server_logger.debug(f"Received {len(spans)} OTLP spans: {', '.join([span.name for span in spans])}")
|
|
946
|
+
await self.add_many_spans(spans)
|
|
947
|
+
|
|
948
|
+
# Reserved methods for OTEL traces
|
|
949
|
+
# https://opentelemetry.io/docs/specs/otlp/#otlphttp-request
|
|
950
|
+
# This is currently the recommended path for Otel compatibility and bulk-insertion support.
|
|
951
|
+
@api.post("/traces")
|
|
952
|
+
async def otlp_traces(request: Request): # pyright: ignore[reportUnusedFunction]
|
|
953
|
+
return await handle_otlp_export(
|
|
954
|
+
request, PbExportTraceServiceRequest, PbExportTraceServiceResponse, _trace_handler, "traces"
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
# Other API endpoints are not supported yet
|
|
958
|
+
@api.post("/metrics")
|
|
959
|
+
async def otlp_metrics(): # pyright: ignore[reportUnusedFunction]
|
|
960
|
+
return Response(status_code=501)
|
|
961
|
+
|
|
962
|
+
@api.post("/logs")
|
|
963
|
+
async def otlp_logs(): # pyright: ignore[reportUnusedFunction]
|
|
964
|
+
return Response(status_code=501)
|
|
965
|
+
|
|
966
|
+
@api.post("/development/profiles")
|
|
967
|
+
async def otlp_development_profiles(): # pyright: ignore[reportUnusedFunction]
|
|
968
|
+
return Response(status_code=501)
|
|
969
|
+
|
|
970
|
+
def _setup_dashboard(self):
|
|
971
|
+
"""Setup the dashboard static files and SPA."""
|
|
972
|
+
assert self.app is not None
|
|
973
|
+
|
|
974
|
+
dashboard_dir = (Path(__file__).parent.parent / "dashboard").resolve()
|
|
975
|
+
if not dashboard_dir.exists():
|
|
976
|
+
server_logger.error("Dashboard directory not found at %s. Please build the dashboard first.", dashboard_dir)
|
|
977
|
+
return
|
|
978
|
+
|
|
979
|
+
dashboard_assets_dir = dashboard_dir / "assets"
|
|
980
|
+
if not dashboard_assets_dir.exists():
|
|
981
|
+
server_logger.error(
|
|
982
|
+
"Dashboard assets directory not found at %s. Please build the dashboard first.", dashboard_assets_dir
|
|
983
|
+
)
|
|
984
|
+
return
|
|
985
|
+
|
|
986
|
+
index_file = dashboard_dir / "index.html"
|
|
987
|
+
if not index_file.exists():
|
|
988
|
+
server_logger.error("Dashboard index file not found at %s. Please build the dashboard first.", index_file)
|
|
989
|
+
return
|
|
990
|
+
|
|
991
|
+
# Mount the static files in dashboard/assets
|
|
992
|
+
self.app.mount("/assets", CachedStaticFiles(directory=dashboard_assets_dir), name="assets")
|
|
993
|
+
|
|
994
|
+
# SPA fallback (client-side routing)
|
|
995
|
+
# Anything that's not /v1/* or a real file in /assets will serve index.html
|
|
996
|
+
@self.app.get("/", include_in_schema=False)
|
|
997
|
+
def root(): # pyright: ignore[reportUnusedFunction]
|
|
998
|
+
return FileResponse(index_file)
|
|
999
|
+
|
|
1000
|
+
@self.app.get("/{full_path:path}", include_in_schema=False)
|
|
1001
|
+
def spa_fallback(full_path: str): # pyright: ignore[reportUnusedFunction]
|
|
1002
|
+
if full_path.startswith("v1/"):
|
|
1003
|
+
raise HTTPException(status_code=404, detail="Not Found")
|
|
1004
|
+
# Let the frontend router handle it
|
|
1005
|
+
return FileResponse(index_file)
|
|
1006
|
+
|
|
1007
|
+
server_logger.info("Mantisdk dashboard will be available at %s", self.endpoint)
|
|
1008
|
+
|
|
1009
|
+
# Delegate methods
|
|
1010
|
+
async def _call_store_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
|
1011
|
+
"""First decide what store to delegate to in *this* process, and then call the method on it.
|
|
1012
|
+
|
|
1013
|
+
- In the owner process: delegate to the in-process store.
|
|
1014
|
+
- In a different process: delegate to a HTTP client talking to the server.
|
|
1015
|
+
"""
|
|
1016
|
+
# If the store is zero-copy, we can just call the method directly.
|
|
1017
|
+
if self.store is not None and self.store.capabilities.get("zero_copy", False):
|
|
1018
|
+
return await getattr(self.store, method_name)(*args, **kwargs)
|
|
1019
|
+
|
|
1020
|
+
if os.getpid() == self._owner_pid:
|
|
1021
|
+
if method_name == "wait_for_rollouts":
|
|
1022
|
+
# wait_for_rollouts can block for a long time; avoid holding the lock
|
|
1023
|
+
# so other requests can make progress while we wait.
|
|
1024
|
+
return await getattr(self.store, method_name)(*args, **kwargs)
|
|
1025
|
+
|
|
1026
|
+
# If it's already thread-safe, we can just call the method directly.
|
|
1027
|
+
# Acquiring the threading lock directly would block the event loop if it's
|
|
1028
|
+
# already held by another thread (for example, the HTTP server thread).
|
|
1029
|
+
# Potential fix here are needed to make it work. For example:
|
|
1030
|
+
# ```
|
|
1031
|
+
# acquired = self._lock.acquire(blocking=False)
|
|
1032
|
+
# if not acquired:
|
|
1033
|
+
# await asyncio.to_thread(self._lock.acquire)
|
|
1034
|
+
# try:
|
|
1035
|
+
# return await getattr(self.store, method_name)(*args, **kwargs)
|
|
1036
|
+
# finally:
|
|
1037
|
+
# self._lock.release()
|
|
1038
|
+
# ```
|
|
1039
|
+
# Or we can just bypass the lock for thread-safe stores.
|
|
1040
|
+
if self.store is not None and self.store.capabilities.get("thread_safe", False):
|
|
1041
|
+
return await getattr(self.store, method_name)(*args, **kwargs)
|
|
1042
|
+
else:
|
|
1043
|
+
with self._lock:
|
|
1044
|
+
return await getattr(self.store, method_name)(*args, **kwargs)
|
|
1045
|
+
if self._client is None:
|
|
1046
|
+
self._client = LightningStoreClient(self.endpoint)
|
|
1047
|
+
return await getattr(self._client, method_name)(*args, **kwargs)
|
|
1048
|
+
|
|
1049
|
+
async def statistics(self) -> LightningStoreStatistics:
|
|
1050
|
+
return await self._call_store_method("statistics")
|
|
1051
|
+
|
|
1052
|
+
async def start_rollout(
|
|
1053
|
+
self,
|
|
1054
|
+
input: TaskInput,
|
|
1055
|
+
mode: Literal["train", "val", "test"] | None = None,
|
|
1056
|
+
resources_id: str | None = None,
|
|
1057
|
+
config: RolloutConfig | None = None,
|
|
1058
|
+
metadata: Dict[str, Any] | None = None,
|
|
1059
|
+
worker_id: Optional[str] = None,
|
|
1060
|
+
) -> AttemptedRollout:
|
|
1061
|
+
return await self._call_store_method(
|
|
1062
|
+
"start_rollout",
|
|
1063
|
+
input,
|
|
1064
|
+
mode,
|
|
1065
|
+
resources_id,
|
|
1066
|
+
config,
|
|
1067
|
+
metadata,
|
|
1068
|
+
worker_id,
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
async def enqueue_rollout(
|
|
1072
|
+
self,
|
|
1073
|
+
input: TaskInput,
|
|
1074
|
+
mode: Literal["train", "val", "test"] | None = None,
|
|
1075
|
+
resources_id: str | None = None,
|
|
1076
|
+
config: RolloutConfig | None = None,
|
|
1077
|
+
metadata: Dict[str, Any] | None = None,
|
|
1078
|
+
) -> Rollout:
|
|
1079
|
+
return await self._call_store_method(
|
|
1080
|
+
"enqueue_rollout",
|
|
1081
|
+
input,
|
|
1082
|
+
mode,
|
|
1083
|
+
resources_id,
|
|
1084
|
+
config,
|
|
1085
|
+
metadata,
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
async def enqueue_many_rollouts(self, rollouts: Sequence[EnqueueRolloutRequest]) -> Sequence[Rollout]:
|
|
1089
|
+
return await self._call_store_method("enqueue_many_rollouts", rollouts)
|
|
1090
|
+
|
|
1091
|
+
async def dequeue_rollout(self, worker_id: Optional[str] = None) -> Optional[AttemptedRollout]:
|
|
1092
|
+
return await self._call_store_method("dequeue_rollout", worker_id)
|
|
1093
|
+
|
|
1094
|
+
async def dequeue_many_rollouts(
|
|
1095
|
+
self,
|
|
1096
|
+
*,
|
|
1097
|
+
limit: int = 1,
|
|
1098
|
+
worker_id: Optional[str] = None,
|
|
1099
|
+
) -> Sequence[AttemptedRollout]:
|
|
1100
|
+
return await self._call_store_method("dequeue_many_rollouts", limit=limit, worker_id=worker_id)
|
|
1101
|
+
|
|
1102
|
+
async def start_attempt(self, rollout_id: str, worker_id: Optional[str] = None) -> AttemptedRollout:
|
|
1103
|
+
return await self._call_store_method("start_attempt", rollout_id, worker_id)
|
|
1104
|
+
|
|
1105
|
+
async def query_rollouts(
|
|
1106
|
+
self,
|
|
1107
|
+
*,
|
|
1108
|
+
status_in: Optional[Sequence[RolloutStatus]] = None,
|
|
1109
|
+
rollout_id_in: Optional[Sequence[str]] = None,
|
|
1110
|
+
rollout_id_contains: Optional[str] = None,
|
|
1111
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
1112
|
+
sort_by: Optional[str] = None,
|
|
1113
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1114
|
+
limit: int = -1,
|
|
1115
|
+
offset: int = 0,
|
|
1116
|
+
status: Optional[Sequence[RolloutStatus]] = None,
|
|
1117
|
+
rollout_ids: Optional[Sequence[str]] = None,
|
|
1118
|
+
) -> PaginatedResult[Union[AttemptedRollout, Rollout]]:
|
|
1119
|
+
return await self._call_store_method(
|
|
1120
|
+
"query_rollouts",
|
|
1121
|
+
status_in=status_in,
|
|
1122
|
+
rollout_id_in=rollout_id_in,
|
|
1123
|
+
rollout_id_contains=rollout_id_contains,
|
|
1124
|
+
filter_logic=filter_logic,
|
|
1125
|
+
sort_by=sort_by,
|
|
1126
|
+
sort_order=sort_order,
|
|
1127
|
+
limit=limit,
|
|
1128
|
+
offset=offset,
|
|
1129
|
+
status=status,
|
|
1130
|
+
rollout_ids=rollout_ids,
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
async def query_attempts(
|
|
1134
|
+
self,
|
|
1135
|
+
rollout_id: str,
|
|
1136
|
+
*,
|
|
1137
|
+
sort_by: Optional[str] = "sequence_id",
|
|
1138
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1139
|
+
limit: int = -1,
|
|
1140
|
+
offset: int = 0,
|
|
1141
|
+
) -> PaginatedResult[Attempt]:
|
|
1142
|
+
return await self._call_store_method(
|
|
1143
|
+
"query_attempts",
|
|
1144
|
+
rollout_id,
|
|
1145
|
+
sort_by=sort_by,
|
|
1146
|
+
sort_order=sort_order,
|
|
1147
|
+
limit=limit,
|
|
1148
|
+
offset=offset,
|
|
1149
|
+
)
|
|
1150
|
+
|
|
1151
|
+
async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
|
|
1152
|
+
return await self._call_store_method("get_latest_attempt", rollout_id)
|
|
1153
|
+
|
|
1154
|
+
async def query_resources(
|
|
1155
|
+
self,
|
|
1156
|
+
*,
|
|
1157
|
+
resources_id: Optional[str] = None,
|
|
1158
|
+
resources_id_contains: Optional[str] = None,
|
|
1159
|
+
sort_by: Optional[str] = None,
|
|
1160
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1161
|
+
limit: int = -1,
|
|
1162
|
+
offset: int = 0,
|
|
1163
|
+
) -> PaginatedResult[ResourcesUpdate]:
|
|
1164
|
+
return await self._call_store_method(
|
|
1165
|
+
"query_resources",
|
|
1166
|
+
resources_id=resources_id,
|
|
1167
|
+
resources_id_contains=resources_id_contains,
|
|
1168
|
+
sort_by=sort_by,
|
|
1169
|
+
sort_order=sort_order,
|
|
1170
|
+
limit=limit,
|
|
1171
|
+
offset=offset,
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
async def get_rollout_by_id(self, rollout_id: str) -> Optional[Rollout]:
|
|
1175
|
+
return await self._call_store_method("get_rollout_by_id", rollout_id)
|
|
1176
|
+
|
|
1177
|
+
async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
|
|
1178
|
+
return await self._call_store_method("add_resources", resources)
|
|
1179
|
+
|
|
1180
|
+
async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
|
|
1181
|
+
return await self._call_store_method("update_resources", resources_id, resources)
|
|
1182
|
+
|
|
1183
|
+
async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
|
|
1184
|
+
return await self._call_store_method("get_resources_by_id", resources_id)
|
|
1185
|
+
|
|
1186
|
+
async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
|
|
1187
|
+
return await self._call_store_method("get_latest_resources")
|
|
1188
|
+
|
|
1189
|
+
async def add_span(self, span: Span) -> Optional[Span]:
|
|
1190
|
+
return await self._call_store_method("add_span", span)
|
|
1191
|
+
|
|
1192
|
+
async def add_many_spans(self, spans: Sequence[Span]) -> Sequence[Span]:
|
|
1193
|
+
return await self._call_store_method("add_many_spans", spans)
|
|
1194
|
+
|
|
1195
|
+
async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
|
|
1196
|
+
return await self._call_store_method("get_next_span_sequence_id", rollout_id, attempt_id)
|
|
1197
|
+
|
|
1198
|
+
async def get_many_span_sequence_ids(self, rollout_attempt_ids: Sequence[Tuple[str, str]]) -> Sequence[int]:
|
|
1199
|
+
return await self._call_store_method("get_many_span_sequence_ids", rollout_attempt_ids)
|
|
1200
|
+
|
|
1201
|
+
async def add_otel_span(
|
|
1202
|
+
self,
|
|
1203
|
+
rollout_id: str,
|
|
1204
|
+
attempt_id: str,
|
|
1205
|
+
readable_span: ReadableSpan,
|
|
1206
|
+
sequence_id: int | None = None,
|
|
1207
|
+
) -> Optional[Span]:
|
|
1208
|
+
return await self._call_store_method(
|
|
1209
|
+
"add_otel_span",
|
|
1210
|
+
rollout_id,
|
|
1211
|
+
attempt_id,
|
|
1212
|
+
readable_span,
|
|
1213
|
+
sequence_id,
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
|
|
1217
|
+
return await self._call_store_method("wait_for_rollouts", rollout_ids=rollout_ids, timeout=timeout)
|
|
1218
|
+
|
|
1219
|
+
async def query_spans(
|
|
1220
|
+
self,
|
|
1221
|
+
rollout_id: str,
|
|
1222
|
+
attempt_id: str | Literal["latest"] | None = None,
|
|
1223
|
+
*,
|
|
1224
|
+
trace_id: Optional[str] = None,
|
|
1225
|
+
trace_id_contains: Optional[str] = None,
|
|
1226
|
+
span_id: Optional[str] = None,
|
|
1227
|
+
span_id_contains: Optional[str] = None,
|
|
1228
|
+
parent_id: Optional[str] = None,
|
|
1229
|
+
parent_id_contains: Optional[str] = None,
|
|
1230
|
+
name: Optional[str] = None,
|
|
1231
|
+
name_contains: Optional[str] = None,
|
|
1232
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
1233
|
+
limit: int = -1,
|
|
1234
|
+
offset: int = 0,
|
|
1235
|
+
sort_by: Optional[str] = "sequence_id",
|
|
1236
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1237
|
+
) -> PaginatedResult[Span]:
|
|
1238
|
+
return await self._call_store_method(
|
|
1239
|
+
"query_spans",
|
|
1240
|
+
rollout_id,
|
|
1241
|
+
attempt_id,
|
|
1242
|
+
trace_id=trace_id,
|
|
1243
|
+
trace_id_contains=trace_id_contains,
|
|
1244
|
+
span_id=span_id,
|
|
1245
|
+
span_id_contains=span_id_contains,
|
|
1246
|
+
parent_id=parent_id,
|
|
1247
|
+
parent_id_contains=parent_id_contains,
|
|
1248
|
+
name=name,
|
|
1249
|
+
name_contains=name_contains,
|
|
1250
|
+
filter_logic=filter_logic,
|
|
1251
|
+
limit=limit,
|
|
1252
|
+
offset=offset,
|
|
1253
|
+
sort_by=sort_by,
|
|
1254
|
+
sort_order=sort_order,
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
async def update_rollout(
|
|
1258
|
+
self,
|
|
1259
|
+
rollout_id: str,
|
|
1260
|
+
input: TaskInput | Unset = UNSET,
|
|
1261
|
+
mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
|
|
1262
|
+
resources_id: Optional[str] | Unset = UNSET,
|
|
1263
|
+
status: RolloutStatus | Unset = UNSET,
|
|
1264
|
+
config: RolloutConfig | Unset = UNSET,
|
|
1265
|
+
metadata: Optional[Dict[str, Any]] | Unset = UNSET,
|
|
1266
|
+
) -> Rollout:
|
|
1267
|
+
return await self._call_store_method(
|
|
1268
|
+
"update_rollout",
|
|
1269
|
+
rollout_id,
|
|
1270
|
+
input,
|
|
1271
|
+
mode,
|
|
1272
|
+
resources_id,
|
|
1273
|
+
status,
|
|
1274
|
+
config,
|
|
1275
|
+
metadata,
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
async def update_attempt(
|
|
1279
|
+
self,
|
|
1280
|
+
rollout_id: str,
|
|
1281
|
+
attempt_id: str | Literal["latest"],
|
|
1282
|
+
status: AttemptStatus | Unset = UNSET,
|
|
1283
|
+
worker_id: str | Unset = UNSET,
|
|
1284
|
+
last_heartbeat_time: float | Unset = UNSET,
|
|
1285
|
+
metadata: Optional[Dict[str, Any]] | Unset = UNSET,
|
|
1286
|
+
) -> Attempt:
|
|
1287
|
+
return await self._call_store_method(
|
|
1288
|
+
"update_attempt",
|
|
1289
|
+
rollout_id,
|
|
1290
|
+
attempt_id,
|
|
1291
|
+
status,
|
|
1292
|
+
worker_id,
|
|
1293
|
+
last_heartbeat_time,
|
|
1294
|
+
metadata,
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
async def query_workers(
|
|
1298
|
+
self,
|
|
1299
|
+
*,
|
|
1300
|
+
status_in: Optional[Sequence[WorkerStatus]] = None,
|
|
1301
|
+
worker_id_contains: Optional[str] = None,
|
|
1302
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
1303
|
+
sort_by: Optional[str] = None,
|
|
1304
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1305
|
+
limit: int = -1,
|
|
1306
|
+
offset: int = 0,
|
|
1307
|
+
) -> PaginatedResult[Worker]:
|
|
1308
|
+
return await self._call_store_method(
|
|
1309
|
+
"query_workers",
|
|
1310
|
+
status_in=status_in,
|
|
1311
|
+
worker_id_contains=worker_id_contains,
|
|
1312
|
+
filter_logic=filter_logic,
|
|
1313
|
+
sort_by=sort_by,
|
|
1314
|
+
sort_order=sort_order,
|
|
1315
|
+
limit=limit,
|
|
1316
|
+
offset=offset,
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
async def get_worker_by_id(self, worker_id: str) -> Optional[Worker]:
|
|
1320
|
+
return await self._call_store_method("get_worker_by_id", worker_id)
|
|
1321
|
+
|
|
1322
|
+
async def update_worker(
|
|
1323
|
+
self,
|
|
1324
|
+
worker_id: str,
|
|
1325
|
+
heartbeat_stats: Dict[str, Any] | Unset = UNSET,
|
|
1326
|
+
) -> Worker:
|
|
1327
|
+
return await self._call_store_method(
|
|
1328
|
+
"update_worker",
|
|
1329
|
+
worker_id,
|
|
1330
|
+
heartbeat_stats,
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
|
|
1334
|
+
class LightningStoreClient(LightningStore):
|
|
1335
|
+
"""HTTP client that talks to a remote LightningStoreServer.
|
|
1336
|
+
|
|
1337
|
+
Args:
|
|
1338
|
+
server_address: The address of the LightningStoreServer to connect to.
|
|
1339
|
+
retry_delays:
|
|
1340
|
+
Backoff schedule (seconds) used when the initial request fails for a
|
|
1341
|
+
non-application reason. Each entry is a retry attempt.
|
|
1342
|
+
Setting to an empty sequence to disable retries.
|
|
1343
|
+
health_retry_delays:
|
|
1344
|
+
Delays between /health probes while waiting for the server to come back.
|
|
1345
|
+
Setting to an empty sequence to disable health checks.
|
|
1346
|
+
request_timeout: Timeout (seconds) for each request.
|
|
1347
|
+
connection_timeout: Timeout (seconds) for establishing connection.
|
|
1348
|
+
"""
|
|
1349
|
+
|
|
1350
|
+
def __init__(
|
|
1351
|
+
self,
|
|
1352
|
+
server_address: str,
|
|
1353
|
+
*,
|
|
1354
|
+
retry_delays: Sequence[float] = (1.0, 2.0, 5.0),
|
|
1355
|
+
health_retry_delays: Sequence[float] = (0.1, 0.2, 0.5),
|
|
1356
|
+
request_timeout: float = 30.0,
|
|
1357
|
+
connection_timeout: float = 5.0,
|
|
1358
|
+
):
|
|
1359
|
+
self.server_address_root = server_address.rstrip("/")
|
|
1360
|
+
self.server_address = self.server_address_root + API_V1_AGL_PREFIX
|
|
1361
|
+
self._sessions: Dict[int, aiohttp.ClientSession] = {} # id(loop) -> ClientSession
|
|
1362
|
+
self._lock = threading.Lock()
|
|
1363
|
+
|
|
1364
|
+
# retry config
|
|
1365
|
+
self._retry_delays = tuple(float(d) for d in retry_delays)
|
|
1366
|
+
self._health_retry_delays = tuple(float(d) for d in health_retry_delays)
|
|
1367
|
+
|
|
1368
|
+
# Timeouts
|
|
1369
|
+
self._request_timeout = request_timeout
|
|
1370
|
+
self._connection_timeout = connection_timeout
|
|
1371
|
+
|
|
1372
|
+
# Store whether the dequeue was successful in history
|
|
1373
|
+
self._dequeue_was_successful: bool = False
|
|
1374
|
+
self._dequeue_first_unsuccessful: bool = True
|
|
1375
|
+
|
|
1376
|
+
@property
|
|
1377
|
+
def capabilities(self) -> LightningStoreCapabilities:
|
|
1378
|
+
"""Return the capabilities of the store."""
|
|
1379
|
+
return LightningStoreCapabilities(
|
|
1380
|
+
thread_safe=True,
|
|
1381
|
+
async_safe=True,
|
|
1382
|
+
zero_copy=True,
|
|
1383
|
+
otlp_traces=True,
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
def otlp_traces_endpoint(self) -> str:
|
|
1387
|
+
"""Return the OTLP/HTTP traces endpoint of the store."""
|
|
1388
|
+
return f"{self.server_address_root}/v1/traces"
|
|
1389
|
+
|
|
1390
|
+
async def statistics(self) -> LightningStoreStatistics:
|
|
1391
|
+
payload = await self._request_json("get", "/statistics")
|
|
1392
|
+
return cast(LightningStoreStatistics, payload)
|
|
1393
|
+
|
|
1394
|
+
def __getstate__(self):
|
|
1395
|
+
"""
|
|
1396
|
+
When LightningStoreClient is pickled (e.g., passed to a subprocess), we only
|
|
1397
|
+
serialize the server address and retry configurations. The ClientSessions
|
|
1398
|
+
are excluded as they should not be transferred between processes.
|
|
1399
|
+
"""
|
|
1400
|
+
return {
|
|
1401
|
+
"server_address_root": self.server_address_root,
|
|
1402
|
+
"server_address": self.server_address,
|
|
1403
|
+
"_retry_delays": self._retry_delays,
|
|
1404
|
+
"_health_retry_delays": self._health_retry_delays,
|
|
1405
|
+
"_request_timeout": self._request_timeout,
|
|
1406
|
+
"_connection_timeout": self._connection_timeout,
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
def __setstate__(self, state: Dict[str, Any]):
|
|
1410
|
+
"""
|
|
1411
|
+
Restore from pickle by reconstructing only the essential attributes.
|
|
1412
|
+
|
|
1413
|
+
Replicating `__init__` logic to create another client instance in the subprocess.
|
|
1414
|
+
"""
|
|
1415
|
+
self.server_address = state["server_address"]
|
|
1416
|
+
self.server_address_root = state["server_address_root"]
|
|
1417
|
+
self._sessions = {}
|
|
1418
|
+
self._lock = threading.Lock()
|
|
1419
|
+
self._retry_delays = state["_retry_delays"]
|
|
1420
|
+
self._health_retry_delays = state["_health_retry_delays"]
|
|
1421
|
+
self._request_timeout = state["_request_timeout"]
|
|
1422
|
+
self._connection_timeout = state["_connection_timeout"]
|
|
1423
|
+
self._dequeue_was_successful = False
|
|
1424
|
+
self._dequeue_first_unsuccessful = True
|
|
1425
|
+
|
|
1426
|
+
async def _get_session(self) -> aiohttp.ClientSession:
|
|
1427
|
+
# In the proxy process, FastAPI middleware calls
|
|
1428
|
+
# client_store.get_next_span_sequence_id(...). With
|
|
1429
|
+
# reuse_session=True, _get_session() creates and caches a
|
|
1430
|
+
# single ClientSession bound to the uvicorn event loop.
|
|
1431
|
+
#
|
|
1432
|
+
# Later, the OpenTelemetry exporter (LightningSpanExporter)
|
|
1433
|
+
# runs its flush on its own private event loop (in a different
|
|
1434
|
+
# thread) and calls client_store.add_otel_span(...) ->
|
|
1435
|
+
# client_store.add_span(...).
|
|
1436
|
+
#
|
|
1437
|
+
# If we reuse one session across all, the exporter tries to reuse the
|
|
1438
|
+
# same cached ClientSession that was created on the uvicorn
|
|
1439
|
+
# loop. aiohttp.ClientSession is not loop-agnostic or
|
|
1440
|
+
# thread-safe. Using it from another loop can hang on the
|
|
1441
|
+
# first request. That's why we need a map from loop to session.
|
|
1442
|
+
|
|
1443
|
+
loop = asyncio.get_running_loop()
|
|
1444
|
+
key = id(loop)
|
|
1445
|
+
with self._lock:
|
|
1446
|
+
sess = self._sessions.get(key)
|
|
1447
|
+
if sess is None or sess.closed:
|
|
1448
|
+
timeout = aiohttp.ClientTimeout(
|
|
1449
|
+
total=self._request_timeout,
|
|
1450
|
+
connect=self._connection_timeout,
|
|
1451
|
+
sock_connect=self._connection_timeout,
|
|
1452
|
+
sock_read=self._request_timeout,
|
|
1453
|
+
)
|
|
1454
|
+
sess = aiohttp.ClientSession(timeout=timeout)
|
|
1455
|
+
self._sessions[key] = sess
|
|
1456
|
+
return sess
|
|
1457
|
+
|
|
1458
|
+
async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
|
|
1459
|
+
"""
|
|
1460
|
+
Probe the server's /health until it responds 200 or retries are exhausted.
|
|
1461
|
+
Returns True if healthy, False otherwise.
|
|
1462
|
+
"""
|
|
1463
|
+
if not self._health_retry_delays:
|
|
1464
|
+
client_logger.info("No health retry delays configured; skipping health checks.")
|
|
1465
|
+
return True
|
|
1466
|
+
|
|
1467
|
+
client_logger.info(f"Waiting for server to be healthy at {self.server_address}/health")
|
|
1468
|
+
for delay in [*self._health_retry_delays, 0.0]:
|
|
1469
|
+
try:
|
|
1470
|
+
async with session.get(f"{self.server_address}/health") as r:
|
|
1471
|
+
if r.status == 200:
|
|
1472
|
+
client_logger.info(f"Server is healthy at {self.server_address}/health")
|
|
1473
|
+
return True
|
|
1474
|
+
except Exception:
|
|
1475
|
+
# swallow and retry
|
|
1476
|
+
if delay > 0.0:
|
|
1477
|
+
client_logger.warning(f"Server is not healthy yet. Retrying in {delay} seconds.")
|
|
1478
|
+
if delay > 0.0:
|
|
1479
|
+
await asyncio.sleep(delay)
|
|
1480
|
+
client_logger.error(
|
|
1481
|
+
f"Server is not healthy at {self.server_address}/health after {len(self._health_retry_delays)} retry attempts"
|
|
1482
|
+
)
|
|
1483
|
+
return False
|
|
1484
|
+
|
|
1485
|
+
async def _request_json(
|
|
1486
|
+
self,
|
|
1487
|
+
method: Literal["get", "post"],
|
|
1488
|
+
path: str,
|
|
1489
|
+
*,
|
|
1490
|
+
json: Any | None = None,
|
|
1491
|
+
params: Mapping[str, Any] | Sequence[Tuple[str, Any]] | None = None,
|
|
1492
|
+
) -> Any:
|
|
1493
|
+
"""
|
|
1494
|
+
Make an HTTP request with:
|
|
1495
|
+
|
|
1496
|
+
1) First attempt.
|
|
1497
|
+
2) On network/session failures: probe /health until back, then retry
|
|
1498
|
+
according to self._retry_delays.
|
|
1499
|
+
3) On 4xx (e.g., 400 set by server exception handler): do not retry.
|
|
1500
|
+
|
|
1501
|
+
Returns parsed JSON (or raw JSON scalar like int).
|
|
1502
|
+
Raises the last exception if all retries fail.
|
|
1503
|
+
"""
|
|
1504
|
+
session = await self._get_session()
|
|
1505
|
+
url = f"{self.server_address}{path if path.startswith('/') else '/'+path}"
|
|
1506
|
+
|
|
1507
|
+
# attempt 0 is immediate, then follow retry schedule
|
|
1508
|
+
attempts = (0.0,) + self._retry_delays
|
|
1509
|
+
last_exc: Exception | None = None
|
|
1510
|
+
|
|
1511
|
+
for delay in attempts:
|
|
1512
|
+
if delay:
|
|
1513
|
+
client_logger.info(f"Waiting {delay} seconds before retrying {method}: {path}")
|
|
1514
|
+
await asyncio.sleep(delay)
|
|
1515
|
+
try:
|
|
1516
|
+
http_call = getattr(session, method)
|
|
1517
|
+
async with http_call(url, json=json, params=params) as resp:
|
|
1518
|
+
resp.raise_for_status()
|
|
1519
|
+
return await resp.json()
|
|
1520
|
+
except aiohttp.ClientResponseError as cre:
|
|
1521
|
+
# Respect app-level 4xx as final
|
|
1522
|
+
# 4xx => application issue; do not retry (except 408 which is transient)
|
|
1523
|
+
client_logger.debug(f"ClientResponseError ({method} {path}): {cre.status} {cre.message}", exc_info=True)
|
|
1524
|
+
if 400 <= cre.status < 500 and cre.status != 408:
|
|
1525
|
+
raise
|
|
1526
|
+
# 5xx and others will be retried below if they raise
|
|
1527
|
+
last_exc = cre
|
|
1528
|
+
client_logger.info(f"5xx and other status codes will be retried. Retrying the request {method}: {path}")
|
|
1529
|
+
# before next retry, ensure server is healthy
|
|
1530
|
+
if not await self._wait_until_healthy(session):
|
|
1531
|
+
break # server is not healthy, do not retry
|
|
1532
|
+
except (
|
|
1533
|
+
aiohttp.ServerDisconnectedError,
|
|
1534
|
+
aiohttp.ClientConnectorError,
|
|
1535
|
+
aiohttp.ClientOSError,
|
|
1536
|
+
asyncio.TimeoutError,
|
|
1537
|
+
) as net_exc:
|
|
1538
|
+
# Network/session issue: probe health before retrying
|
|
1539
|
+
client_logger.debug(f"Network/session issue ({method} {path}): {net_exc}", exc_info=True)
|
|
1540
|
+
last_exc = net_exc
|
|
1541
|
+
client_logger.info(f"Network/session issue: {net_exc} - will retry the request {method}: {path}")
|
|
1542
|
+
if not await self._wait_until_healthy(session):
|
|
1543
|
+
break # server is not healthy, do not retry
|
|
1544
|
+
|
|
1545
|
+
# exhausted retries
|
|
1546
|
+
assert last_exc is not None
|
|
1547
|
+
raise last_exc
|
|
1548
|
+
|
|
1549
|
+
async def close(self):
|
|
1550
|
+
"""Close the HTTP session."""
|
|
1551
|
+
with self._lock:
|
|
1552
|
+
sessions = list(self._sessions.values())
|
|
1553
|
+
self._sessions.clear()
|
|
1554
|
+
|
|
1555
|
+
# close them on their own loops to avoid warnings
|
|
1556
|
+
async def _close(sess: aiohttp.ClientSession):
|
|
1557
|
+
if not sess.closed:
|
|
1558
|
+
await sess.close()
|
|
1559
|
+
|
|
1560
|
+
# If called from one loop, best-effort close here.
|
|
1561
|
+
for s in sessions:
|
|
1562
|
+
try:
|
|
1563
|
+
await _close(s)
|
|
1564
|
+
except RuntimeError:
|
|
1565
|
+
# If created on a different loop/thread, schedule a thread-safe close
|
|
1566
|
+
# Fallback: close without awaiting (library tolerates it in practice),
|
|
1567
|
+
# or keep a per-loop shutdown hook where they were created.
|
|
1568
|
+
pass
|
|
1569
|
+
|
|
1570
|
+
async def start_rollout(
|
|
1571
|
+
self,
|
|
1572
|
+
input: TaskInput,
|
|
1573
|
+
mode: Literal["train", "val", "test"] | None = None,
|
|
1574
|
+
resources_id: str | None = None,
|
|
1575
|
+
config: RolloutConfig | None = None,
|
|
1576
|
+
metadata: Dict[str, Any] | None = None,
|
|
1577
|
+
worker_id: Optional[str] = None,
|
|
1578
|
+
) -> AttemptedRollout:
|
|
1579
|
+
data = await self._request_json(
|
|
1580
|
+
"post",
|
|
1581
|
+
"/rollouts",
|
|
1582
|
+
json=RolloutRequest(
|
|
1583
|
+
input=input,
|
|
1584
|
+
mode=mode,
|
|
1585
|
+
resources_id=resources_id,
|
|
1586
|
+
config=config,
|
|
1587
|
+
metadata=metadata,
|
|
1588
|
+
worker_id=worker_id,
|
|
1589
|
+
).model_dump(exclude_none=False),
|
|
1590
|
+
)
|
|
1591
|
+
return AttemptedRollout.model_validate(data)
|
|
1592
|
+
|
|
1593
|
+
async def enqueue_rollout(
|
|
1594
|
+
self,
|
|
1595
|
+
input: TaskInput,
|
|
1596
|
+
mode: Literal["train", "val", "test"] | None = None,
|
|
1597
|
+
resources_id: str | None = None,
|
|
1598
|
+
config: RolloutConfig | None = None,
|
|
1599
|
+
metadata: Dict[str, Any] | None = None,
|
|
1600
|
+
) -> Rollout:
|
|
1601
|
+
request_body = EnqueueManyRolloutsRequest(
|
|
1602
|
+
rollouts=[
|
|
1603
|
+
EnqueueRolloutRequest(
|
|
1604
|
+
input=input,
|
|
1605
|
+
mode=mode,
|
|
1606
|
+
resources_id=resources_id,
|
|
1607
|
+
config=config,
|
|
1608
|
+
metadata=metadata,
|
|
1609
|
+
)
|
|
1610
|
+
]
|
|
1611
|
+
).model_dump(exclude_none=False)
|
|
1612
|
+
data = await self._request_json(
|
|
1613
|
+
"post",
|
|
1614
|
+
"/queues/rollouts/enqueue",
|
|
1615
|
+
json=request_body,
|
|
1616
|
+
)
|
|
1617
|
+
if not data:
|
|
1618
|
+
raise RuntimeError("enqueue_rollout returned no rollouts")
|
|
1619
|
+
return Rollout.model_validate(data[0])
|
|
1620
|
+
|
|
1621
|
+
async def enqueue_many_rollouts(self, rollouts: Sequence[EnqueueRolloutRequest]) -> Sequence[Rollout]:
|
|
1622
|
+
if not rollouts:
|
|
1623
|
+
return []
|
|
1624
|
+
request_body = EnqueueManyRolloutsRequest(rollouts=list(rollouts)).model_dump(exclude_none=False)
|
|
1625
|
+
data = await self._request_json(
|
|
1626
|
+
"post",
|
|
1627
|
+
"/queues/rollouts/enqueue",
|
|
1628
|
+
json=request_body,
|
|
1629
|
+
)
|
|
1630
|
+
return [Rollout.model_validate(entry) for entry in data]
|
|
1631
|
+
|
|
1632
|
+
async def _dequeue_batch(
|
|
1633
|
+
self,
|
|
1634
|
+
*,
|
|
1635
|
+
limit: int,
|
|
1636
|
+
worker_id: Optional[str],
|
|
1637
|
+
) -> List[AttemptedRollout]:
|
|
1638
|
+
if limit <= 0:
|
|
1639
|
+
return []
|
|
1640
|
+
session = await self._get_session()
|
|
1641
|
+
url = f"{self.server_address}/queues/rollouts/dequeue"
|
|
1642
|
+
payload: Dict[str, Any] = {"limit": limit}
|
|
1643
|
+
if worker_id is not None:
|
|
1644
|
+
payload["worker_id"] = worker_id
|
|
1645
|
+
try:
|
|
1646
|
+
async with session.post(url, json=payload) as resp:
|
|
1647
|
+
resp.raise_for_status()
|
|
1648
|
+
data = await resp.json()
|
|
1649
|
+
self._dequeue_was_successful = True
|
|
1650
|
+
return [AttemptedRollout.model_validate(item) for item in data]
|
|
1651
|
+
except Exception as e:
|
|
1652
|
+
if self._dequeue_was_successful:
|
|
1653
|
+
if self._dequeue_first_unsuccessful:
|
|
1654
|
+
client_logger.warning(f"dequeue_rollout failed with exception: {e}")
|
|
1655
|
+
self._dequeue_first_unsuccessful = False
|
|
1656
|
+
client_logger.debug("dequeue_rollout failed with exception. Details:", exc_info=True)
|
|
1657
|
+
# Else ignore the exception because the server is not ready yet
|
|
1658
|
+
return []
|
|
1659
|
+
|
|
1660
|
+
async def dequeue_rollout(self, worker_id: Optional[str] = None) -> Optional[AttemptedRollout]:
|
|
1661
|
+
"""
|
|
1662
|
+
Dequeue a rollout from the server queue.
|
|
1663
|
+
|
|
1664
|
+
Returns:
|
|
1665
|
+
AttemptedRollout if a rollout is available, None if queue is empty.
|
|
1666
|
+
|
|
1667
|
+
Note:
|
|
1668
|
+
This method does NOT retry on failures. If any exception occurs (network error,
|
|
1669
|
+
server error, etc.), it logs the error and returns None immediately.
|
|
1670
|
+
"""
|
|
1671
|
+
attempts = await self._dequeue_batch(limit=1, worker_id=worker_id)
|
|
1672
|
+
return attempts[0] if attempts else None
|
|
1673
|
+
|
|
1674
|
+
async def dequeue_many_rollouts(
|
|
1675
|
+
self,
|
|
1676
|
+
*,
|
|
1677
|
+
limit: int = 1,
|
|
1678
|
+
worker_id: Optional[str] = None,
|
|
1679
|
+
) -> Sequence[AttemptedRollout]:
|
|
1680
|
+
return await self._dequeue_batch(limit=limit, worker_id=worker_id)
|
|
1681
|
+
|
|
1682
|
+
async def start_attempt(self, rollout_id: str, worker_id: Optional[str] = None) -> AttemptedRollout:
|
|
1683
|
+
payload = {"worker_id": worker_id} if worker_id is not None else None
|
|
1684
|
+
data = await self._request_json(
|
|
1685
|
+
"post",
|
|
1686
|
+
f"/rollouts/{rollout_id}/attempts",
|
|
1687
|
+
json=payload,
|
|
1688
|
+
)
|
|
1689
|
+
return AttemptedRollout.model_validate(data)
|
|
1690
|
+
|
|
1691
|
+
async def query_rollouts(
|
|
1692
|
+
self,
|
|
1693
|
+
*,
|
|
1694
|
+
status_in: Optional[Sequence[RolloutStatus]] = None,
|
|
1695
|
+
rollout_id_in: Optional[Sequence[str]] = None,
|
|
1696
|
+
rollout_id_contains: Optional[str] = None,
|
|
1697
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
1698
|
+
sort_by: Optional[str] = None,
|
|
1699
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1700
|
+
limit: int = -1,
|
|
1701
|
+
offset: int = 0,
|
|
1702
|
+
status: Optional[Sequence[RolloutStatus]] = None,
|
|
1703
|
+
rollout_ids: Optional[Sequence[str]] = None,
|
|
1704
|
+
) -> PaginatedResult[Union[AttemptedRollout, Rollout]]:
|
|
1705
|
+
resolved_status = status_in if status_in is not None else status
|
|
1706
|
+
resolved_rollout_ids = rollout_id_in if rollout_id_in is not None else rollout_ids
|
|
1707
|
+
|
|
1708
|
+
payload: Dict[str, Any] = {
|
|
1709
|
+
"limit": limit,
|
|
1710
|
+
"offset": offset,
|
|
1711
|
+
}
|
|
1712
|
+
if resolved_status is not None:
|
|
1713
|
+
payload["status_in"] = resolved_status
|
|
1714
|
+
if resolved_rollout_ids is not None:
|
|
1715
|
+
payload["rollout_id_in"] = resolved_rollout_ids
|
|
1716
|
+
if rollout_id_contains is not None:
|
|
1717
|
+
payload["rollout_id_contains"] = rollout_id_contains
|
|
1718
|
+
payload["filter_logic"] = filter_logic
|
|
1719
|
+
if sort_by is not None:
|
|
1720
|
+
payload["sort_by"] = sort_by
|
|
1721
|
+
payload["sort_order"] = sort_order
|
|
1722
|
+
|
|
1723
|
+
data = await self._request_json("post", "/rollouts/search", json=payload)
|
|
1724
|
+
items = [
|
|
1725
|
+
(
|
|
1726
|
+
AttemptedRollout.model_validate(item)
|
|
1727
|
+
if isinstance(item, dict) and "attempt" in item
|
|
1728
|
+
else Rollout.model_validate(item)
|
|
1729
|
+
)
|
|
1730
|
+
for item in data["items"]
|
|
1731
|
+
]
|
|
1732
|
+
return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
|
|
1733
|
+
|
|
1734
|
+
async def query_attempts(
|
|
1735
|
+
self,
|
|
1736
|
+
rollout_id: str,
|
|
1737
|
+
*,
|
|
1738
|
+
sort_by: Optional[str] = "sequence_id",
|
|
1739
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1740
|
+
limit: int = -1,
|
|
1741
|
+
offset: int = 0,
|
|
1742
|
+
) -> PaginatedResult[Attempt]:
|
|
1743
|
+
payload: Dict[str, Any] = {
|
|
1744
|
+
"limit": limit,
|
|
1745
|
+
"offset": offset,
|
|
1746
|
+
}
|
|
1747
|
+
if sort_by is not None:
|
|
1748
|
+
payload["sort_by"] = sort_by
|
|
1749
|
+
payload["sort_order"] = sort_order
|
|
1750
|
+
data = await self._request_json("post", f"/rollouts/{rollout_id}/attempts/search", json=payload)
|
|
1751
|
+
items = [Attempt.model_validate(item) for item in data["items"]]
|
|
1752
|
+
return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
|
|
1753
|
+
|
|
1754
|
+
async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
|
|
1755
|
+
"""
|
|
1756
|
+
Get the latest attempt for a rollout.
|
|
1757
|
+
|
|
1758
|
+
Args:
|
|
1759
|
+
rollout_id: ID of the rollout to query.
|
|
1760
|
+
|
|
1761
|
+
Returns:
|
|
1762
|
+
Attempt if found, None if not found or if all retries are exhausted.
|
|
1763
|
+
|
|
1764
|
+
Note:
|
|
1765
|
+
This method retries on transient failures (network errors, 5xx status codes).
|
|
1766
|
+
If all retries fail, it logs the error and returns None instead of raising an exception.
|
|
1767
|
+
"""
|
|
1768
|
+
try:
|
|
1769
|
+
data = await self._request_json("get", f"/rollouts/{rollout_id}/attempts/latest")
|
|
1770
|
+
return Attempt.model_validate(data) if data else None
|
|
1771
|
+
except Exception as e:
|
|
1772
|
+
client_logger.error(
|
|
1773
|
+
f"get_latest_attempt failed after all retries for rollout_id={rollout_id}: {e}", exc_info=True
|
|
1774
|
+
)
|
|
1775
|
+
return None
|
|
1776
|
+
|
|
1777
|
+
async def get_rollout_by_id(self, rollout_id: str) -> Optional[Rollout]:
|
|
1778
|
+
"""
|
|
1779
|
+
Get a rollout by its ID.
|
|
1780
|
+
|
|
1781
|
+
Args:
|
|
1782
|
+
rollout_id: ID of the rollout to retrieve.
|
|
1783
|
+
|
|
1784
|
+
Returns:
|
|
1785
|
+
Rollout if found, None if not found or if all retries are exhausted.
|
|
1786
|
+
|
|
1787
|
+
Note:
|
|
1788
|
+
This method retries on transient failures (network errors, 5xx status codes).
|
|
1789
|
+
If all retries fail, it logs the error and returns None instead of raising an exception.
|
|
1790
|
+
"""
|
|
1791
|
+
try:
|
|
1792
|
+
data = await self._request_json("get", f"/rollouts/{rollout_id}")
|
|
1793
|
+
if data is None:
|
|
1794
|
+
return None
|
|
1795
|
+
elif isinstance(data, dict) and "attempt" in data:
|
|
1796
|
+
return AttemptedRollout.model_validate(data)
|
|
1797
|
+
else:
|
|
1798
|
+
return Rollout.model_validate(data)
|
|
1799
|
+
except Exception as e:
|
|
1800
|
+
client_logger.error(
|
|
1801
|
+
f"get_rollout_by_id failed after all retries for rollout_id={rollout_id}: {e}", exc_info=True
|
|
1802
|
+
)
|
|
1803
|
+
return None
|
|
1804
|
+
|
|
1805
|
+
async def query_resources(
|
|
1806
|
+
self,
|
|
1807
|
+
*,
|
|
1808
|
+
resources_id: Optional[str] = None,
|
|
1809
|
+
resources_id_contains: Optional[str] = None,
|
|
1810
|
+
sort_by: Optional[str] = None,
|
|
1811
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1812
|
+
limit: int = -1,
|
|
1813
|
+
offset: int = 0,
|
|
1814
|
+
) -> PaginatedResult[ResourcesUpdate]:
|
|
1815
|
+
"""
|
|
1816
|
+
List all resource snapshots stored on the server.
|
|
1817
|
+
"""
|
|
1818
|
+
params: List[Tuple[str, Any]] = [
|
|
1819
|
+
("limit", limit),
|
|
1820
|
+
("offset", offset),
|
|
1821
|
+
]
|
|
1822
|
+
if sort_by is not None:
|
|
1823
|
+
params.append(("sort_by", sort_by))
|
|
1824
|
+
params.append(("sort_order", sort_order))
|
|
1825
|
+
if resources_id is not None:
|
|
1826
|
+
params.append(("resources_id", resources_id))
|
|
1827
|
+
if resources_id_contains is not None:
|
|
1828
|
+
params.append(("resources_id_contains", resources_id_contains))
|
|
1829
|
+
|
|
1830
|
+
data = await self._request_json("get", "/resources", params=params)
|
|
1831
|
+
items = [ResourcesUpdate.model_validate(item) for item in data["items"]]
|
|
1832
|
+
return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
|
|
1833
|
+
|
|
1834
|
+
async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
|
|
1835
|
+
data = await self._request_json("post", "/resources", json=TypeAdapter(NamedResources).dump_python(resources))
|
|
1836
|
+
return ResourcesUpdate.model_validate(data)
|
|
1837
|
+
|
|
1838
|
+
async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
|
|
1839
|
+
data = await self._request_json(
|
|
1840
|
+
"post", f"/resources/{resources_id}", json=TypeAdapter(NamedResources).dump_python(resources)
|
|
1841
|
+
)
|
|
1842
|
+
return ResourcesUpdate.model_validate(data)
|
|
1843
|
+
|
|
1844
|
+
async def get_resources_by_id(self, resources_id: str) -> Optional[ResourcesUpdate]:
|
|
1845
|
+
"""
|
|
1846
|
+
Get resources by their ID.
|
|
1847
|
+
|
|
1848
|
+
Args:
|
|
1849
|
+
resources_id: ID of the resources to retrieve.
|
|
1850
|
+
|
|
1851
|
+
Returns:
|
|
1852
|
+
ResourcesUpdate if found, None if not found or if all retries are exhausted.
|
|
1853
|
+
|
|
1854
|
+
Note:
|
|
1855
|
+
This method retries on transient failures (network errors, 5xx status codes).
|
|
1856
|
+
If all retries fail, it logs the error and returns None instead of raising an exception.
|
|
1857
|
+
"""
|
|
1858
|
+
try:
|
|
1859
|
+
data = await self._request_json("get", f"/resources/{resources_id}")
|
|
1860
|
+
return ResourcesUpdate.model_validate(data) if data else None
|
|
1861
|
+
except Exception as e:
|
|
1862
|
+
client_logger.error(
|
|
1863
|
+
f"get_resources_by_id failed after all retries for resources_id={resources_id}: {e}", exc_info=True
|
|
1864
|
+
)
|
|
1865
|
+
return None
|
|
1866
|
+
|
|
1867
|
+
async def get_latest_resources(self) -> Optional[ResourcesUpdate]:
|
|
1868
|
+
"""
|
|
1869
|
+
Get the latest resources.
|
|
1870
|
+
|
|
1871
|
+
Returns:
|
|
1872
|
+
ResourcesUpdate if found, None if not found or if all retries are exhausted.
|
|
1873
|
+
|
|
1874
|
+
Note:
|
|
1875
|
+
This method retries on transient failures (network errors, 5xx status codes).
|
|
1876
|
+
If all retries fail, it logs the error and returns None instead of raising an exception.
|
|
1877
|
+
"""
|
|
1878
|
+
try:
|
|
1879
|
+
data = await self._request_json("get", "/resources/latest")
|
|
1880
|
+
return ResourcesUpdate.model_validate(data) if data else None
|
|
1881
|
+
except Exception as e:
|
|
1882
|
+
client_logger.error(f"get_latest_resources failed after all retries: {e}", exc_info=True)
|
|
1883
|
+
return None
|
|
1884
|
+
|
|
1885
|
+
async def add_span(self, span: Span) -> Optional[Span]:
|
|
1886
|
+
data = await self._request_json("post", "/spans", json=span.model_dump(mode="json"))
|
|
1887
|
+
return Span.model_validate(data) if data is not None else None
|
|
1888
|
+
|
|
1889
|
+
async def add_many_spans(self, spans: Sequence[Span]) -> Sequence[Span]:
|
|
1890
|
+
result: List[Span] = []
|
|
1891
|
+
for span in spans:
|
|
1892
|
+
ret = await self.add_span(span)
|
|
1893
|
+
if ret is not None:
|
|
1894
|
+
result.append(ret)
|
|
1895
|
+
return result
|
|
1896
|
+
|
|
1897
|
+
async def get_next_span_sequence_id(self, rollout_id: str, attempt_id: str) -> int:
|
|
1898
|
+
data = await self._request_json(
|
|
1899
|
+
"post",
|
|
1900
|
+
"/spans/next",
|
|
1901
|
+
json=NextSequenceIdRequest(rollout_id=rollout_id, attempt_id=attempt_id).model_dump(),
|
|
1902
|
+
)
|
|
1903
|
+
response = NextSequenceIdResponse.model_validate(data)
|
|
1904
|
+
return response.sequence_id
|
|
1905
|
+
|
|
1906
|
+
async def get_many_span_sequence_ids(self, rollout_attempt_ids: Sequence[Tuple[str, str]]) -> Sequence[int]:
|
|
1907
|
+
return [
|
|
1908
|
+
await self.get_next_span_sequence_id(rollout_id, attempt_id)
|
|
1909
|
+
for rollout_id, attempt_id in rollout_attempt_ids
|
|
1910
|
+
]
|
|
1911
|
+
|
|
1912
|
+
async def add_otel_span(
|
|
1913
|
+
self,
|
|
1914
|
+
rollout_id: str,
|
|
1915
|
+
attempt_id: str,
|
|
1916
|
+
readable_span: ReadableSpan,
|
|
1917
|
+
sequence_id: int | None = None,
|
|
1918
|
+
) -> Optional[Span]:
|
|
1919
|
+
# unchanged logic, now benefits from retries inside add_span/get_next_span_sequence_id
|
|
1920
|
+
if sequence_id is None:
|
|
1921
|
+
sequence_id = await self.get_next_span_sequence_id(rollout_id, attempt_id)
|
|
1922
|
+
span = Span.from_opentelemetry(
|
|
1923
|
+
readable_span,
|
|
1924
|
+
rollout_id=rollout_id,
|
|
1925
|
+
attempt_id=attempt_id,
|
|
1926
|
+
sequence_id=sequence_id,
|
|
1927
|
+
)
|
|
1928
|
+
return await self.add_span(span)
|
|
1929
|
+
|
|
1930
|
+
async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
|
|
1931
|
+
"""Wait for rollouts to complete.
|
|
1932
|
+
|
|
1933
|
+
Args:
|
|
1934
|
+
rollout_ids: List of rollout IDs to wait for.
|
|
1935
|
+
timeout: Timeout in seconds. If not None, the method will raise a ValueError if the timeout is greater than 0.1 seconds.
|
|
1936
|
+
|
|
1937
|
+
Returns:
|
|
1938
|
+
List of rollouts that are completed.
|
|
1939
|
+
"""
|
|
1940
|
+
if timeout is not None and timeout > 0.1:
|
|
1941
|
+
raise ValueError(
|
|
1942
|
+
"Timeout must be less than 0.1 seconds in LightningStoreClient to avoid blocking the event loop"
|
|
1943
|
+
)
|
|
1944
|
+
data = await self._request_json(
|
|
1945
|
+
"post",
|
|
1946
|
+
"/waits/rollouts",
|
|
1947
|
+
json=WaitForRolloutsRequest(rollout_ids=rollout_ids, timeout=timeout).model_dump(),
|
|
1948
|
+
)
|
|
1949
|
+
return [Rollout.model_validate(item) for item in data]
|
|
1950
|
+
|
|
1951
|
+
async def query_spans(
|
|
1952
|
+
self,
|
|
1953
|
+
rollout_id: str,
|
|
1954
|
+
attempt_id: str | Literal["latest"] | None = None,
|
|
1955
|
+
*,
|
|
1956
|
+
trace_id: Optional[str] = None,
|
|
1957
|
+
trace_id_contains: Optional[str] = None,
|
|
1958
|
+
span_id: Optional[str] = None,
|
|
1959
|
+
span_id_contains: Optional[str] = None,
|
|
1960
|
+
parent_id: Optional[str] = None,
|
|
1961
|
+
parent_id_contains: Optional[str] = None,
|
|
1962
|
+
name: Optional[str] = None,
|
|
1963
|
+
name_contains: Optional[str] = None,
|
|
1964
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
1965
|
+
limit: int = -1,
|
|
1966
|
+
offset: int = 0,
|
|
1967
|
+
sort_by: Optional[str] = "sequence_id",
|
|
1968
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
1969
|
+
) -> PaginatedResult[Span]:
|
|
1970
|
+
payload: Dict[str, Any] = {"rollout_id": rollout_id, "limit": limit, "offset": offset}
|
|
1971
|
+
if attempt_id is not None:
|
|
1972
|
+
payload["attempt_id"] = attempt_id
|
|
1973
|
+
if trace_id is not None:
|
|
1974
|
+
payload["trace_id"] = trace_id
|
|
1975
|
+
if trace_id_contains is not None:
|
|
1976
|
+
payload["trace_id_contains"] = trace_id_contains
|
|
1977
|
+
if span_id is not None:
|
|
1978
|
+
payload["span_id"] = span_id
|
|
1979
|
+
if span_id_contains is not None:
|
|
1980
|
+
payload["span_id_contains"] = span_id_contains
|
|
1981
|
+
if parent_id is not None:
|
|
1982
|
+
payload["parent_id"] = parent_id
|
|
1983
|
+
if parent_id_contains is not None:
|
|
1984
|
+
payload["parent_id_contains"] = parent_id_contains
|
|
1985
|
+
if name is not None:
|
|
1986
|
+
payload["name"] = name
|
|
1987
|
+
if name_contains is not None:
|
|
1988
|
+
payload["name_contains"] = name_contains
|
|
1989
|
+
payload["filter_logic"] = filter_logic
|
|
1990
|
+
if sort_by is not None:
|
|
1991
|
+
payload["sort_by"] = sort_by
|
|
1992
|
+
payload["sort_order"] = sort_order
|
|
1993
|
+
data = await self._request_json("post", "/spans/search", json=payload)
|
|
1994
|
+
items = [Span.model_validate(item) for item in data["items"]]
|
|
1995
|
+
return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
|
|
1996
|
+
|
|
1997
|
+
async def update_rollout(
|
|
1998
|
+
self,
|
|
1999
|
+
rollout_id: str,
|
|
2000
|
+
input: TaskInput | Unset = UNSET,
|
|
2001
|
+
mode: Optional[Literal["train", "val", "test"]] | Unset = UNSET,
|
|
2002
|
+
resources_id: Optional[str] | Unset = UNSET,
|
|
2003
|
+
status: RolloutStatus | Unset = UNSET,
|
|
2004
|
+
config: RolloutConfig | Unset = UNSET,
|
|
2005
|
+
metadata: Optional[Dict[str, Any]] | Unset = UNSET,
|
|
2006
|
+
) -> Rollout:
|
|
2007
|
+
payload: Dict[str, Any] = {}
|
|
2008
|
+
if not isinstance(input, Unset):
|
|
2009
|
+
payload["input"] = input
|
|
2010
|
+
if not isinstance(mode, Unset):
|
|
2011
|
+
payload["mode"] = mode
|
|
2012
|
+
if not isinstance(resources_id, Unset):
|
|
2013
|
+
payload["resources_id"] = resources_id
|
|
2014
|
+
if not isinstance(status, Unset):
|
|
2015
|
+
payload["status"] = status
|
|
2016
|
+
if not isinstance(config, Unset):
|
|
2017
|
+
payload["config"] = config.model_dump()
|
|
2018
|
+
if not isinstance(metadata, Unset):
|
|
2019
|
+
payload["metadata"] = metadata
|
|
2020
|
+
|
|
2021
|
+
data = await self._request_json("post", f"/rollouts/{rollout_id}", json=payload)
|
|
2022
|
+
return Rollout.model_validate(data)
|
|
2023
|
+
|
|
2024
|
+
async def update_attempt(
|
|
2025
|
+
self,
|
|
2026
|
+
rollout_id: str,
|
|
2027
|
+
attempt_id: str | Literal["latest"],
|
|
2028
|
+
status: AttemptStatus | Unset = UNSET,
|
|
2029
|
+
worker_id: str | Unset = UNSET,
|
|
2030
|
+
last_heartbeat_time: float | Unset = UNSET,
|
|
2031
|
+
metadata: Optional[Dict[str, Any]] | Unset = UNSET,
|
|
2032
|
+
) -> Attempt:
|
|
2033
|
+
payload: Dict[str, Any] = {}
|
|
2034
|
+
if not isinstance(status, Unset):
|
|
2035
|
+
payload["status"] = status
|
|
2036
|
+
if not isinstance(worker_id, Unset):
|
|
2037
|
+
payload["worker_id"] = worker_id
|
|
2038
|
+
if not isinstance(last_heartbeat_time, Unset):
|
|
2039
|
+
payload["last_heartbeat_time"] = last_heartbeat_time
|
|
2040
|
+
if not isinstance(metadata, Unset):
|
|
2041
|
+
payload["metadata"] = metadata
|
|
2042
|
+
|
|
2043
|
+
data = await self._request_json(
|
|
2044
|
+
"post",
|
|
2045
|
+
f"/rollouts/{rollout_id}/attempts/{attempt_id}",
|
|
2046
|
+
json=payload,
|
|
2047
|
+
)
|
|
2048
|
+
return Attempt.model_validate(data)
|
|
2049
|
+
|
|
2050
|
+
async def query_workers(
|
|
2051
|
+
self,
|
|
2052
|
+
*,
|
|
2053
|
+
status_in: Optional[Sequence[WorkerStatus]] = None,
|
|
2054
|
+
worker_id_contains: Optional[str] = None,
|
|
2055
|
+
filter_logic: Literal["and", "or"] = "and",
|
|
2056
|
+
sort_by: Optional[str] = None,
|
|
2057
|
+
sort_order: Literal["asc", "desc"] = "asc",
|
|
2058
|
+
limit: int = -1,
|
|
2059
|
+
offset: int = 0,
|
|
2060
|
+
) -> PaginatedResult[Worker]:
|
|
2061
|
+
payload: Dict[str, Any] = {}
|
|
2062
|
+
if status_in is not None:
|
|
2063
|
+
payload["status_in"] = status_in
|
|
2064
|
+
if worker_id_contains is not None:
|
|
2065
|
+
payload["worker_id_contains"] = worker_id_contains
|
|
2066
|
+
payload["filter_logic"] = filter_logic
|
|
2067
|
+
if sort_by is not None:
|
|
2068
|
+
payload["sort_by"] = sort_by
|
|
2069
|
+
payload["sort_order"] = sort_order
|
|
2070
|
+
|
|
2071
|
+
data = await self._request_json("post", "/workers/search", json=payload)
|
|
2072
|
+
items = [Worker.model_validate(item) for item in data.get("items", [])]
|
|
2073
|
+
return PaginatedResult(items=items, limit=data["limit"], offset=data["offset"], total=data["total"])
|
|
2074
|
+
|
|
2075
|
+
async def get_worker_by_id(self, worker_id: str) -> Optional[Worker]:
|
|
2076
|
+
data = await self._request_json("get", f"/workers/{worker_id}")
|
|
2077
|
+
if data is None:
|
|
2078
|
+
return None
|
|
2079
|
+
return Worker.model_validate(data)
|
|
2080
|
+
|
|
2081
|
+
async def update_worker(
|
|
2082
|
+
self,
|
|
2083
|
+
worker_id: str,
|
|
2084
|
+
heartbeat_stats: Dict[str, Any] | Unset = UNSET,
|
|
2085
|
+
) -> Worker:
|
|
2086
|
+
payload: Dict[str, Any] = {}
|
|
2087
|
+
if not isinstance(heartbeat_stats, Unset):
|
|
2088
|
+
payload["heartbeat_stats"] = heartbeat_stats
|
|
2089
|
+
json_payload = payload if payload else None
|
|
2090
|
+
|
|
2091
|
+
data = await self._request_json("post", f"/workers/{worker_id}", json=json_payload)
|
|
2092
|
+
return Worker.model_validate(data)
|