tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
tuft/config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Configuration helpers for the TuFT service."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, Iterable, List
|
|
8
|
+
|
|
9
|
+
from .persistence import PersistenceConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _default_checkpoint_dir() -> Path:
|
|
13
|
+
return Path.home() / ".cache" / "tuft" / "checkpoints"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _default_persistence_config() -> PersistenceConfig:
|
|
17
|
+
return PersistenceConfig()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class TelemetryConfig:
|
|
22
|
+
"""Configuration for OpenTelemetry integration.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
enabled: Whether telemetry is enabled.
|
|
26
|
+
service_name: Name of the service for tracing.
|
|
27
|
+
otlp_endpoint: OTLP exporter endpoint. If None, uses TUFT_OTLP_ENDPOINT env var.
|
|
28
|
+
resource_attributes: Additional resource attributes as key-value pairs.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
enabled: bool = False
|
|
32
|
+
service_name: str = "tuft"
|
|
33
|
+
otlp_endpoint: str | None = None
|
|
34
|
+
resource_attributes: Dict[str, str] = field(default_factory=dict)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _default_telemetry_config() -> TelemetryConfig:
|
|
38
|
+
return TelemetryConfig()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class AppConfig:
|
|
43
|
+
"""Runtime configuration for the TuFT server."""
|
|
44
|
+
|
|
45
|
+
checkpoint_dir: Path = field(default_factory=_default_checkpoint_dir)
|
|
46
|
+
supported_models: List[ModelConfig] = field(default_factory=list)
|
|
47
|
+
model_owner: str = "local-user"
|
|
48
|
+
toy_backend_seed: int = 0
|
|
49
|
+
# TODO: Temporary implementation for user authorization,
|
|
50
|
+
# replace with proper auth system later
|
|
51
|
+
authorized_users: Dict[str, str] = field(default_factory=dict)
|
|
52
|
+
persistence: PersistenceConfig = field(default_factory=_default_persistence_config)
|
|
53
|
+
telemetry: TelemetryConfig = field(default_factory=_default_telemetry_config)
|
|
54
|
+
|
|
55
|
+
def ensure_directories(self) -> None:
|
|
56
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
def check_validity(self) -> None:
|
|
59
|
+
if not self.supported_models:
|
|
60
|
+
raise ValueError("At least one supported model must be configured.")
|
|
61
|
+
model_names = {model.model_name for model in self.supported_models}
|
|
62
|
+
if len(model_names) != len(self.supported_models):
|
|
63
|
+
raise ValueError("Model names in supported_models must be unique.")
|
|
64
|
+
if len(model_names) > 1 and any(model.colocate for model in self.supported_models):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Colocate option is only allowed when there is a single supported model."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def with_supported_models(self, models: Iterable[ModelConfig]) -> "AppConfig":
|
|
70
|
+
updated = list(models)
|
|
71
|
+
if updated:
|
|
72
|
+
self.supported_models = updated
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class ModelConfig:
|
|
78
|
+
"""Configuration for a specific model."""
|
|
79
|
+
|
|
80
|
+
model_name: str # name used in APIs
|
|
81
|
+
model_path: Path # path to model checkpoint
|
|
82
|
+
max_model_len: int # maximum context length supported by the model
|
|
83
|
+
tensor_parallel_size: int = 1 # tensor parallel size
|
|
84
|
+
|
|
85
|
+
# default sampling parameters for this model
|
|
86
|
+
temperature: float = 1.0
|
|
87
|
+
top_p: float = 1.0
|
|
88
|
+
top_k: int = -1
|
|
89
|
+
logprobs: int = 0
|
|
90
|
+
seed: int = 42
|
|
91
|
+
min_response_tokens: int = 0
|
|
92
|
+
|
|
93
|
+
# default lora setting
|
|
94
|
+
max_lora_rank: int = 16 # maximum rank for LoRA adapters
|
|
95
|
+
max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
|
|
96
|
+
|
|
97
|
+
# whether to colocate sampling and training on the same device
|
|
98
|
+
# only for local testing purposes
|
|
99
|
+
colocate: bool = False
|
|
100
|
+
sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
|
|
101
|
+
|
|
102
|
+
def __post_init__(self) -> None:
|
|
103
|
+
if self.colocate and self.tensor_parallel_size != 1:
|
|
104
|
+
raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_yaml_config(config_path: Path) -> AppConfig:
|
|
108
|
+
"""Loads an AppConfig from a YAML file."""
|
|
109
|
+
from omegaconf import OmegaConf
|
|
110
|
+
|
|
111
|
+
schema = OmegaConf.structured(AppConfig)
|
|
112
|
+
loaded = OmegaConf.load(config_path)
|
|
113
|
+
try:
|
|
114
|
+
config = OmegaConf.merge(schema, loaded)
|
|
115
|
+
app_config = OmegaConf.to_object(config)
|
|
116
|
+
assert isinstance(app_config, AppConfig), (
|
|
117
|
+
"Loaded config is not of type AppConfig, which should not happen."
|
|
118
|
+
)
|
|
119
|
+
return app_config
|
|
120
|
+
except Exception as e:
|
|
121
|
+
raise ValueError(f"Failed to load config from {config_path}: {e}") from e
|
tuft/exceptions.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Some custom exceptions."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TuFTException(Exception):
|
|
5
|
+
"""Base exception for TuFT errors."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, detail: str = ""):
|
|
8
|
+
super().__init__(detail)
|
|
9
|
+
self.detail = detail
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelException(TuFTException):
|
|
13
|
+
"""Base exception for Model related errors."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CheckpointException(TuFTException):
|
|
17
|
+
"""Base exception for Checkpoint related errors."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FutureException(TuFTException):
|
|
21
|
+
"""Base exception for Future related errors."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SessionException(TuFTException):
|
|
25
|
+
"""Base exception for Session related errors."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AuthenticationException(TuFTException):
|
|
29
|
+
"""Base exception for Authentication related errors."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LossFunctionException(TuFTException):
|
|
33
|
+
"""Base exception for Loss Function related errors."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class UnknownModelException(ModelException):
|
|
37
|
+
"""A model was requested that is not known."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, model_name: str | None):
|
|
40
|
+
detail = f"Unknown model: {model_name}"
|
|
41
|
+
super().__init__(detail)
|
|
42
|
+
self.model_name = model_name
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CheckpointNotFoundException(CheckpointException):
|
|
46
|
+
"""Checkpoint not found."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, checkpoint_id: str):
|
|
49
|
+
detail = f"Checkpoint {checkpoint_id} not found."
|
|
50
|
+
super().__init__(detail)
|
|
51
|
+
self.checkpoint_id = checkpoint_id
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class CheckpointAccessDeniedException(CheckpointException):
|
|
55
|
+
"""Access to the checkpoint is denied."""
|
|
56
|
+
|
|
57
|
+
def __init__(self, checkpoint_id: str):
|
|
58
|
+
detail = f"Access to checkpoint {checkpoint_id} is denied."
|
|
59
|
+
super().__init__(detail)
|
|
60
|
+
self.checkpoint_id = checkpoint_id
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class CheckpointMetadataReadException(CheckpointException):
|
|
64
|
+
"""Failed to read checkpoint metadata."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, checkpoint_id: str):
|
|
67
|
+
detail = f"Failed to read metadata for checkpoint {checkpoint_id}."
|
|
68
|
+
super().__init__(detail)
|
|
69
|
+
self.checkpoint_id = checkpoint_id
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class SequenceConflictException(FutureException):
|
|
73
|
+
"""A sequence conflict occurred."""
|
|
74
|
+
|
|
75
|
+
def __init__(self, expected: int, got: int):
|
|
76
|
+
detail = f"Sequence conflict: expected {expected}, got {got}."
|
|
77
|
+
super().__init__(detail)
|
|
78
|
+
self.expected = expected
|
|
79
|
+
self.got = got
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class MissingSequenceIDException(FutureException):
|
|
83
|
+
"""Missing sequence ID in the request."""
|
|
84
|
+
|
|
85
|
+
def __init__(self):
|
|
86
|
+
detail = "Missing sequence ID in the request."
|
|
87
|
+
super().__init__(detail)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class FutureNotFoundException(FutureException):
|
|
91
|
+
"""Future not found."""
|
|
92
|
+
|
|
93
|
+
def __init__(self, request_id: str):
|
|
94
|
+
detail = f"Future with request ID {request_id} not found."
|
|
95
|
+
super().__init__(detail)
|
|
96
|
+
self.request_id = request_id
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class SessionNotFoundException(SessionException):
|
|
100
|
+
"""Session not found."""
|
|
101
|
+
|
|
102
|
+
def __init__(self, session_id: str):
|
|
103
|
+
detail = f"Session {session_id} not found."
|
|
104
|
+
super().__init__(detail)
|
|
105
|
+
self.session_id = session_id
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class UserMismatchException(AuthenticationException):
|
|
109
|
+
"""User ID does not match the owner of the resource.
|
|
110
|
+
Do not expose user IDs in the detail message for security reasons.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(self):
|
|
114
|
+
detail = "You do not have permission to access this resource."
|
|
115
|
+
super().__init__(detail)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class LossFunctionNotFoundException(LossFunctionException):
|
|
119
|
+
"""Loss function not found."""
|
|
120
|
+
|
|
121
|
+
def __init__(self, loss_function_name: str):
|
|
122
|
+
detail = f"Loss function {loss_function_name} not found."
|
|
123
|
+
super().__init__(detail)
|
|
124
|
+
self.loss_function_name = loss_function_name
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class LossFunctionMissingInputException(LossFunctionException):
|
|
128
|
+
def __init__(self, missing_input_name: str):
|
|
129
|
+
detail = f"Missing '{missing_input_name}' in loss_fn_inputs."
|
|
130
|
+
super().__init__(detail)
|
|
131
|
+
self.input_name = missing_input_name
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class LossFunctionInputShapeMismatchException(LossFunctionException):
|
|
135
|
+
def __init__(self, shapes: list):
|
|
136
|
+
detail = f"Input tensors must have the same shape. Got shapes: {shapes}"
|
|
137
|
+
super().__init__(detail)
|
|
138
|
+
self.shapes = shapes
|
tuft/futures.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""Simple in-memory future registry for the synthetic Tinker API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from typing import Any, Callable, Literal
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
13
|
+
from tinker import types
|
|
14
|
+
from tinker.types.try_again_response import TryAgainResponse
|
|
15
|
+
|
|
16
|
+
from .exceptions import FutureNotFoundException, TuFTException, UserMismatchException
|
|
17
|
+
from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
|
|
18
|
+
from .telemetry.metrics import get_metrics
|
|
19
|
+
from .telemetry.tracing import get_tracer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
_get_tracer = lambda: get_tracer("tuft.futures") # noqa: E731
|
|
26
|
+
|
|
27
|
+
QueueState = Literal["active", "paused_capacity", "paused_rate_limit"]
|
|
28
|
+
|
|
29
|
+
OperationType = Literal[
|
|
30
|
+
"forward",
|
|
31
|
+
"forward_backward",
|
|
32
|
+
"optim_step",
|
|
33
|
+
"save_weights",
|
|
34
|
+
"save_weights_for_sampler",
|
|
35
|
+
"load_weights",
|
|
36
|
+
"sample",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _now() -> datetime:
|
|
41
|
+
return datetime.now(timezone.utc)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class FutureRecord(BaseModel):
|
|
45
|
+
"""Future record with persistence support.
|
|
46
|
+
|
|
47
|
+
Fields:
|
|
48
|
+
event: Not serialized (excluded) - created fresh on each instance.
|
|
49
|
+
After restore, if status is ready/failed, event is auto-set.
|
|
50
|
+
operation_type: Type of operation for recovery purposes.
|
|
51
|
+
operation_args: Serializable arguments for the operation.
|
|
52
|
+
future_id: Globally incrementing sequence number for ordering futures.
|
|
53
|
+
Used instead of timestamps to avoid timezone/clock issues.
|
|
54
|
+
created_at: Timestamp when the future was created (for logging only).
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
58
|
+
|
|
59
|
+
request_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
60
|
+
future_id: int = 0
|
|
61
|
+
model_id: str | None = None
|
|
62
|
+
user_id: str | None = None
|
|
63
|
+
queue_state: QueueState = "active"
|
|
64
|
+
status: Literal["pending", "ready", "failed"] = "pending"
|
|
65
|
+
payload: Any | None = None
|
|
66
|
+
error: types.RequestFailedResponse | None = None
|
|
67
|
+
operation_type: OperationType | None = None
|
|
68
|
+
operation_args: dict[str, Any] | None = None
|
|
69
|
+
created_at: datetime = Field(default_factory=_now)
|
|
70
|
+
# Runtime-only field, excluded from serialization
|
|
71
|
+
event: asyncio.Event = Field(default_factory=asyncio.Event, exclude=True)
|
|
72
|
+
|
|
73
|
+
@model_validator(mode="after")
|
|
74
|
+
def _set_event_if_completed(self) -> "FutureRecord":
|
|
75
|
+
"""Set the event if the future is already completed."""
|
|
76
|
+
if self.status in ("ready", "failed"):
|
|
77
|
+
self.event.set()
|
|
78
|
+
return self
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class FutureStore:
|
|
82
|
+
"""Runs controller work asynchronously and tracks each request's lifecycle."""
|
|
83
|
+
|
|
84
|
+
REDIS_KEY_PREFIX = "future"
|
|
85
|
+
|
|
86
|
+
def __init__(self) -> None:
|
|
87
|
+
self._records: dict[str, FutureRecord] = {}
|
|
88
|
+
self._lock = asyncio.Lock()
|
|
89
|
+
self._tasks: set[asyncio.Task[None]] = set()
|
|
90
|
+
self._next_future_id: int = 1
|
|
91
|
+
self._restore_from_redis()
|
|
92
|
+
|
|
93
|
+
def _build_key(self, request_id: str) -> str:
|
|
94
|
+
return get_redis_store().build_key(self.REDIS_KEY_PREFIX, request_id)
|
|
95
|
+
|
|
96
|
+
def _restore_from_redis(self) -> None:
|
|
97
|
+
if not is_persistence_enabled():
|
|
98
|
+
return
|
|
99
|
+
store = get_redis_store()
|
|
100
|
+
pattern = store.build_key(self.REDIS_KEY_PREFIX, "*")
|
|
101
|
+
for key in store.keys(pattern):
|
|
102
|
+
record = load_record(key, FutureRecord)
|
|
103
|
+
if record is None:
|
|
104
|
+
# Record may have expired (TTL) or failed to deserialize
|
|
105
|
+
# This is expected for expired futures, just skip them
|
|
106
|
+
continue
|
|
107
|
+
if record.status != "pending":
|
|
108
|
+
record.event.set()
|
|
109
|
+
self._records[record.request_id] = record
|
|
110
|
+
if record.future_id >= self._next_future_id:
|
|
111
|
+
self._next_future_id = record.future_id + 1
|
|
112
|
+
|
|
113
|
+
def _save_future(self, request_id: str) -> None:
|
|
114
|
+
if not is_persistence_enabled():
|
|
115
|
+
return
|
|
116
|
+
record = self._records.get(request_id)
|
|
117
|
+
if record is not None:
|
|
118
|
+
# Use TTL for futures to prevent Redis from growing indefinitely
|
|
119
|
+
# Futures are short-lived and can be safely expired
|
|
120
|
+
ttl = get_redis_store().future_ttl
|
|
121
|
+
save_record(self._build_key(request_id), record, ttl_seconds=ttl)
|
|
122
|
+
|
|
123
|
+
def _allocate_future_id(self) -> int:
|
|
124
|
+
"""Allocate and return a new globally unique future_id."""
|
|
125
|
+
future_id = self._next_future_id
|
|
126
|
+
self._next_future_id += 1
|
|
127
|
+
return future_id
|
|
128
|
+
|
|
129
|
+
def get_current_future_id(self) -> int:
|
|
130
|
+
"""Get the current (latest allocated) future_id, or 0 if none allocated."""
|
|
131
|
+
return self._next_future_id - 1 if self._next_future_id > 1 else 0
|
|
132
|
+
|
|
133
|
+
def _delete_future(self, request_id: str) -> None:
|
|
134
|
+
if not is_persistence_enabled():
|
|
135
|
+
return
|
|
136
|
+
get_redis_store().delete(self._build_key(request_id))
|
|
137
|
+
|
|
138
|
+
def get_pending_futures_by_model(self) -> dict[str | None, list[FutureRecord]]:
|
|
139
|
+
"""Group all pending futures by model_id."""
|
|
140
|
+
by_model: dict[str | None, list[FutureRecord]] = {}
|
|
141
|
+
for record in self._records.values():
|
|
142
|
+
if record.status == "pending":
|
|
143
|
+
if record.model_id not in by_model:
|
|
144
|
+
by_model[record.model_id] = []
|
|
145
|
+
by_model[record.model_id].append(record)
|
|
146
|
+
|
|
147
|
+
for model_id in by_model:
|
|
148
|
+
by_model[model_id].sort(key=lambda r: r.future_id)
|
|
149
|
+
|
|
150
|
+
return by_model
|
|
151
|
+
|
|
152
|
+
def mark_futures_failed_after_checkpoint(
|
|
153
|
+
self,
|
|
154
|
+
model_id: str | None,
|
|
155
|
+
checkpoint_future_id: int | None,
|
|
156
|
+
error_message: str = "Server restored from checkpoint. Please retry.",
|
|
157
|
+
) -> int:
|
|
158
|
+
"""Mark all futures for a model after a checkpoint as failed."""
|
|
159
|
+
count = 0
|
|
160
|
+
for record in self._records.values():
|
|
161
|
+
if record.model_id != model_id:
|
|
162
|
+
continue
|
|
163
|
+
if checkpoint_future_id is None or record.future_id > checkpoint_future_id:
|
|
164
|
+
record.status = "failed"
|
|
165
|
+
record.error = types.RequestFailedResponse(
|
|
166
|
+
error=error_message,
|
|
167
|
+
category=types.RequestErrorCategory.Server,
|
|
168
|
+
)
|
|
169
|
+
record.event.set()
|
|
170
|
+
self._save_future(record.request_id)
|
|
171
|
+
count += 1
|
|
172
|
+
return count
|
|
173
|
+
|
|
174
|
+
def mark_all_pending_failed(
|
|
175
|
+
self,
|
|
176
|
+
error_message: str = "Server restarted while task was pending. Please retry.",
|
|
177
|
+
) -> int:
|
|
178
|
+
"""Mark all pending futures as failed."""
|
|
179
|
+
count = 0
|
|
180
|
+
for record in self._records.values():
|
|
181
|
+
if record.status == "pending":
|
|
182
|
+
record.status = "failed"
|
|
183
|
+
record.error = types.RequestFailedResponse(
|
|
184
|
+
error=error_message,
|
|
185
|
+
category=types.RequestErrorCategory.Server,
|
|
186
|
+
)
|
|
187
|
+
record.event.set()
|
|
188
|
+
self._save_future(record.request_id)
|
|
189
|
+
count += 1
|
|
190
|
+
return count
|
|
191
|
+
|
|
192
|
+
def _store_record(self, record: FutureRecord) -> None:
|
|
193
|
+
self._records[record.request_id] = record
|
|
194
|
+
self._save_future(record.request_id)
|
|
195
|
+
|
|
196
|
+
async def enqueue(
|
|
197
|
+
self,
|
|
198
|
+
operation: Callable[[], Any],
|
|
199
|
+
user_id: str,
|
|
200
|
+
*,
|
|
201
|
+
model_id: str | None = None,
|
|
202
|
+
queue_state: QueueState = "active",
|
|
203
|
+
operation_type: OperationType | None = None,
|
|
204
|
+
operation_args: dict[str, Any] | None = None,
|
|
205
|
+
) -> types.UntypedAPIFuture:
|
|
206
|
+
"""Enqueue a task (sync or async) and return a future immediately.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
operation: The callable to execute.
|
|
210
|
+
user_id: The user ID making the request.
|
|
211
|
+
model_id: Optional model ID associated with this operation.
|
|
212
|
+
queue_state: State of the queue.
|
|
213
|
+
operation_type: Type of operation for recovery purposes.
|
|
214
|
+
operation_args: Serializable arguments for recovery.
|
|
215
|
+
"""
|
|
216
|
+
async with self._lock:
|
|
217
|
+
future_id = self._allocate_future_id()
|
|
218
|
+
record = FutureRecord(
|
|
219
|
+
future_id=future_id,
|
|
220
|
+
model_id=model_id,
|
|
221
|
+
user_id=user_id,
|
|
222
|
+
queue_state=queue_state,
|
|
223
|
+
operation_type=operation_type,
|
|
224
|
+
operation_args=operation_args,
|
|
225
|
+
)
|
|
226
|
+
self._store_record(record)
|
|
227
|
+
|
|
228
|
+
# Update metrics
|
|
229
|
+
metrics = get_metrics()
|
|
230
|
+
metrics.futures_created.add(
|
|
231
|
+
1, {"operation_type": operation_type or "unknown", "model_id": model_id or ""}
|
|
232
|
+
)
|
|
233
|
+
metrics.futures_queue_length.add(1, {"queue_state": queue_state})
|
|
234
|
+
|
|
235
|
+
logger.info("Future enqueued: %s", record.request_id)
|
|
236
|
+
enqueue_time = time.perf_counter()
|
|
237
|
+
|
|
238
|
+
async def _runner() -> None:
|
|
239
|
+
start_time = time.perf_counter()
|
|
240
|
+
wait_time = start_time - enqueue_time
|
|
241
|
+
|
|
242
|
+
with _get_tracer().start_as_current_span("future_store.execute_operation") as span:
|
|
243
|
+
span.set_attribute("tuft.request_id", record.request_id)
|
|
244
|
+
span.set_attribute("tuft.operation_type", operation_type or "unknown")
|
|
245
|
+
if model_id:
|
|
246
|
+
span.set_attribute("tuft.model_id", model_id)
|
|
247
|
+
|
|
248
|
+
logger.info("Future begin: %s", record.request_id)
|
|
249
|
+
try:
|
|
250
|
+
if asyncio.iscoroutinefunction(operation):
|
|
251
|
+
payload = await operation()
|
|
252
|
+
else:
|
|
253
|
+
# Run sync operation in thread pool to avoid blocking
|
|
254
|
+
loop = asyncio.get_running_loop()
|
|
255
|
+
payload = await loop.run_in_executor(None, operation)
|
|
256
|
+
except TuFTException as exc:
|
|
257
|
+
message = exc.detail
|
|
258
|
+
failure = types.RequestFailedResponse(
|
|
259
|
+
error=message,
|
|
260
|
+
category=types.RequestErrorCategory.User,
|
|
261
|
+
)
|
|
262
|
+
span.record_exception(exc)
|
|
263
|
+
logger.error("Future failed: %s", record.request_id)
|
|
264
|
+
await self._mark_failed(record.request_id, failure, operation_type)
|
|
265
|
+
except Exception as exc: # pylint: disable=broad-except
|
|
266
|
+
failure = types.RequestFailedResponse(
|
|
267
|
+
error=str(exc),
|
|
268
|
+
category=types.RequestErrorCategory.Server,
|
|
269
|
+
)
|
|
270
|
+
span.record_exception(exc)
|
|
271
|
+
logger.error("Future failed: %s", record.request_id)
|
|
272
|
+
await self._mark_failed(record.request_id, failure, operation_type)
|
|
273
|
+
else:
|
|
274
|
+
logger.info("Future completed: %s", record.request_id)
|
|
275
|
+
await self._mark_ready(record.request_id, payload, operation_type)
|
|
276
|
+
finally:
|
|
277
|
+
# Record execution time
|
|
278
|
+
execution_time = time.perf_counter() - start_time
|
|
279
|
+
metrics.futures_wait_time.record(
|
|
280
|
+
wait_time, {"operation_type": operation_type or "unknown"}
|
|
281
|
+
)
|
|
282
|
+
metrics.futures_execution_time.record(
|
|
283
|
+
execution_time, {"operation_type": operation_type or "unknown"}
|
|
284
|
+
)
|
|
285
|
+
metrics.futures_queue_length.add(-1, {"queue_state": queue_state})
|
|
286
|
+
|
|
287
|
+
# Clean up task reference
|
|
288
|
+
task = asyncio.current_task()
|
|
289
|
+
if task:
|
|
290
|
+
self._tasks.discard(task)
|
|
291
|
+
|
|
292
|
+
# Create and track the task
|
|
293
|
+
task = asyncio.create_task(_runner())
|
|
294
|
+
self._tasks.add(task)
|
|
295
|
+
return types.UntypedAPIFuture(request_id=record.request_id, model_id=model_id)
|
|
296
|
+
|
|
297
|
+
async def create_ready_future(
|
|
298
|
+
self,
|
|
299
|
+
payload: Any,
|
|
300
|
+
user_id: str,
|
|
301
|
+
*,
|
|
302
|
+
model_id: str | None = None,
|
|
303
|
+
) -> types.UntypedAPIFuture:
|
|
304
|
+
"""Create a future that's already completed."""
|
|
305
|
+
async with self._lock:
|
|
306
|
+
future_id = self._allocate_future_id()
|
|
307
|
+
record = FutureRecord(
|
|
308
|
+
future_id=future_id,
|
|
309
|
+
payload=payload,
|
|
310
|
+
model_id=model_id,
|
|
311
|
+
user_id=user_id,
|
|
312
|
+
status="ready",
|
|
313
|
+
)
|
|
314
|
+
record.event.set()
|
|
315
|
+
self._store_record(record)
|
|
316
|
+
|
|
317
|
+
return types.UntypedAPIFuture(request_id=record.request_id, model_id=model_id)
|
|
318
|
+
|
|
319
|
+
async def _mark_ready(
|
|
320
|
+
self, request_id: str, payload: Any, operation_type: str | None = None
|
|
321
|
+
) -> None:
|
|
322
|
+
"""Mark a future as ready with the given payload."""
|
|
323
|
+
async with self._lock:
|
|
324
|
+
record = self._records.get(request_id)
|
|
325
|
+
if record is None:
|
|
326
|
+
return
|
|
327
|
+
record.payload = payload
|
|
328
|
+
record.status = "ready"
|
|
329
|
+
record.error = None
|
|
330
|
+
record.event.set()
|
|
331
|
+
self._save_future(request_id)
|
|
332
|
+
|
|
333
|
+
# Update metrics
|
|
334
|
+
get_metrics().futures_completed.add(
|
|
335
|
+
1,
|
|
336
|
+
{
|
|
337
|
+
"operation_type": operation_type or record.operation_type or "unknown",
|
|
338
|
+
"status": "ready",
|
|
339
|
+
},
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
async def _mark_failed(
|
|
343
|
+
self,
|
|
344
|
+
request_id: str,
|
|
345
|
+
failure: types.RequestFailedResponse,
|
|
346
|
+
operation_type: str | None = None,
|
|
347
|
+
) -> None:
|
|
348
|
+
"""Mark a future as failed with the given error."""
|
|
349
|
+
async with self._lock:
|
|
350
|
+
record = self._records.get(request_id)
|
|
351
|
+
if record is None:
|
|
352
|
+
return
|
|
353
|
+
record.status = "failed"
|
|
354
|
+
record.error = failure
|
|
355
|
+
record.event.set()
|
|
356
|
+
self._save_future(request_id)
|
|
357
|
+
|
|
358
|
+
# Update metrics
|
|
359
|
+
get_metrics().futures_completed.add(
|
|
360
|
+
1,
|
|
361
|
+
{
|
|
362
|
+
"operation_type": operation_type or record.operation_type or "unknown",
|
|
363
|
+
"status": "failed",
|
|
364
|
+
},
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
async def retrieve(
|
|
368
|
+
self,
|
|
369
|
+
request_id: str,
|
|
370
|
+
user_id: str,
|
|
371
|
+
*,
|
|
372
|
+
timeout: float = 120,
|
|
373
|
+
) -> Any:
|
|
374
|
+
"""
|
|
375
|
+
Retrieve the result of a future, waiting if it's still pending.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
request_id: The ID of the request to retrieve
|
|
379
|
+
user_id: The ID of the user making the request
|
|
380
|
+
timeout: Maximum time to wait in seconds (None for no timeout)
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
The payload if ready, or error response if failed
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
FutureNotFoundException: If request_id not found (may have expired due to TTL)
|
|
387
|
+
UserMismatchException: If user_id does not match the owner
|
|
388
|
+
asyncio.TimeoutError: If timeout is exceeded
|
|
389
|
+
"""
|
|
390
|
+
# Get the record
|
|
391
|
+
async with self._lock:
|
|
392
|
+
record = self._records.get(request_id)
|
|
393
|
+
|
|
394
|
+
if record is None:
|
|
395
|
+
# Record not found - may have expired due to TTL or never existed
|
|
396
|
+
raise FutureNotFoundException(request_id)
|
|
397
|
+
if record.user_id != user_id:
|
|
398
|
+
raise UserMismatchException()
|
|
399
|
+
# Wait for completion if still pending
|
|
400
|
+
if record.status == "pending":
|
|
401
|
+
try:
|
|
402
|
+
await asyncio.wait_for(record.event.wait(), timeout=timeout)
|
|
403
|
+
except asyncio.TimeoutError:
|
|
404
|
+
# Return TryAgainResponse on timeout for backwards compatibility
|
|
405
|
+
return TryAgainResponse(request_id=request_id, queue_state=record.queue_state)
|
|
406
|
+
|
|
407
|
+
# Return result
|
|
408
|
+
if record.status == "failed" and record.error is not None:
|
|
409
|
+
return record.error
|
|
410
|
+
|
|
411
|
+
return record.payload
|
|
412
|
+
|
|
413
|
+
async def cleanup(self, request_id: str) -> None:
|
|
414
|
+
"""Remove a completed request from the store to free memory."""
|
|
415
|
+
async with self._lock:
|
|
416
|
+
self._records.pop(request_id, None)
|
|
417
|
+
self._delete_future(request_id)
|
|
418
|
+
|
|
419
|
+
async def shutdown(self) -> None:
|
|
420
|
+
"""Cancel all pending tasks and clean up."""
|
|
421
|
+
# Cancel all running tasks
|
|
422
|
+
for task in self._tasks:
|
|
423
|
+
if not task.done():
|
|
424
|
+
task.cancel()
|
|
425
|
+
|
|
426
|
+
# Wait for all tasks to complete (with cancellation)
|
|
427
|
+
if self._tasks:
|
|
428
|
+
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
429
|
+
|
|
430
|
+
self._tasks.clear()
|
|
431
|
+
self._records.clear()
|