tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tuft/config.py ADDED
@@ -0,0 +1,121 @@
1
+ """Configuration helpers for the TuFT service."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, List
8
+
9
+ from .persistence import PersistenceConfig
10
+
11
+
12
+ def _default_checkpoint_dir() -> Path:
13
+ return Path.home() / ".cache" / "tuft" / "checkpoints"
14
+
15
+
16
+ def _default_persistence_config() -> PersistenceConfig:
17
+ return PersistenceConfig()
18
+
19
+
20
+ @dataclass
21
+ class TelemetryConfig:
22
+ """Configuration for OpenTelemetry integration.
23
+
24
+ Attributes:
25
+ enabled: Whether telemetry is enabled.
26
+ service_name: Name of the service for tracing.
27
+ otlp_endpoint: OTLP exporter endpoint. If None, uses TUFT_OTLP_ENDPOINT env var.
28
+ resource_attributes: Additional resource attributes as key-value pairs.
29
+ """
30
+
31
+ enabled: bool = False
32
+ service_name: str = "tuft"
33
+ otlp_endpoint: str | None = None
34
+ resource_attributes: Dict[str, str] = field(default_factory=dict)
35
+
36
+
37
+ def _default_telemetry_config() -> TelemetryConfig:
38
+ return TelemetryConfig()
39
+
40
+
41
+ @dataclass
42
+ class AppConfig:
43
+ """Runtime configuration for the TuFT server."""
44
+
45
+ checkpoint_dir: Path = field(default_factory=_default_checkpoint_dir)
46
+ supported_models: List[ModelConfig] = field(default_factory=list)
47
+ model_owner: str = "local-user"
48
+ toy_backend_seed: int = 0
49
+ # TODO: Temporary implementation for user authorization,
50
+ # replace with proper auth system later
51
+ authorized_users: Dict[str, str] = field(default_factory=dict)
52
+ persistence: PersistenceConfig = field(default_factory=_default_persistence_config)
53
+ telemetry: TelemetryConfig = field(default_factory=_default_telemetry_config)
54
+
55
+ def ensure_directories(self) -> None:
56
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ def check_validity(self) -> None:
59
+ if not self.supported_models:
60
+ raise ValueError("At least one supported model must be configured.")
61
+ model_names = {model.model_name for model in self.supported_models}
62
+ if len(model_names) != len(self.supported_models):
63
+ raise ValueError("Model names in supported_models must be unique.")
64
+ if len(model_names) > 1 and any(model.colocate for model in self.supported_models):
65
+ raise ValueError(
66
+ "Colocate option is only allowed when there is a single supported model."
67
+ )
68
+
69
+ def with_supported_models(self, models: Iterable[ModelConfig]) -> "AppConfig":
70
+ updated = list(models)
71
+ if updated:
72
+ self.supported_models = updated
73
+ return self
74
+
75
+
76
+ @dataclass
77
+ class ModelConfig:
78
+ """Configuration for a specific model."""
79
+
80
+ model_name: str # name used in APIs
81
+ model_path: Path # path to model checkpoint
82
+ max_model_len: int # maximum context length supported by the model
83
+ tensor_parallel_size: int = 1 # tensor parallel size
84
+
85
+ # default sampling parameters for this model
86
+ temperature: float = 1.0
87
+ top_p: float = 1.0
88
+ top_k: int = -1
89
+ logprobs: int = 0
90
+ seed: int = 42
91
+ min_response_tokens: int = 0
92
+
93
+ # default lora setting
94
+ max_lora_rank: int = 16 # maximum rank for LoRA adapters
95
+ max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
96
+
97
+ # whether to colocate sampling and training on the same device
98
+ # only for local testing purposes
99
+ colocate: bool = False
100
+ sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
101
+
102
+ def __post_init__(self) -> None:
103
+ if self.colocate and self.tensor_parallel_size != 1:
104
+ raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
105
+
106
+
107
+ def load_yaml_config(config_path: Path) -> AppConfig:
108
+ """Loads an AppConfig from a YAML file."""
109
+ from omegaconf import OmegaConf
110
+
111
+ schema = OmegaConf.structured(AppConfig)
112
+ loaded = OmegaConf.load(config_path)
113
+ try:
114
+ config = OmegaConf.merge(schema, loaded)
115
+ app_config = OmegaConf.to_object(config)
116
+ assert isinstance(app_config, AppConfig), (
117
+ "Loaded config is not of type AppConfig, which should not happen."
118
+ )
119
+ return app_config
120
+ except Exception as e:
121
+ raise ValueError(f"Failed to load config from {config_path}: {e}") from e
tuft/exceptions.py ADDED
@@ -0,0 +1,138 @@
1
+ """Some custom exceptions."""
2
+
3
+
4
+ class TuFTException(Exception):
5
+ """Base exception for TuFT errors."""
6
+
7
+ def __init__(self, detail: str = ""):
8
+ super().__init__(detail)
9
+ self.detail = detail
10
+
11
+
12
+ class ModelException(TuFTException):
13
+ """Base exception for Model related errors."""
14
+
15
+
16
+ class CheckpointException(TuFTException):
17
+ """Base exception for Checkpoint related errors."""
18
+
19
+
20
+ class FutureException(TuFTException):
21
+ """Base exception for Future related errors."""
22
+
23
+
24
+ class SessionException(TuFTException):
25
+ """Base exception for Session related errors."""
26
+
27
+
28
+ class AuthenticationException(TuFTException):
29
+ """Base exception for Authentication related errors."""
30
+
31
+
32
+ class LossFunctionException(TuFTException):
33
+ """Base exception for Loss Function related errors."""
34
+
35
+
36
+ class UnknownModelException(ModelException):
37
+ """A model was requested that is not known."""
38
+
39
+ def __init__(self, model_name: str | None):
40
+ detail = f"Unknown model: {model_name}"
41
+ super().__init__(detail)
42
+ self.model_name = model_name
43
+
44
+
45
+ class CheckpointNotFoundException(CheckpointException):
46
+ """Checkpoint not found."""
47
+
48
+ def __init__(self, checkpoint_id: str):
49
+ detail = f"Checkpoint {checkpoint_id} not found."
50
+ super().__init__(detail)
51
+ self.checkpoint_id = checkpoint_id
52
+
53
+
54
+ class CheckpointAccessDeniedException(CheckpointException):
55
+ """Access to the checkpoint is denied."""
56
+
57
+ def __init__(self, checkpoint_id: str):
58
+ detail = f"Access to checkpoint {checkpoint_id} is denied."
59
+ super().__init__(detail)
60
+ self.checkpoint_id = checkpoint_id
61
+
62
+
63
+ class CheckpointMetadataReadException(CheckpointException):
64
+ """Failed to read checkpoint metadata."""
65
+
66
+ def __init__(self, checkpoint_id: str):
67
+ detail = f"Failed to read metadata for checkpoint {checkpoint_id}."
68
+ super().__init__(detail)
69
+ self.checkpoint_id = checkpoint_id
70
+
71
+
72
+ class SequenceConflictException(FutureException):
73
+ """A sequence conflict occurred."""
74
+
75
+ def __init__(self, expected: int, got: int):
76
+ detail = f"Sequence conflict: expected {expected}, got {got}."
77
+ super().__init__(detail)
78
+ self.expected = expected
79
+ self.got = got
80
+
81
+
82
+ class MissingSequenceIDException(FutureException):
83
+ """Missing sequence ID in the request."""
84
+
85
+ def __init__(self):
86
+ detail = "Missing sequence ID in the request."
87
+ super().__init__(detail)
88
+
89
+
90
+ class FutureNotFoundException(FutureException):
91
+ """Future not found."""
92
+
93
+ def __init__(self, request_id: str):
94
+ detail = f"Future with request ID {request_id} not found."
95
+ super().__init__(detail)
96
+ self.request_id = request_id
97
+
98
+
99
+ class SessionNotFoundException(SessionException):
100
+ """Session not found."""
101
+
102
+ def __init__(self, session_id: str):
103
+ detail = f"Session {session_id} not found."
104
+ super().__init__(detail)
105
+ self.session_id = session_id
106
+
107
+
108
+ class UserMismatchException(AuthenticationException):
109
+ """User ID does not match the owner of the resource.
110
+ Do not expose user IDs in the detail message for security reasons.
111
+ """
112
+
113
+ def __init__(self):
114
+ detail = "You do not have permission to access this resource."
115
+ super().__init__(detail)
116
+
117
+
118
+ class LossFunctionNotFoundException(LossFunctionException):
119
+ """Loss function not found."""
120
+
121
+ def __init__(self, loss_function_name: str):
122
+ detail = f"Loss function {loss_function_name} not found."
123
+ super().__init__(detail)
124
+ self.loss_function_name = loss_function_name
125
+
126
+
127
+ class LossFunctionMissingInputException(LossFunctionException):
128
+ def __init__(self, missing_input_name: str):
129
+ detail = f"Missing '{missing_input_name}' in loss_fn_inputs."
130
+ super().__init__(detail)
131
+ self.input_name = missing_input_name
132
+
133
+
134
+ class LossFunctionInputShapeMismatchException(LossFunctionException):
135
+ def __init__(self, shapes: list):
136
+ detail = f"Input tensors must have the same shape. Got shapes: {shapes}"
137
+ super().__init__(detail)
138
+ self.shapes = shapes
tuft/futures.py ADDED
@@ -0,0 +1,431 @@
1
+ """Simple in-memory future registry for the synthetic Tinker API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import time
8
+ import uuid
9
+ from datetime import datetime, timezone
10
+ from typing import Any, Callable, Literal
11
+
12
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
13
+ from tinker import types
14
+ from tinker.types.try_again_response import TryAgainResponse
15
+
16
+ from .exceptions import FutureNotFoundException, TuFTException, UserMismatchException
17
+ from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
18
+ from .telemetry.metrics import get_metrics
19
+ from .telemetry.tracing import get_tracer
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ _get_tracer = lambda: get_tracer("tuft.futures") # noqa: E731
26
+
27
+ QueueState = Literal["active", "paused_capacity", "paused_rate_limit"]
28
+
29
+ OperationType = Literal[
30
+ "forward",
31
+ "forward_backward",
32
+ "optim_step",
33
+ "save_weights",
34
+ "save_weights_for_sampler",
35
+ "load_weights",
36
+ "sample",
37
+ ]
38
+
39
+
40
+ def _now() -> datetime:
41
+ return datetime.now(timezone.utc)
42
+
43
+
44
+ class FutureRecord(BaseModel):
45
+ """Future record with persistence support.
46
+
47
+ Fields:
48
+ event: Not serialized (excluded) - created fresh on each instance.
49
+ After restore, if status is ready/failed, event is auto-set.
50
+ operation_type: Type of operation for recovery purposes.
51
+ operation_args: Serializable arguments for the operation.
52
+ future_id: Globally incrementing sequence number for ordering futures.
53
+ Used instead of timestamps to avoid timezone/clock issues.
54
+ created_at: Timestamp when the future was created (for logging only).
55
+ """
56
+
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
58
+
59
+ request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
60
+ future_id: int = 0
61
+ model_id: str | None = None
62
+ user_id: str | None = None
63
+ queue_state: QueueState = "active"
64
+ status: Literal["pending", "ready", "failed"] = "pending"
65
+ payload: Any | None = None
66
+ error: types.RequestFailedResponse | None = None
67
+ operation_type: OperationType | None = None
68
+ operation_args: dict[str, Any] | None = None
69
+ created_at: datetime = Field(default_factory=_now)
70
+ # Runtime-only field, excluded from serialization
71
+ event: asyncio.Event = Field(default_factory=asyncio.Event, exclude=True)
72
+
73
+ @model_validator(mode="after")
74
+ def _set_event_if_completed(self) -> "FutureRecord":
75
+ """Set the event if the future is already completed."""
76
+ if self.status in ("ready", "failed"):
77
+ self.event.set()
78
+ return self
79
+
80
+
81
+ class FutureStore:
82
+ """Runs controller work asynchronously and tracks each request's lifecycle."""
83
+
84
+ REDIS_KEY_PREFIX = "future"
85
+
86
+ def __init__(self) -> None:
87
+ self._records: dict[str, FutureRecord] = {}
88
+ self._lock = asyncio.Lock()
89
+ self._tasks: set[asyncio.Task[None]] = set()
90
+ self._next_future_id: int = 1
91
+ self._restore_from_redis()
92
+
93
+ def _build_key(self, request_id: str) -> str:
94
+ return get_redis_store().build_key(self.REDIS_KEY_PREFIX, request_id)
95
+
96
+ def _restore_from_redis(self) -> None:
97
+ if not is_persistence_enabled():
98
+ return
99
+ store = get_redis_store()
100
+ pattern = store.build_key(self.REDIS_KEY_PREFIX, "*")
101
+ for key in store.keys(pattern):
102
+ record = load_record(key, FutureRecord)
103
+ if record is None:
104
+ # Record may have expired (TTL) or failed to deserialize
105
+ # This is expected for expired futures, just skip them
106
+ continue
107
+ if record.status != "pending":
108
+ record.event.set()
109
+ self._records[record.request_id] = record
110
+ if record.future_id >= self._next_future_id:
111
+ self._next_future_id = record.future_id + 1
112
+
113
+ def _save_future(self, request_id: str) -> None:
114
+ if not is_persistence_enabled():
115
+ return
116
+ record = self._records.get(request_id)
117
+ if record is not None:
118
+ # Use TTL for futures to prevent Redis from growing indefinitely
119
+ # Futures are short-lived and can be safely expired
120
+ ttl = get_redis_store().future_ttl
121
+ save_record(self._build_key(request_id), record, ttl_seconds=ttl)
122
+
123
+ def _allocate_future_id(self) -> int:
124
+ """Allocate and return a new globally unique future_id."""
125
+ future_id = self._next_future_id
126
+ self._next_future_id += 1
127
+ return future_id
128
+
129
+ def get_current_future_id(self) -> int:
130
+ """Get the current (latest allocated) future_id, or 0 if none allocated."""
131
+ return self._next_future_id - 1 if self._next_future_id > 1 else 0
132
+
133
+ def _delete_future(self, request_id: str) -> None:
134
+ if not is_persistence_enabled():
135
+ return
136
+ get_redis_store().delete(self._build_key(request_id))
137
+
138
+ def get_pending_futures_by_model(self) -> dict[str | None, list[FutureRecord]]:
139
+ """Group all pending futures by model_id."""
140
+ by_model: dict[str | None, list[FutureRecord]] = {}
141
+ for record in self._records.values():
142
+ if record.status == "pending":
143
+ if record.model_id not in by_model:
144
+ by_model[record.model_id] = []
145
+ by_model[record.model_id].append(record)
146
+
147
+ for model_id in by_model:
148
+ by_model[model_id].sort(key=lambda r: r.future_id)
149
+
150
+ return by_model
151
+
152
+ def mark_futures_failed_after_checkpoint(
153
+ self,
154
+ model_id: str | None,
155
+ checkpoint_future_id: int | None,
156
+ error_message: str = "Server restored from checkpoint. Please retry.",
157
+ ) -> int:
158
+ """Mark all futures for a model after a checkpoint as failed."""
159
+ count = 0
160
+ for record in self._records.values():
161
+ if record.model_id != model_id:
162
+ continue
163
+ if checkpoint_future_id is None or record.future_id > checkpoint_future_id:
164
+ record.status = "failed"
165
+ record.error = types.RequestFailedResponse(
166
+ error=error_message,
167
+ category=types.RequestErrorCategory.Server,
168
+ )
169
+ record.event.set()
170
+ self._save_future(record.request_id)
171
+ count += 1
172
+ return count
173
+
174
+ def mark_all_pending_failed(
175
+ self,
176
+ error_message: str = "Server restarted while task was pending. Please retry.",
177
+ ) -> int:
178
+ """Mark all pending futures as failed."""
179
+ count = 0
180
+ for record in self._records.values():
181
+ if record.status == "pending":
182
+ record.status = "failed"
183
+ record.error = types.RequestFailedResponse(
184
+ error=error_message,
185
+ category=types.RequestErrorCategory.Server,
186
+ )
187
+ record.event.set()
188
+ self._save_future(record.request_id)
189
+ count += 1
190
+ return count
191
+
192
+ def _store_record(self, record: FutureRecord) -> None:
193
+ self._records[record.request_id] = record
194
+ self._save_future(record.request_id)
195
+
196
+ async def enqueue(
197
+ self,
198
+ operation: Callable[[], Any],
199
+ user_id: str,
200
+ *,
201
+ model_id: str | None = None,
202
+ queue_state: QueueState = "active",
203
+ operation_type: OperationType | None = None,
204
+ operation_args: dict[str, Any] | None = None,
205
+ ) -> types.UntypedAPIFuture:
206
+ """Enqueue a task (sync or async) and return a future immediately.
207
+
208
+ Args:
209
+ operation: The callable to execute.
210
+ user_id: The user ID making the request.
211
+ model_id: Optional model ID associated with this operation.
212
+ queue_state: State of the queue.
213
+ operation_type: Type of operation for recovery purposes.
214
+ operation_args: Serializable arguments for recovery.
215
+ """
216
+ async with self._lock:
217
+ future_id = self._allocate_future_id()
218
+ record = FutureRecord(
219
+ future_id=future_id,
220
+ model_id=model_id,
221
+ user_id=user_id,
222
+ queue_state=queue_state,
223
+ operation_type=operation_type,
224
+ operation_args=operation_args,
225
+ )
226
+ self._store_record(record)
227
+
228
+ # Update metrics
229
+ metrics = get_metrics()
230
+ metrics.futures_created.add(
231
+ 1, {"operation_type": operation_type or "unknown", "model_id": model_id or ""}
232
+ )
233
+ metrics.futures_queue_length.add(1, {"queue_state": queue_state})
234
+
235
+ logger.info("Future enqueued: %s", record.request_id)
236
+ enqueue_time = time.perf_counter()
237
+
238
+ async def _runner() -> None:
239
+ start_time = time.perf_counter()
240
+ wait_time = start_time - enqueue_time
241
+
242
+ with _get_tracer().start_as_current_span("future_store.execute_operation") as span:
243
+ span.set_attribute("tuft.request_id", record.request_id)
244
+ span.set_attribute("tuft.operation_type", operation_type or "unknown")
245
+ if model_id:
246
+ span.set_attribute("tuft.model_id", model_id)
247
+
248
+ logger.info("Future begin: %s", record.request_id)
249
+ try:
250
+ if asyncio.iscoroutinefunction(operation):
251
+ payload = await operation()
252
+ else:
253
+ # Run sync operation in thread pool to avoid blocking
254
+ loop = asyncio.get_running_loop()
255
+ payload = await loop.run_in_executor(None, operation)
256
+ except TuFTException as exc:
257
+ message = exc.detail
258
+ failure = types.RequestFailedResponse(
259
+ error=message,
260
+ category=types.RequestErrorCategory.User,
261
+ )
262
+ span.record_exception(exc)
263
+ logger.error("Future failed: %s", record.request_id)
264
+ await self._mark_failed(record.request_id, failure, operation_type)
265
+ except Exception as exc: # pylint: disable=broad-except
266
+ failure = types.RequestFailedResponse(
267
+ error=str(exc),
268
+ category=types.RequestErrorCategory.Server,
269
+ )
270
+ span.record_exception(exc)
271
+ logger.error("Future failed: %s", record.request_id)
272
+ await self._mark_failed(record.request_id, failure, operation_type)
273
+ else:
274
+ logger.info("Future completed: %s", record.request_id)
275
+ await self._mark_ready(record.request_id, payload, operation_type)
276
+ finally:
277
+ # Record execution time
278
+ execution_time = time.perf_counter() - start_time
279
+ metrics.futures_wait_time.record(
280
+ wait_time, {"operation_type": operation_type or "unknown"}
281
+ )
282
+ metrics.futures_execution_time.record(
283
+ execution_time, {"operation_type": operation_type or "unknown"}
284
+ )
285
+ metrics.futures_queue_length.add(-1, {"queue_state": queue_state})
286
+
287
+ # Clean up task reference
288
+ task = asyncio.current_task()
289
+ if task:
290
+ self._tasks.discard(task)
291
+
292
+ # Create and track the task
293
+ task = asyncio.create_task(_runner())
294
+ self._tasks.add(task)
295
+ return types.UntypedAPIFuture(request_id=record.request_id, model_id=model_id)
296
+
297
+ async def create_ready_future(
298
+ self,
299
+ payload: Any,
300
+ user_id: str,
301
+ *,
302
+ model_id: str | None = None,
303
+ ) -> types.UntypedAPIFuture:
304
+ """Create a future that's already completed."""
305
+ async with self._lock:
306
+ future_id = self._allocate_future_id()
307
+ record = FutureRecord(
308
+ future_id=future_id,
309
+ payload=payload,
310
+ model_id=model_id,
311
+ user_id=user_id,
312
+ status="ready",
313
+ )
314
+ record.event.set()
315
+ self._store_record(record)
316
+
317
+ return types.UntypedAPIFuture(request_id=record.request_id, model_id=model_id)
318
+
319
+ async def _mark_ready(
320
+ self, request_id: str, payload: Any, operation_type: str | None = None
321
+ ) -> None:
322
+ """Mark a future as ready with the given payload."""
323
+ async with self._lock:
324
+ record = self._records.get(request_id)
325
+ if record is None:
326
+ return
327
+ record.payload = payload
328
+ record.status = "ready"
329
+ record.error = None
330
+ record.event.set()
331
+ self._save_future(request_id)
332
+
333
+ # Update metrics
334
+ get_metrics().futures_completed.add(
335
+ 1,
336
+ {
337
+ "operation_type": operation_type or record.operation_type or "unknown",
338
+ "status": "ready",
339
+ },
340
+ )
341
+
342
+ async def _mark_failed(
343
+ self,
344
+ request_id: str,
345
+ failure: types.RequestFailedResponse,
346
+ operation_type: str | None = None,
347
+ ) -> None:
348
+ """Mark a future as failed with the given error."""
349
+ async with self._lock:
350
+ record = self._records.get(request_id)
351
+ if record is None:
352
+ return
353
+ record.status = "failed"
354
+ record.error = failure
355
+ record.event.set()
356
+ self._save_future(request_id)
357
+
358
+ # Update metrics
359
+ get_metrics().futures_completed.add(
360
+ 1,
361
+ {
362
+ "operation_type": operation_type or record.operation_type or "unknown",
363
+ "status": "failed",
364
+ },
365
+ )
366
+
367
+ async def retrieve(
368
+ self,
369
+ request_id: str,
370
+ user_id: str,
371
+ *,
372
+ timeout: float = 120,
373
+ ) -> Any:
374
+ """
375
+ Retrieve the result of a future, waiting if it's still pending.
376
+
377
+ Args:
378
+ request_id: The ID of the request to retrieve
379
+ user_id: The ID of the user making the request
380
+ timeout: Maximum time to wait in seconds (None for no timeout)
381
+
382
+ Returns:
383
+ The payload if ready, or error response if failed
384
+
385
+ Raises:
386
+ FutureNotFoundException: If request_id not found (may have expired due to TTL)
387
+ UserMismatchException: If user_id does not match the owner
388
+ asyncio.TimeoutError: If timeout is exceeded
389
+ """
390
+ # Get the record
391
+ async with self._lock:
392
+ record = self._records.get(request_id)
393
+
394
+ if record is None:
395
+ # Record not found - may have expired due to TTL or never existed
396
+ raise FutureNotFoundException(request_id)
397
+ if record.user_id != user_id:
398
+ raise UserMismatchException()
399
+ # Wait for completion if still pending
400
+ if record.status == "pending":
401
+ try:
402
+ await asyncio.wait_for(record.event.wait(), timeout=timeout)
403
+ except asyncio.TimeoutError:
404
+ # Return TryAgainResponse on timeout for backwards compatibility
405
+ return TryAgainResponse(request_id=request_id, queue_state=record.queue_state)
406
+
407
+ # Return result
408
+ if record.status == "failed" and record.error is not None:
409
+ return record.error
410
+
411
+ return record.payload
412
+
413
+ async def cleanup(self, request_id: str) -> None:
414
+ """Remove a completed request from the store to free memory."""
415
+ async with self._lock:
416
+ self._records.pop(request_id, None)
417
+ self._delete_future(request_id)
418
+
419
+ async def shutdown(self) -> None:
420
+ """Cancel all pending tasks and clean up."""
421
+ # Cancel all running tasks
422
+ for task in self._tasks:
423
+ if not task.done():
424
+ task.cancel()
425
+
426
+ # Wait for all tasks to complete (with cancellation)
427
+ if self._tasks:
428
+ await asyncio.gather(*self._tasks, return_exceptions=True)
429
+
430
+ self._tasks.clear()
431
+ self._records.clear()