tuft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tuft/server.py ADDED
@@ -0,0 +1,720 @@
1
+ """FastAPI application exposing a local-compatible Tinker API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from contextlib import asynccontextmanager
7
+ from datetime import timezone
8
+ from functools import partial
9
+ from typing import Any, Callable
10
+
11
+ from fastapi import Depends, FastAPI, HTTPException, Query, Request, status
12
+ from fastapi.responses import Response
13
+ from fastapi.security import APIKeyHeader
14
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
15
+ from pydantic import BaseModel
16
+ from tinker import types
17
+
18
+ from .auth import User
19
+ from .config import AppConfig
20
+ from .exceptions import TuFTException
21
+ from .persistence import get_redis_store
22
+ from .state import ServerState
23
+ from .telemetry import shutdown_telemetry
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
29
+
30
+
31
+ async def _get_user(
32
+ request: Request,
33
+ api_key: str = Depends(api_key_header),
34
+ ) -> User:
35
+ if not api_key:
36
+ raise HTTPException(
37
+ status_code=status.HTTP_401_UNAUTHORIZED,
38
+ detail="Missing API key",
39
+ )
40
+ user = request.app.state.server_state.get_user(api_key)
41
+ if user is None:
42
+ raise HTTPException(
43
+ status_code=status.HTTP_403_FORBIDDEN,
44
+ detail="Invalid API key",
45
+ )
46
+ return user
47
+
48
+
49
+ class WeightsInfoBody(BaseModel):
50
+ tinker_path: str
51
+
52
+
53
+ def _normalize_checkpoint_id(raw: str) -> str:
54
+ if "/" not in raw:
55
+ return raw
56
+ prefix, remainder = raw.split("/", 1)
57
+ if prefix not in {"weights", "sampler_weights"} or not remainder:
58
+ raise HTTPException(
59
+ status_code=status.HTTP_400_BAD_REQUEST,
60
+ detail="Invalid checkpoint reference",
61
+ )
62
+ return remainder
63
+
64
+
65
+ def _get_state(request: Request) -> ServerState:
66
+ state = getattr(request.app.state, "server_state", None)
67
+ if state is None:
68
+ raise RuntimeError("Server state has not been initialized")
69
+ return state
70
+
71
+
72
+ def _instrument_fastapi(app: FastAPI) -> None:
73
+ """Instrument FastAPI app with OpenTelemetry."""
74
+ FastAPIInstrumentor.instrument_app(app)
75
+ logger.debug("FastAPI instrumentation enabled")
76
+
77
+
78
+ def create_root_app(config: AppConfig | None = None) -> FastAPI:
79
+ @asynccontextmanager
80
+ async def lifespan(app: FastAPI):
81
+ try:
82
+ await app.state.server_state.async_init()
83
+ logger.info("Server initialized successfully")
84
+ yield
85
+ finally:
86
+ logger.info("Server shutting down")
87
+ await app.state.server_state.future_store.shutdown()
88
+ store = get_redis_store()
89
+ if store.is_enabled:
90
+ store.close()
91
+ shutdown_telemetry()
92
+
93
+ def require_user_dependency(route):
94
+ if not any(dep.dependency == _get_user for dep in getattr(route, "dependencies", [])):
95
+ route.dependencies = getattr(route, "dependencies", []) + [Depends(_get_user)]
96
+ return route
97
+
98
+ resolved_config = config or AppConfig()
99
+ if resolved_config.persistence.enabled:
100
+ store = get_redis_store()
101
+ store.configure(resolved_config.persistence)
102
+
103
+ app = FastAPI(
104
+ title="TuFT",
105
+ version="0.1.0",
106
+ lifespan=lifespan,
107
+ )
108
+ app.state.server_state = ServerState(resolved_config)
109
+
110
+ # Instrument FastAPI with OpenTelemetry if enabled
111
+ if resolved_config.telemetry.enabled:
112
+ _instrument_fastapi(app)
113
+
114
+ @app.get("/api/v1/healthz", response_model=types.HealthResponse)
115
+ async def healthz() -> types.HealthResponse:
116
+ return types.HealthResponse(status="ok")
117
+
118
+ @app.get(
119
+ "/api/v1/get_server_capabilities",
120
+ response_model=types.GetServerCapabilitiesResponse,
121
+ )
122
+ async def get_server_capabilities(
123
+ state: ServerState = Depends(_get_state),
124
+ ) -> types.GetServerCapabilitiesResponse:
125
+ return types.GetServerCapabilitiesResponse(supported_models=state.build_supported_models())
126
+
127
+ @app.post(
128
+ "/api/v1/create_session",
129
+ response_model=types.CreateSessionResponse,
130
+ status_code=status.HTTP_201_CREATED,
131
+ )
132
+ async def create_session(
133
+ request: types.CreateSessionRequest,
134
+ state: ServerState = Depends(_get_state),
135
+ user: User = Depends(_get_user),
136
+ ) -> types.CreateSessionResponse:
137
+ record = state.create_session(request, user)
138
+ return types.CreateSessionResponse(session_id=record.session_id)
139
+
140
+ @app.post(
141
+ "/api/v1/session_heartbeat",
142
+ response_model=types.SessionHeartbeatResponse,
143
+ )
144
+ async def session_heartbeat(
145
+ request: types.SessionHeartbeatRequest,
146
+ state: ServerState = Depends(_get_state),
147
+ user: User = Depends(_get_user),
148
+ ) -> types.SessionHeartbeatResponse:
149
+ state.heartbeat(request.session_id, user_id=user.user_id)
150
+ return types.SessionHeartbeatResponse()
151
+
152
+ @app.post(
153
+ "/api/v1/create_sampling_session",
154
+ response_model=types.CreateSamplingSessionResponse,
155
+ status_code=status.HTTP_201_CREATED,
156
+ )
157
+ async def create_sampling_session(
158
+ request: types.CreateSamplingSessionRequest,
159
+ state: ServerState = Depends(_get_state),
160
+ user: User = Depends(_get_user),
161
+ ) -> types.CreateSamplingSessionResponse:
162
+ try:
163
+ sampling_session_id = await state.create_sampling_session(
164
+ session_id=request.session_id,
165
+ user_id=user.user_id,
166
+ base_model=request.base_model,
167
+ model_path=request.model_path,
168
+ session_seq_id=request.sampling_session_seq_id,
169
+ )
170
+ except TuFTException as exc:
171
+ raise HTTPException(
172
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
173
+ detail=f"Failed to create sampling session: {exc.detail}",
174
+ ) from exc
175
+ except Exception as exc: # pylint: disable=broad-except
176
+ raise HTTPException(
177
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
178
+ detail=f"Failed to create sampling session: {str(exc)}",
179
+ ) from exc
180
+ return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id)
181
+
182
+ @app.post(
183
+ "/api/v1/create_model",
184
+ response_model=types.UntypedAPIFuture,
185
+ status_code=status.HTTP_202_ACCEPTED,
186
+ )
187
+ async def create_model(
188
+ request: types.CreateModelRequest,
189
+ state: ServerState = Depends(_get_state),
190
+ user: User = Depends(_get_user),
191
+ ) -> types.UntypedAPIFuture:
192
+ if request.lora_config is None:
193
+ raise HTTPException(
194
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Missing LoRA config"
195
+ )
196
+ try:
197
+ training_record = await state.create_model(
198
+ session_id=request.session_id,
199
+ base_model=request.base_model,
200
+ lora_config=request.lora_config,
201
+ model_owner=user.user_id,
202
+ user_metadata=request.user_metadata,
203
+ )
204
+ except TuFTException as exc:
205
+ raise HTTPException(
206
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
207
+ detail=f"Failed to create model: {exc.detail}",
208
+ ) from exc
209
+ except Exception as exc: # pylint: disable=broad-except
210
+ raise HTTPException(
211
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
212
+ detail=f"Failed to create model: {str(exc)}",
213
+ ) from exc
214
+ response = types.CreateModelResponse(model_id=training_record.training_run_id)
215
+ return await state.future_store.create_ready_future(
216
+ response,
217
+ model_id=training_record.training_run_id,
218
+ user_id=user.user_id,
219
+ )
220
+
221
+ @app.post(
222
+ "/api/v1/get_info",
223
+ response_model=types.GetInfoResponse,
224
+ )
225
+ async def get_info(
226
+ request: types.GetInfoRequest,
227
+ state: ServerState = Depends(_get_state),
228
+ user: User = Depends(_get_user),
229
+ ) -> types.GetInfoResponse:
230
+ return state.get_model_info(request.model_id, user_id=user.user_id)
231
+
232
+ @app.post(
233
+ "/api/v1/unload_model",
234
+ response_model=types.UntypedAPIFuture,
235
+ status_code=status.HTTP_202_ACCEPTED,
236
+ )
237
+ async def unload_model(
238
+ request: types.UnloadModelRequest,
239
+ state: ServerState = Depends(_get_state),
240
+ user: User = Depends(_get_user),
241
+ ) -> types.UntypedAPIFuture:
242
+ try:
243
+ await state.unload_model(request.model_id, user_id=user.user_id)
244
+ except TuFTException as exc:
245
+ raise HTTPException(
246
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
247
+ detail=f"Failed to unload model: {exc.detail}",
248
+ ) from exc
249
+ except Exception as exc: # pylint: disable=broad-except
250
+ raise HTTPException(
251
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
252
+ detail=f"Failed to unload model: {str(exc)}",
253
+ ) from exc
254
+ response = types.UnloadModelResponse(model_id=request.model_id)
255
+ return await state.future_store.create_ready_future(
256
+ response, model_id=request.model_id, user_id=user.user_id
257
+ )
258
+
259
+ async def _queue_future(
260
+ operation: Callable[[], Any],
261
+ state: ServerState,
262
+ user_id: str,
263
+ *,
264
+ model_id: str | None = None,
265
+ operation_type: str | None = None,
266
+ operation_args: dict[str, Any] | None = None,
267
+ ) -> types.UntypedAPIFuture:
268
+ return await state.future_store.enqueue(
269
+ operation,
270
+ model_id=model_id,
271
+ user_id=user_id,
272
+ operation_type=operation_type, # type: ignore[arg-type]
273
+ operation_args=operation_args,
274
+ )
275
+
276
+ @app.post(
277
+ "/api/v1/forward",
278
+ response_model=types.UntypedAPIFuture,
279
+ status_code=status.HTTP_202_ACCEPTED,
280
+ )
281
+ async def forward(
282
+ request: types.ForwardRequest,
283
+ state: ServerState = Depends(_get_state),
284
+ user: User = Depends(_get_user),
285
+ ) -> types.UntypedAPIFuture:
286
+ inp = request.forward_input
287
+
288
+ async def _operation() -> types.ForwardBackwardOutput:
289
+ return await state.run_forward(
290
+ request.model_id,
291
+ user.user_id,
292
+ inp.data,
293
+ inp.loss_fn,
294
+ inp.loss_fn_config,
295
+ request.seq_id,
296
+ backward=False,
297
+ )
298
+
299
+ return await _queue_future(
300
+ _operation,
301
+ state,
302
+ model_id=request.model_id,
303
+ user_id=user.user_id,
304
+ operation_type="forward",
305
+ operation_args={
306
+ "model_id": request.model_id,
307
+ "user_id": user.user_id,
308
+ "data": inp.data,
309
+ "loss_fn": inp.loss_fn,
310
+ "loss_fn_config": inp.loss_fn_config,
311
+ "seq_id": request.seq_id,
312
+ "backward": False,
313
+ },
314
+ )
315
+
316
+ @app.post(
317
+ "/api/v1/forward_backward",
318
+ response_model=types.UntypedAPIFuture,
319
+ status_code=status.HTTP_202_ACCEPTED,
320
+ )
321
+ async def forward_backward(
322
+ request: types.ForwardBackwardRequest,
323
+ state: ServerState = Depends(_get_state),
324
+ user: User = Depends(_get_user),
325
+ ) -> types.UntypedAPIFuture:
326
+ inp = request.forward_backward_input
327
+
328
+ async def _operation() -> types.ForwardBackwardOutput:
329
+ return await state.run_forward(
330
+ request.model_id,
331
+ user.user_id,
332
+ inp.data,
333
+ inp.loss_fn,
334
+ inp.loss_fn_config,
335
+ request.seq_id,
336
+ backward=True,
337
+ )
338
+
339
+ return await _queue_future(
340
+ _operation,
341
+ state,
342
+ model_id=request.model_id,
343
+ user_id=user.user_id,
344
+ operation_type="forward_backward",
345
+ operation_args={
346
+ "model_id": request.model_id,
347
+ "user_id": user.user_id,
348
+ "data": inp.data,
349
+ "loss_fn": inp.loss_fn,
350
+ "loss_fn_config": inp.loss_fn_config,
351
+ "seq_id": request.seq_id,
352
+ "backward": True,
353
+ },
354
+ )
355
+
356
+ @app.post(
357
+ "/api/v1/optim_step",
358
+ response_model=types.UntypedAPIFuture,
359
+ status_code=status.HTTP_202_ACCEPTED,
360
+ )
361
+ async def optim_step(
362
+ request: types.OptimStepRequest,
363
+ state: ServerState = Depends(_get_state),
364
+ user: User = Depends(_get_user),
365
+ ) -> types.UntypedAPIFuture:
366
+ async def _operation() -> types.OptimStepResponse:
367
+ return await state.run_optim_step(
368
+ request.model_id,
369
+ user.user_id,
370
+ request.adam_params,
371
+ request.seq_id,
372
+ )
373
+
374
+ return await _queue_future(
375
+ _operation,
376
+ state,
377
+ model_id=request.model_id,
378
+ user_id=user.user_id,
379
+ operation_type="optim_step",
380
+ operation_args={
381
+ "model_id": request.model_id,
382
+ "user_id": user.user_id,
383
+ "params": request.adam_params,
384
+ "seq_id": request.seq_id,
385
+ },
386
+ )
387
+
388
+ @app.post(
389
+ "/api/v1/save_weights",
390
+ response_model=types.UntypedAPIFuture,
391
+ status_code=status.HTTP_202_ACCEPTED,
392
+ )
393
+ async def save_weights(
394
+ request: types.SaveWeightsRequest,
395
+ state: ServerState = Depends(_get_state),
396
+ user: User = Depends(_get_user),
397
+ ) -> types.UntypedAPIFuture:
398
+ async def _operation() -> types.SaveWeightsResponse:
399
+ checkpoint = await state.save_checkpoint(
400
+ request.model_id,
401
+ user.user_id,
402
+ request.path,
403
+ "training",
404
+ seq_id=request.seq_id,
405
+ )
406
+ return types.SaveWeightsResponse(path=checkpoint.tinker_checkpoint.tinker_path)
407
+
408
+ return await _queue_future(
409
+ _operation,
410
+ state,
411
+ model_id=request.model_id,
412
+ user_id=user.user_id,
413
+ operation_type="save_weights",
414
+ operation_args={
415
+ "model_id": request.model_id,
416
+ "user_id": user.user_id,
417
+ "name": request.path,
418
+ "checkpoint_type": "training",
419
+ },
420
+ )
421
+
422
+ @app.post(
423
+ "/api/v1/save_weights_for_sampler",
424
+ response_model=types.UntypedAPIFuture,
425
+ status_code=status.HTTP_202_ACCEPTED,
426
+ )
427
+ async def save_weights_for_sampler(
428
+ request: types.SaveWeightsForSamplerRequest,
429
+ state: ServerState = Depends(_get_state),
430
+ user: User = Depends(_get_user),
431
+ ) -> types.UntypedAPIFuture:
432
+ async def _operation() -> types.SaveWeightsForSamplerResponse:
433
+ checkpoint = await state.save_checkpoint(
434
+ request.model_id,
435
+ user.user_id,
436
+ request.path,
437
+ "sampler",
438
+ seq_id=request.seq_id,
439
+ )
440
+ return types.SaveWeightsForSamplerResponse(
441
+ path=checkpoint.tinker_checkpoint.tinker_path
442
+ )
443
+
444
+ return await _queue_future(
445
+ _operation,
446
+ state,
447
+ model_id=request.model_id,
448
+ user_id=user.user_id,
449
+ operation_type="save_weights_for_sampler",
450
+ operation_args={
451
+ "model_id": request.model_id,
452
+ "user_id": user.user_id,
453
+ "name": request.path,
454
+ "checkpoint_type": "sampler",
455
+ },
456
+ )
457
+
458
+ @app.post(
459
+ "/api/v1/load_weights",
460
+ response_model=types.UntypedAPIFuture,
461
+ status_code=status.HTTP_202_ACCEPTED,
462
+ )
463
+ async def load_weights(
464
+ request: types.LoadWeightsRequest,
465
+ state: ServerState = Depends(_get_state),
466
+ user: User = Depends(_get_user),
467
+ ) -> types.UntypedAPIFuture:
468
+ async def _operation() -> types.LoadWeightsResponse:
469
+ await state.load_checkpoint(
470
+ model_id=request.model_id,
471
+ user_id=user.user_id,
472
+ path=request.path,
473
+ optimizer=request.optimizer,
474
+ seq_id=request.seq_id,
475
+ )
476
+ return types.LoadWeightsResponse(path=request.path)
477
+
478
+ return await _queue_future(
479
+ _operation,
480
+ state,
481
+ model_id=request.model_id,
482
+ user_id=user.user_id,
483
+ operation_type="load_weights",
484
+ operation_args={
485
+ "model_id": request.model_id,
486
+ "user_id": user.user_id,
487
+ "path": request.path,
488
+ "optimizer": request.optimizer,
489
+ },
490
+ )
491
+
492
+ @app.post(
493
+ "/api/v1/asample",
494
+ response_model=types.UntypedAPIFuture,
495
+ status_code=status.HTTP_202_ACCEPTED,
496
+ )
497
+ async def asample(
498
+ request: types.SampleRequest,
499
+ state: ServerState = Depends(_get_state),
500
+ user: User = Depends(_get_user),
501
+ ) -> types.UntypedAPIFuture:
502
+ return await _queue_future(
503
+ partial(state.run_sample, request=request, user_id=user.user_id),
504
+ state=state,
505
+ user_id=user.user_id,
506
+ operation_type="sample",
507
+ operation_args={
508
+ "request": request,
509
+ "user_id": user.user_id,
510
+ },
511
+ )
512
+
513
+ @app.post("/api/v1/retrieve_future")
514
+ async def retrieve_future(
515
+ request: types.FutureRetrieveRequest,
516
+ state: ServerState = Depends(_get_state),
517
+ user: User = Depends(_get_user),
518
+ ) -> Any:
519
+ try:
520
+ payload = await state.future_store.retrieve(
521
+ request_id=request.request_id, user_id=user.user_id
522
+ )
523
+ except KeyError as exc:
524
+ raise HTTPException(
525
+ status_code=status.HTTP_404_NOT_FOUND, detail="Unknown request_id"
526
+ ) from exc
527
+ return payload # FastAPI will serialize the stored Tinker type
528
+
529
+ @app.get(
530
+ "/api/v1/training_runs",
531
+ response_model=types.TrainingRunsResponse,
532
+ )
533
+ async def list_training_runs(
534
+ limit: int = Query(20, ge=1, le=500),
535
+ offset: int = Query(0, ge=0),
536
+ state: ServerState = Depends(_get_state),
537
+ user: User = Depends(_get_user),
538
+ ) -> types.TrainingRunsResponse:
539
+ return state.list_training_runs(user_id=user.user_id, limit=limit, offset=offset)
540
+
541
+ @app.get(
542
+ "/api/v1/training_runs/{model_id}",
543
+ response_model=types.TrainingRun,
544
+ )
545
+ async def get_training_run(
546
+ model_id: str,
547
+ state: ServerState = Depends(_get_state),
548
+ user: User = Depends(_get_user),
549
+ ) -> types.TrainingRun:
550
+ return state.get_training_run_view(model_id=model_id, user_id=user.user_id)
551
+
552
+ def _build_checkpoint_cursor(total: int, limit: int, offset: int) -> types.Cursor:
553
+ return types.Cursor(offset=offset, limit=limit, total_count=total)
554
+
555
+ @app.get(
556
+ "/api/v1/training_runs/{model_id}/checkpoints",
557
+ response_model=types.CheckpointsListResponse,
558
+ )
559
+ async def list_training_run_checkpoints(
560
+ model_id: str,
561
+ limit: int = Query(100, ge=1, le=1000),
562
+ offset: int = Query(0, ge=0),
563
+ state: ServerState = Depends(_get_state),
564
+ user: User = Depends(_get_user),
565
+ ) -> types.CheckpointsListResponse:
566
+ checkpoints = state.list_checkpoints(model_id, user_id=user.user_id)
567
+ total = len(checkpoints)
568
+ start = min(offset, total)
569
+ end = min(start + limit, total)
570
+ subset = checkpoints[start:end]
571
+ cursor = _build_checkpoint_cursor(total, limit, offset)
572
+ return types.CheckpointsListResponse(checkpoints=subset, cursor=cursor)
573
+
574
+ @app.delete(
575
+ "/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}",
576
+ status_code=status.HTTP_204_NO_CONTENT,
577
+ )
578
+ async def delete_checkpoint(
579
+ model_id: str,
580
+ checkpoint_path: str,
581
+ state: ServerState = Depends(_get_state),
582
+ user: User = Depends(_get_user),
583
+ ) -> None:
584
+ state.delete_checkpoint(model_id, user.user_id, _normalize_checkpoint_id(checkpoint_path))
585
+
586
+ @app.post(
587
+ "/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}/publish",
588
+ status_code=status.HTTP_204_NO_CONTENT,
589
+ )
590
+ async def publish_checkpoint(
591
+ model_id: str,
592
+ checkpoint_path: str,
593
+ state: ServerState = Depends(_get_state),
594
+ user: User = Depends(_get_user),
595
+ ) -> None:
596
+ state.set_checkpoint_visibility(
597
+ model_id,
598
+ user.user_id,
599
+ _normalize_checkpoint_id(checkpoint_path),
600
+ public=True,
601
+ )
602
+
603
+ @app.delete(
604
+ "/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}/publish",
605
+ status_code=status.HTTP_204_NO_CONTENT,
606
+ )
607
+ async def unpublish_checkpoint(
608
+ model_id: str,
609
+ checkpoint_path: str,
610
+ state: ServerState = Depends(_get_state),
611
+ user: User = Depends(_get_user),
612
+ ) -> None:
613
+ state.set_checkpoint_visibility(
614
+ model_id,
615
+ user.user_id,
616
+ _normalize_checkpoint_id(checkpoint_path),
617
+ public=False,
618
+ )
619
+
620
+ @app.get(
621
+ "/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_path:path}/archive",
622
+ status_code=status.HTTP_302_FOUND,
623
+ )
624
+ async def checkpoint_archive(
625
+ model_id: str,
626
+ checkpoint_path: str,
627
+ state: ServerState = Depends(_get_state),
628
+ user: User = Depends(_get_user),
629
+ ) -> Response:
630
+ archive = state.build_archive_url(
631
+ model_id,
632
+ user_id=user.user_id,
633
+ checkpoint_id=_normalize_checkpoint_id(checkpoint_path),
634
+ )
635
+ expires = archive.expires.astimezone(timezone.utc)
636
+ return Response(
637
+ status_code=status.HTTP_302_FOUND,
638
+ headers={
639
+ "Location": archive.url,
640
+ "Expires": expires.strftime("%a, %d %b %Y %H:%M:%S GMT"),
641
+ },
642
+ )
643
+
644
+ @app.get(
645
+ "/api/v1/checkpoints",
646
+ response_model=types.CheckpointsListResponse,
647
+ )
648
+ async def list_user_checkpoints(
649
+ limit: int = Query(100, ge=1, le=1000),
650
+ offset: int = Query(0, ge=0),
651
+ state: ServerState = Depends(_get_state),
652
+ user: User = Depends(_get_user),
653
+ ) -> types.CheckpointsListResponse:
654
+ checkpoints = state.list_user_checkpoints(user.user_id)
655
+ total = len(checkpoints)
656
+ start = min(offset, total)
657
+ end = min(start + limit, total)
658
+ subset = checkpoints[start:end]
659
+ cursor = _build_checkpoint_cursor(total, limit, offset)
660
+ return types.CheckpointsListResponse(checkpoints=subset, cursor=cursor)
661
+
662
+ @app.post(
663
+ "/api/v1/weights_info",
664
+ response_model=types.WeightsInfoResponse,
665
+ )
666
+ async def weights_info(
667
+ body: WeightsInfoBody,
668
+ state: ServerState = Depends(_get_state),
669
+ user: User = Depends(_get_user),
670
+ ) -> types.WeightsInfoResponse:
671
+ return state.get_weights_info(body.tinker_path, user.user_id)
672
+
673
+ @app.get(
674
+ "/api/v1/sessions/{session_id}",
675
+ response_model=types.GetSessionResponse,
676
+ )
677
+ async def get_session(
678
+ session_id: str,
679
+ state: ServerState = Depends(_get_state),
680
+ user: User = Depends(_get_user),
681
+ ) -> types.GetSessionResponse:
682
+ return state.get_session_overview(session_id, user.user_id)
683
+
684
+ @app.get(
685
+ "/api/v1/sessions",
686
+ response_model=types.ListSessionsResponse,
687
+ )
688
+ async def list_sessions(
689
+ limit: int = Query(20, ge=1, le=500),
690
+ offset: int = Query(0, ge=0),
691
+ state: ServerState = Depends(_get_state),
692
+ user: User = Depends(_get_user),
693
+ ) -> types.ListSessionsResponse:
694
+ return state.list_sessions(user_id=user.user_id, limit=limit, offset=offset)
695
+
696
+ @app.post(
697
+ "/api/v1/telemetry",
698
+ response_model=types.TelemetryResponse,
699
+ )
700
+ async def send_telemetry(
701
+ body: types.TelemetrySendRequest,
702
+ ) -> types.TelemetryResponse:
703
+ # We currently accept telemetry events for protocol compatibility but do not persist them.
704
+ return types.TelemetryResponse(status="accepted")
705
+
706
+ @app.get(
707
+ "/api/v1/samplers/{sampler_id}",
708
+ response_model=types.GetSamplerResponse,
709
+ )
710
+ async def get_sampler(
711
+ sampler_id: str,
712
+ state: ServerState = Depends(_get_state),
713
+ user: User = Depends(_get_user),
714
+ ) -> types.GetSamplerResponse:
715
+ return state.get_sampler_info(sampler_id, user.user_id)
716
+
717
+ for route in app.routes:
718
+ if getattr(route, "path", None) != "/api/v1/healthz" and hasattr(route, "dependencies"):
719
+ require_user_dependency(route)
720
+ return app