tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
tuft/server.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
1
|
+
"""FastAPI application exposing a local-compatible Tinker API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from datetime import timezone
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import Any, Callable
|
|
10
|
+
|
|
11
|
+
from fastapi import Depends, FastAPI, HTTPException, Query, Request, status
|
|
12
|
+
from fastapi.responses import Response
|
|
13
|
+
from fastapi.security import APIKeyHeader
|
|
14
|
+
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
from tinker import types
|
|
17
|
+
|
|
18
|
+
from .auth import User
|
|
19
|
+
from .config import AppConfig
|
|
20
|
+
from .exceptions import TuFTException
|
|
21
|
+
from .persistence import get_redis_store
|
|
22
|
+
from .state import ServerState
|
|
23
|
+
from .telemetry import shutdown_telemetry
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def _get_user(
|
|
32
|
+
request: Request,
|
|
33
|
+
api_key: str = Depends(api_key_header),
|
|
34
|
+
) -> User:
|
|
35
|
+
if not api_key:
|
|
36
|
+
raise HTTPException(
|
|
37
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
38
|
+
detail="Missing API key",
|
|
39
|
+
)
|
|
40
|
+
user = request.app.state.server_state.get_user(api_key)
|
|
41
|
+
if user is None:
|
|
42
|
+
raise HTTPException(
|
|
43
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
44
|
+
detail="Invalid API key",
|
|
45
|
+
)
|
|
46
|
+
return user
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class WeightsInfoBody(BaseModel):
|
|
50
|
+
tinker_path: str
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _normalize_checkpoint_id(raw: str) -> str:
|
|
54
|
+
if "/" not in raw:
|
|
55
|
+
return raw
|
|
56
|
+
prefix, remainder = raw.split("/", 1)
|
|
57
|
+
if prefix not in {"weights", "sampler_weights"} or not remainder:
|
|
58
|
+
raise HTTPException(
|
|
59
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
60
|
+
detail="Invalid checkpoint reference",
|
|
61
|
+
)
|
|
62
|
+
return remainder
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_state(request: Request) -> ServerState:
|
|
66
|
+
state = getattr(request.app.state, "server_state", None)
|
|
67
|
+
if state is None:
|
|
68
|
+
raise RuntimeError("Server state has not been initialized")
|
|
69
|
+
return state
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _instrument_fastapi(app: FastAPI) -> None:
|
|
73
|
+
"""Instrument FastAPI app with OpenTelemetry."""
|
|
74
|
+
FastAPIInstrumentor.instrument_app(app)
|
|
75
|
+
logger.debug("FastAPI instrumentation enabled")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def create_root_app(config: AppConfig | None = None) -> FastAPI:
|
|
79
|
+
@asynccontextmanager
|
|
80
|
+
async def lifespan(app: FastAPI):
|
|
81
|
+
try:
|
|
82
|
+
await app.state.server_state.async_init()
|
|
83
|
+
logger.info("Server initialized successfully")
|
|
84
|
+
yield
|
|
85
|
+
finally:
|
|
86
|
+
logger.info("Server shutting down")
|
|
87
|
+
await app.state.server_state.future_store.shutdown()
|
|
88
|
+
store = get_redis_store()
|
|
89
|
+
if store.is_enabled:
|
|
90
|
+
store.close()
|
|
91
|
+
shutdown_telemetry()
|
|
92
|
+
|
|
93
|
+
def require_user_dependency(route):
|
|
94
|
+
if not any(dep.dependency == _get_user for dep in getattr(route, "dependencies", [])):
|
|
95
|
+
route.dependencies = getattr(route, "dependencies", []) + [Depends(_get_user)]
|
|
96
|
+
return route
|
|
97
|
+
|
|
98
|
+
resolved_config = config or AppConfig()
|
|
99
|
+
if resolved_config.persistence.enabled:
|
|
100
|
+
store = get_redis_store()
|
|
101
|
+
store.configure(resolved_config.persistence)
|
|
102
|
+
|
|
103
|
+
app = FastAPI(
|
|
104
|
+
title="TuFT",
|
|
105
|
+
version="0.1.0",
|
|
106
|
+
lifespan=lifespan,
|
|
107
|
+
)
|
|
108
|
+
app.state.server_state = ServerState(resolved_config)
|
|
109
|
+
|
|
110
|
+
# Instrument FastAPI with OpenTelemetry if enabled
|
|
111
|
+
if resolved_config.telemetry.enabled:
|
|
112
|
+
_instrument_fastapi(app)
|
|
113
|
+
|
|
114
|
+
@app.get("/api/v1/healthz", response_model=types.HealthResponse)
|
|
115
|
+
async def healthz() -> types.HealthResponse:
|
|
116
|
+
return types.HealthResponse(status="ok")
|
|
117
|
+
|
|
118
|
+
@app.get(
|
|
119
|
+
"/api/v1/get_server_capabilities",
|
|
120
|
+
response_model=types.GetServerCapabilitiesResponse,
|
|
121
|
+
)
|
|
122
|
+
async def get_server_capabilities(
|
|
123
|
+
state: ServerState = Depends(_get_state),
|
|
124
|
+
) -> types.GetServerCapabilitiesResponse:
|
|
125
|
+
return types.GetServerCapabilitiesResponse(supported_models=state.build_supported_models())
|
|
126
|
+
|
|
127
|
+
@app.post(
|
|
128
|
+
"/api/v1/create_session",
|
|
129
|
+
response_model=types.CreateSessionResponse,
|
|
130
|
+
status_code=status.HTTP_201_CREATED,
|
|
131
|
+
)
|
|
132
|
+
async def create_session(
|
|
133
|
+
request: types.CreateSessionRequest,
|
|
134
|
+
state: ServerState = Depends(_get_state),
|
|
135
|
+
user: User = Depends(_get_user),
|
|
136
|
+
) -> types.CreateSessionResponse:
|
|
137
|
+
record = state.create_session(request, user)
|
|
138
|
+
return types.CreateSessionResponse(session_id=record.session_id)
|
|
139
|
+
|
|
140
|
+
@app.post(
|
|
141
|
+
"/api/v1/session_heartbeat",
|
|
142
|
+
response_model=types.SessionHeartbeatResponse,
|
|
143
|
+
)
|
|
144
|
+
async def session_heartbeat(
|
|
145
|
+
request: types.SessionHeartbeatRequest,
|
|
146
|
+
state: ServerState = Depends(_get_state),
|
|
147
|
+
user: User = Depends(_get_user),
|
|
148
|
+
) -> types.SessionHeartbeatResponse:
|
|
149
|
+
state.heartbeat(request.session_id, user_id=user.user_id)
|
|
150
|
+
return types.SessionHeartbeatResponse()
|
|
151
|
+
|
|
152
|
+
@app.post(
|
|
153
|
+
"/api/v1/create_sampling_session",
|
|
154
|
+
response_model=types.CreateSamplingSessionResponse,
|
|
155
|
+
status_code=status.HTTP_201_CREATED,
|
|
156
|
+
)
|
|
157
|
+
async def create_sampling_session(
|
|
158
|
+
request: types.CreateSamplingSessionRequest,
|
|
159
|
+
state: ServerState = Depends(_get_state),
|
|
160
|
+
user: User = Depends(_get_user),
|
|
161
|
+
) -> types.CreateSamplingSessionResponse:
|
|
162
|
+
try:
|
|
163
|
+
sampling_session_id = await state.create_sampling_session(
|
|
164
|
+
session_id=request.session_id,
|
|
165
|
+
user_id=user.user_id,
|
|
166
|
+
base_model=request.base_model,
|
|
167
|
+
model_path=request.model_path,
|
|
168
|
+
session_seq_id=request.sampling_session_seq_id,
|
|
169
|
+
)
|
|
170
|
+
except TuFTException as exc:
|
|
171
|
+
raise HTTPException(
|
|
172
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
173
|
+
detail=f"Failed to create sampling session: {exc.detail}",
|
|
174
|
+
) from exc
|
|
175
|
+
except Exception as exc: # pylint: disable=broad-except
|
|
176
|
+
raise HTTPException(
|
|
177
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
178
|
+
detail=f"Failed to create sampling session: {str(exc)}",
|
|
179
|
+
) from exc
|
|
180
|
+
return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id)
|
|
181
|
+
|
|
182
|
+
@app.post(
|
|
183
|
+
"/api/v1/create_model",
|
|
184
|
+
response_model=types.UntypedAPIFuture,
|
|
185
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
186
|
+
)
|
|
187
|
+
async def create_model(
|
|
188
|
+
request: types.CreateModelRequest,
|
|
189
|
+
state: ServerState = Depends(_get_state),
|
|
190
|
+
user: User = Depends(_get_user),
|
|
191
|
+
) -> types.UntypedAPIFuture:
|
|
192
|
+
if request.lora_config is None:
|
|
193
|
+
raise HTTPException(
|
|
194
|
+
status_code=status.HTTP_400_BAD_REQUEST, detail="Missing LoRA config"
|
|
195
|
+
)
|
|
196
|
+
try:
|
|
197
|
+
training_record = await state.create_model(
|
|
198
|
+
session_id=request.session_id,
|
|
199
|
+
base_model=request.base_model,
|
|
200
|
+
lora_config=request.lora_config,
|
|
201
|
+
model_owner=user.user_id,
|
|
202
|
+
user_metadata=request.user_metadata,
|
|
203
|
+
)
|
|
204
|
+
except TuFTException as exc:
|
|
205
|
+
raise HTTPException(
|
|
206
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
207
|
+
detail=f"Failed to create model: {exc.detail}",
|
|
208
|
+
) from exc
|
|
209
|
+
except Exception as exc: # pylint: disable=broad-except
|
|
210
|
+
raise HTTPException(
|
|
211
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
212
|
+
detail=f"Failed to create model: {str(exc)}",
|
|
213
|
+
) from exc
|
|
214
|
+
response = types.CreateModelResponse(model_id=training_record.training_run_id)
|
|
215
|
+
return await state.future_store.create_ready_future(
|
|
216
|
+
response,
|
|
217
|
+
model_id=training_record.training_run_id,
|
|
218
|
+
user_id=user.user_id,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
@app.post(
|
|
222
|
+
"/api/v1/get_info",
|
|
223
|
+
response_model=types.GetInfoResponse,
|
|
224
|
+
)
|
|
225
|
+
async def get_info(
|
|
226
|
+
request: types.GetInfoRequest,
|
|
227
|
+
state: ServerState = Depends(_get_state),
|
|
228
|
+
user: User = Depends(_get_user),
|
|
229
|
+
) -> types.GetInfoResponse:
|
|
230
|
+
return state.get_model_info(request.model_id, user_id=user.user_id)
|
|
231
|
+
|
|
232
|
+
@app.post(
|
|
233
|
+
"/api/v1/unload_model",
|
|
234
|
+
response_model=types.UntypedAPIFuture,
|
|
235
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
236
|
+
)
|
|
237
|
+
async def unload_model(
|
|
238
|
+
request: types.UnloadModelRequest,
|
|
239
|
+
state: ServerState = Depends(_get_state),
|
|
240
|
+
user: User = Depends(_get_user),
|
|
241
|
+
) -> types.UntypedAPIFuture:
|
|
242
|
+
try:
|
|
243
|
+
await state.unload_model(request.model_id, user_id=user.user_id)
|
|
244
|
+
except TuFTException as exc:
|
|
245
|
+
raise HTTPException(
|
|
246
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
247
|
+
detail=f"Failed to unload model: {exc.detail}",
|
|
248
|
+
) from exc
|
|
249
|
+
except Exception as exc: # pylint: disable=broad-except
|
|
250
|
+
raise HTTPException(
|
|
251
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
252
|
+
detail=f"Failed to unload model: {str(exc)}",
|
|
253
|
+
) from exc
|
|
254
|
+
response = types.UnloadModelResponse(model_id=request.model_id)
|
|
255
|
+
return await state.future_store.create_ready_future(
|
|
256
|
+
response, model_id=request.model_id, user_id=user.user_id
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def _queue_future(
|
|
260
|
+
operation: Callable[[], Any],
|
|
261
|
+
state: ServerState,
|
|
262
|
+
user_id: str,
|
|
263
|
+
*,
|
|
264
|
+
model_id: str | None = None,
|
|
265
|
+
operation_type: str | None = None,
|
|
266
|
+
operation_args: dict[str, Any] | None = None,
|
|
267
|
+
) -> types.UntypedAPIFuture:
|
|
268
|
+
return await state.future_store.enqueue(
|
|
269
|
+
operation,
|
|
270
|
+
model_id=model_id,
|
|
271
|
+
user_id=user_id,
|
|
272
|
+
operation_type=operation_type, # type: ignore[arg-type]
|
|
273
|
+
operation_args=operation_args,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
@app.post(
|
|
277
|
+
"/api/v1/forward",
|
|
278
|
+
response_model=types.UntypedAPIFuture,
|
|
279
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
280
|
+
)
|
|
281
|
+
async def forward(
|
|
282
|
+
request: types.ForwardRequest,
|
|
283
|
+
state: ServerState = Depends(_get_state),
|
|
284
|
+
user: User = Depends(_get_user),
|
|
285
|
+
) -> types.UntypedAPIFuture:
|
|
286
|
+
inp = request.forward_input
|
|
287
|
+
|
|
288
|
+
async def _operation() -> types.ForwardBackwardOutput:
|
|
289
|
+
return await state.run_forward(
|
|
290
|
+
request.model_id,
|
|
291
|
+
user.user_id,
|
|
292
|
+
inp.data,
|
|
293
|
+
inp.loss_fn,
|
|
294
|
+
inp.loss_fn_config,
|
|
295
|
+
request.seq_id,
|
|
296
|
+
backward=False,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return await _queue_future(
|
|
300
|
+
_operation,
|
|
301
|
+
state,
|
|
302
|
+
model_id=request.model_id,
|
|
303
|
+
user_id=user.user_id,
|
|
304
|
+
operation_type="forward",
|
|
305
|
+
operation_args={
|
|
306
|
+
"model_id": request.model_id,
|
|
307
|
+
"user_id": user.user_id,
|
|
308
|
+
"data": inp.data,
|
|
309
|
+
"loss_fn": inp.loss_fn,
|
|
310
|
+
"loss_fn_config": inp.loss_fn_config,
|
|
311
|
+
"seq_id": request.seq_id,
|
|
312
|
+
"backward": False,
|
|
313
|
+
},
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
@app.post(
|
|
317
|
+
"/api/v1/forward_backward",
|
|
318
|
+
response_model=types.UntypedAPIFuture,
|
|
319
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
320
|
+
)
|
|
321
|
+
async def forward_backward(
|
|
322
|
+
request: types.ForwardBackwardRequest,
|
|
323
|
+
state: ServerState = Depends(_get_state),
|
|
324
|
+
user: User = Depends(_get_user),
|
|
325
|
+
) -> types.UntypedAPIFuture:
|
|
326
|
+
inp = request.forward_backward_input
|
|
327
|
+
|
|
328
|
+
async def _operation() -> types.ForwardBackwardOutput:
|
|
329
|
+
return await state.run_forward(
|
|
330
|
+
request.model_id,
|
|
331
|
+
user.user_id,
|
|
332
|
+
inp.data,
|
|
333
|
+
inp.loss_fn,
|
|
334
|
+
inp.loss_fn_config,
|
|
335
|
+
request.seq_id,
|
|
336
|
+
backward=True,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
return await _queue_future(
|
|
340
|
+
_operation,
|
|
341
|
+
state,
|
|
342
|
+
model_id=request.model_id,
|
|
343
|
+
user_id=user.user_id,
|
|
344
|
+
operation_type="forward_backward",
|
|
345
|
+
operation_args={
|
|
346
|
+
"model_id": request.model_id,
|
|
347
|
+
"user_id": user.user_id,
|
|
348
|
+
"data": inp.data,
|
|
349
|
+
"loss_fn": inp.loss_fn,
|
|
350
|
+
"loss_fn_config": inp.loss_fn_config,
|
|
351
|
+
"seq_id": request.seq_id,
|
|
352
|
+
"backward": True,
|
|
353
|
+
},
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
@app.post(
|
|
357
|
+
"/api/v1/optim_step",
|
|
358
|
+
response_model=types.UntypedAPIFuture,
|
|
359
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
360
|
+
)
|
|
361
|
+
async def optim_step(
|
|
362
|
+
request: types.OptimStepRequest,
|
|
363
|
+
state: ServerState = Depends(_get_state),
|
|
364
|
+
user: User = Depends(_get_user),
|
|
365
|
+
) -> types.UntypedAPIFuture:
|
|
366
|
+
async def _operation() -> types.OptimStepResponse:
|
|
367
|
+
return await state.run_optim_step(
|
|
368
|
+
request.model_id,
|
|
369
|
+
user.user_id,
|
|
370
|
+
request.adam_params,
|
|
371
|
+
request.seq_id,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return await _queue_future(
|
|
375
|
+
_operation,
|
|
376
|
+
state,
|
|
377
|
+
model_id=request.model_id,
|
|
378
|
+
user_id=user.user_id,
|
|
379
|
+
operation_type="optim_step",
|
|
380
|
+
operation_args={
|
|
381
|
+
"model_id": request.model_id,
|
|
382
|
+
"user_id": user.user_id,
|
|
383
|
+
"params": request.adam_params,
|
|
384
|
+
"seq_id": request.seq_id,
|
|
385
|
+
},
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
@app.post(
|
|
389
|
+
"/api/v1/save_weights",
|
|
390
|
+
response_model=types.UntypedAPIFuture,
|
|
391
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
392
|
+
)
|
|
393
|
+
async def save_weights(
|
|
394
|
+
request: types.SaveWeightsRequest,
|
|
395
|
+
state: ServerState = Depends(_get_state),
|
|
396
|
+
user: User = Depends(_get_user),
|
|
397
|
+
) -> types.UntypedAPIFuture:
|
|
398
|
+
async def _operation() -> types.SaveWeightsResponse:
|
|
399
|
+
checkpoint = await state.save_checkpoint(
|
|
400
|
+
request.model_id,
|
|
401
|
+
user.user_id,
|
|
402
|
+
request.path,
|
|
403
|
+
"training",
|
|
404
|
+
seq_id=request.seq_id,
|
|
405
|
+
)
|
|
406
|
+
return types.SaveWeightsResponse(path=checkpoint.tinker_checkpoint.tinker_path)
|
|
407
|
+
|
|
408
|
+
return await _queue_future(
|
|
409
|
+
_operation,
|
|
410
|
+
state,
|
|
411
|
+
model_id=request.model_id,
|
|
412
|
+
user_id=user.user_id,
|
|
413
|
+
operation_type="save_weights",
|
|
414
|
+
operation_args={
|
|
415
|
+
"model_id": request.model_id,
|
|
416
|
+
"user_id": user.user_id,
|
|
417
|
+
"name": request.path,
|
|
418
|
+
"checkpoint_type": "training",
|
|
419
|
+
},
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
@app.post(
|
|
423
|
+
"/api/v1/save_weights_for_sampler",
|
|
424
|
+
response_model=types.UntypedAPIFuture,
|
|
425
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
426
|
+
)
|
|
427
|
+
async def save_weights_for_sampler(
|
|
428
|
+
request: types.SaveWeightsForSamplerRequest,
|
|
429
|
+
state: ServerState = Depends(_get_state),
|
|
430
|
+
user: User = Depends(_get_user),
|
|
431
|
+
) -> types.UntypedAPIFuture:
|
|
432
|
+
async def _operation() -> types.SaveWeightsForSamplerResponse:
|
|
433
|
+
checkpoint = await state.save_checkpoint(
|
|
434
|
+
request.model_id,
|
|
435
|
+
user.user_id,
|
|
436
|
+
request.path,
|
|
437
|
+
"sampler",
|
|
438
|
+
seq_id=request.seq_id,
|
|
439
|
+
)
|
|
440
|
+
return types.SaveWeightsForSamplerResponse(
|
|
441
|
+
path=checkpoint.tinker_checkpoint.tinker_path
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
return await _queue_future(
|
|
445
|
+
_operation,
|
|
446
|
+
state,
|
|
447
|
+
model_id=request.model_id,
|
|
448
|
+
user_id=user.user_id,
|
|
449
|
+
operation_type="save_weights_for_sampler",
|
|
450
|
+
operation_args={
|
|
451
|
+
"model_id": request.model_id,
|
|
452
|
+
"user_id": user.user_id,
|
|
453
|
+
"name": request.path,
|
|
454
|
+
"checkpoint_type": "sampler",
|
|
455
|
+
},
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
@app.post(
|
|
459
|
+
"/api/v1/load_weights",
|
|
460
|
+
response_model=types.UntypedAPIFuture,
|
|
461
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
462
|
+
)
|
|
463
|
+
async def load_weights(
|
|
464
|
+
request: types.LoadWeightsRequest,
|
|
465
|
+
state: ServerState = Depends(_get_state),
|
|
466
|
+
user: User = Depends(_get_user),
|
|
467
|
+
) -> types.UntypedAPIFuture:
|
|
468
|
+
async def _operation() -> types.LoadWeightsResponse:
|
|
469
|
+
await state.load_checkpoint(
|
|
470
|
+
model_id=request.model_id,
|
|
471
|
+
user_id=user.user_id,
|
|
472
|
+
path=request.path,
|
|
473
|
+
optimizer=request.optimizer,
|
|
474
|
+
seq_id=request.seq_id,
|
|
475
|
+
)
|
|
476
|
+
return types.LoadWeightsResponse(path=request.path)
|
|
477
|
+
|
|
478
|
+
return await _queue_future(
|
|
479
|
+
_operation,
|
|
480
|
+
state,
|
|
481
|
+
model_id=request.model_id,
|
|
482
|
+
user_id=user.user_id,
|
|
483
|
+
operation_type="load_weights",
|
|
484
|
+
operation_args={
|
|
485
|
+
"model_id": request.model_id,
|
|
486
|
+
"user_id": user.user_id,
|
|
487
|
+
"path": request.path,
|
|
488
|
+
"optimizer": request.optimizer,
|
|
489
|
+
},
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
@app.post(
|
|
493
|
+
"/api/v1/asample",
|
|
494
|
+
response_model=types.UntypedAPIFuture,
|
|
495
|
+
status_code=status.HTTP_202_ACCEPTED,
|
|
496
|
+
)
|
|
497
|
+
async def asample(
|
|
498
|
+
request: types.SampleRequest,
|
|
499
|
+
state: ServerState = Depends(_get_state),
|
|
500
|
+
user: User = Depends(_get_user),
|
|
501
|
+
) -> types.UntypedAPIFuture:
|
|
502
|
+
return await _queue_future(
|
|
503
|
+
partial(state.run_sample, request=request, user_id=user.user_id),
|
|
504
|
+
state=state,
|
|
505
|
+
user_id=user.user_id,
|
|
506
|
+
operation_type="sample",
|
|
507
|
+
operation_args={
|
|
508
|
+
"request": request,
|
|
509
|
+
"user_id": user.user_id,
|
|
510
|
+
},
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
@app.post("/api/v1/retrieve_future")
|
|
514
|
+
async def retrieve_future(
|
|
515
|
+
request: types.FutureRetrieveRequest,
|
|
516
|
+
state: ServerState = Depends(_get_state),
|
|
517
|
+
user: User = Depends(_get_user),
|
|
518
|
+
) -> Any:
|
|
519
|
+
try:
|
|
520
|
+
payload = await state.future_store.retrieve(
|
|
521
|
+
request_id=request.request_id, user_id=user.user_id
|
|
522
|
+
)
|
|
523
|
+
except KeyError as exc:
|
|
524
|
+
raise HTTPException(
|
|
525
|
+
status_code=status.HTTP_404_NOT_FOUND, detail="Unknown request_id"
|
|
526
|
+
) from exc
|
|
527
|
+
return payload # FastAPI will serialize the stored Tinker type
|
|
528
|
+
|
|
529
|
+
@app.get(
|
|
530
|
+
"/api/v1/training_runs",
|
|
531
|
+
response_model=types.TrainingRunsResponse,
|
|
532
|
+
)
|
|
533
|
+
async def list_training_runs(
|
|
534
|
+
limit: int = Query(20, ge=1, le=500),
|
|
535
|
+
offset: int = Query(0, ge=0),
|
|
536
|
+
state: ServerState = Depends(_get_state),
|
|
537
|
+
user: User = Depends(_get_user),
|
|
538
|
+
) -> types.TrainingRunsResponse:
|
|
539
|
+
return state.list_training_runs(user_id=user.user_id, limit=limit, offset=offset)
|
|
540
|
+
|
|
541
|
+
@app.get(
|
|
542
|
+
"/api/v1/training_runs/{model_id}",
|
|
543
|
+
response_model=types.TrainingRun,
|
|
544
|
+
)
|
|
545
|
+
async def get_training_run(
|
|
546
|
+
model_id: str,
|
|
547
|
+
state: ServerState = Depends(_get_state),
|
|
548
|
+
user: User = Depends(_get_user),
|
|
549
|
+
) -> types.TrainingRun:
|
|
550
|
+
return state.get_training_run_view(model_id=model_id, user_id=user.user_id)
|
|
551
|
+
|
|
552
|
+
def _build_checkpoint_cursor(total: int, limit: int, offset: int) -> types.Cursor:
|
|
553
|
+
return types.Cursor(offset=offset, limit=limit, total_count=total)
|
|
554
|
+
|
|
555
|
+
@app.get(
|
|
556
|
+
"/api/v1/training_runs/{model_id}/checkpoints",
|
|
557
|
+
response_model=types.CheckpointsListResponse,
|
|
558
|
+
)
|
|
559
|
+
async def list_training_run_checkpoints(
|
|
560
|
+
model_id: str,
|
|
561
|
+
limit: int = Query(100, ge=1, le=1000),
|
|
562
|
+
offset: int = Query(0, ge=0),
|
|
563
|
+
state: ServerState = Depends(_get_state),
|
|
564
|
+
user: User = Depends(_get_user),
|
|
565
|
+
) -> types.CheckpointsListResponse:
|
|
566
|
+
checkpoints = state.list_checkpoints(model_id, user_id=user.user_id)
|
|
567
|
+
total = len(checkpoints)
|
|
568
|
+
start = min(offset, total)
|
|
569
|
+
end = min(start + limit, total)
|
|
570
|
+
subset = checkpoints[start:end]
|
|
571
|
+
cursor = _build_checkpoint_cursor(total, limit, offset)
|
|
572
|
+
return types.CheckpointsListResponse(checkpoints=subset, cursor=cursor)
|
|
573
|
+
|
|
574
|
+
@app.delete(
|
|
575
|
+
"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}",
|
|
576
|
+
status_code=status.HTTP_204_NO_CONTENT,
|
|
577
|
+
)
|
|
578
|
+
async def delete_checkpoint(
|
|
579
|
+
model_id: str,
|
|
580
|
+
checkpoint_path: str,
|
|
581
|
+
state: ServerState = Depends(_get_state),
|
|
582
|
+
user: User = Depends(_get_user),
|
|
583
|
+
) -> None:
|
|
584
|
+
state.delete_checkpoint(model_id, user.user_id, _normalize_checkpoint_id(checkpoint_path))
|
|
585
|
+
|
|
586
|
+
@app.post(
|
|
587
|
+
"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}/publish",
|
|
588
|
+
status_code=status.HTTP_204_NO_CONTENT,
|
|
589
|
+
)
|
|
590
|
+
async def publish_checkpoint(
|
|
591
|
+
model_id: str,
|
|
592
|
+
checkpoint_path: str,
|
|
593
|
+
state: ServerState = Depends(_get_state),
|
|
594
|
+
user: User = Depends(_get_user),
|
|
595
|
+
) -> None:
|
|
596
|
+
state.set_checkpoint_visibility(
|
|
597
|
+
model_id,
|
|
598
|
+
user.user_id,
|
|
599
|
+
_normalize_checkpoint_id(checkpoint_path),
|
|
600
|
+
public=True,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
@app.delete(
|
|
604
|
+
"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}/publish",
|
|
605
|
+
status_code=status.HTTP_204_NO_CONTENT,
|
|
606
|
+
)
|
|
607
|
+
async def unpublish_checkpoint(
|
|
608
|
+
model_id: str,
|
|
609
|
+
checkpoint_path: str,
|
|
610
|
+
state: ServerState = Depends(_get_state),
|
|
611
|
+
user: User = Depends(_get_user),
|
|
612
|
+
) -> None:
|
|
613
|
+
state.set_checkpoint_visibility(
|
|
614
|
+
model_id,
|
|
615
|
+
user.user_id,
|
|
616
|
+
_normalize_checkpoint_id(checkpoint_path),
|
|
617
|
+
public=False,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
@app.get(
|
|
621
|
+
"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}/archive",
|
|
622
|
+
status_code=status.HTTP_302_FOUND,
|
|
623
|
+
)
|
|
624
|
+
async def checkpoint_archive(
|
|
625
|
+
model_id: str,
|
|
626
|
+
checkpoint_path: str,
|
|
627
|
+
state: ServerState = Depends(_get_state),
|
|
628
|
+
user: User = Depends(_get_user),
|
|
629
|
+
) -> Response:
|
|
630
|
+
archive = state.build_archive_url(
|
|
631
|
+
model_id,
|
|
632
|
+
user_id=user.user_id,
|
|
633
|
+
checkpoint_id=_normalize_checkpoint_id(checkpoint_path),
|
|
634
|
+
)
|
|
635
|
+
expires = archive.expires.astimezone(timezone.utc)
|
|
636
|
+
return Response(
|
|
637
|
+
status_code=status.HTTP_302_FOUND,
|
|
638
|
+
headers={
|
|
639
|
+
"Location": archive.url,
|
|
640
|
+
"Expires": expires.strftime("%a, %d %b %Y %H:%M:%S GMT"),
|
|
641
|
+
},
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
@app.get(
|
|
645
|
+
"/api/v1/checkpoints",
|
|
646
|
+
response_model=types.CheckpointsListResponse,
|
|
647
|
+
)
|
|
648
|
+
async def list_user_checkpoints(
|
|
649
|
+
limit: int = Query(100, ge=1, le=1000),
|
|
650
|
+
offset: int = Query(0, ge=0),
|
|
651
|
+
state: ServerState = Depends(_get_state),
|
|
652
|
+
user: User = Depends(_get_user),
|
|
653
|
+
) -> types.CheckpointsListResponse:
|
|
654
|
+
checkpoints = state.list_user_checkpoints(user.user_id)
|
|
655
|
+
total = len(checkpoints)
|
|
656
|
+
start = min(offset, total)
|
|
657
|
+
end = min(start + limit, total)
|
|
658
|
+
subset = checkpoints[start:end]
|
|
659
|
+
cursor = _build_checkpoint_cursor(total, limit, offset)
|
|
660
|
+
return types.CheckpointsListResponse(checkpoints=subset, cursor=cursor)
|
|
661
|
+
|
|
662
|
+
@app.post(
|
|
663
|
+
"/api/v1/weights_info",
|
|
664
|
+
response_model=types.WeightsInfoResponse,
|
|
665
|
+
)
|
|
666
|
+
async def weights_info(
|
|
667
|
+
body: WeightsInfoBody,
|
|
668
|
+
state: ServerState = Depends(_get_state),
|
|
669
|
+
user: User = Depends(_get_user),
|
|
670
|
+
) -> types.WeightsInfoResponse:
|
|
671
|
+
return state.get_weights_info(body.tinker_path, user.user_id)
|
|
672
|
+
|
|
673
|
+
@app.get(
|
|
674
|
+
"/api/v1/sessions/{session_id}",
|
|
675
|
+
response_model=types.GetSessionResponse,
|
|
676
|
+
)
|
|
677
|
+
async def get_session(
|
|
678
|
+
session_id: str,
|
|
679
|
+
state: ServerState = Depends(_get_state),
|
|
680
|
+
user: User = Depends(_get_user),
|
|
681
|
+
) -> types.GetSessionResponse:
|
|
682
|
+
return state.get_session_overview(session_id, user.user_id)
|
|
683
|
+
|
|
684
|
+
@app.get(
|
|
685
|
+
"/api/v1/sessions",
|
|
686
|
+
response_model=types.ListSessionsResponse,
|
|
687
|
+
)
|
|
688
|
+
async def list_sessions(
|
|
689
|
+
limit: int = Query(20, ge=1, le=500),
|
|
690
|
+
offset: int = Query(0, ge=0),
|
|
691
|
+
state: ServerState = Depends(_get_state),
|
|
692
|
+
user: User = Depends(_get_user),
|
|
693
|
+
) -> types.ListSessionsResponse:
|
|
694
|
+
return state.list_sessions(user_id=user.user_id, limit=limit, offset=offset)
|
|
695
|
+
|
|
696
|
+
@app.post(
|
|
697
|
+
"/api/v1/telemetry",
|
|
698
|
+
response_model=types.TelemetryResponse,
|
|
699
|
+
)
|
|
700
|
+
async def send_telemetry(
|
|
701
|
+
body: types.TelemetrySendRequest,
|
|
702
|
+
) -> types.TelemetryResponse:
|
|
703
|
+
# We currently accept telemetry events for protocol compatibility but do not persist them.
|
|
704
|
+
return types.TelemetryResponse(status="accepted")
|
|
705
|
+
|
|
706
|
+
@app.get(
|
|
707
|
+
"/api/v1/samplers/{sampler_id}",
|
|
708
|
+
response_model=types.GetSamplerResponse,
|
|
709
|
+
)
|
|
710
|
+
async def get_sampler(
|
|
711
|
+
sampler_id: str,
|
|
712
|
+
state: ServerState = Depends(_get_state),
|
|
713
|
+
user: User = Depends(_get_user),
|
|
714
|
+
) -> types.GetSamplerResponse:
|
|
715
|
+
return state.get_sampler_info(sampler_id, user.user_id)
|
|
716
|
+
|
|
717
|
+
for route in app.routes:
|
|
718
|
+
if getattr(route, "path", None) != "/api/v1/healthz" and hasattr(route, "dependencies"):
|
|
719
|
+
require_user_dependency(route)
|
|
720
|
+
return app
|