tuft 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__main__.py +7 -0
- tuft/backends/hf_training_model.py +184 -64
- tuft/cli.py +161 -8
- tuft/config.py +63 -59
- tuft/exceptions.py +66 -0
- tuft/futures.py +22 -2
- tuft/loss_fn/__init__.py +33 -0
- tuft/persistence/__init__.py +10 -2
- tuft/persistence/redis_store.py +352 -31
- tuft/sampling_controller.py +37 -11
- tuft/sequence_executor.py +72 -0
- tuft/server.py +9 -2
- tuft/state.py +3 -0
- tuft/training_controller.py +20 -5
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/METADATA +10 -66
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/RECORD +19 -17
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/WHEEL +0 -0
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/entry_points.txt +0 -0
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/licenses/LICENSE +0 -0
tuft/config.py
CHANGED
|
@@ -2,23 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
8
|
-
|
|
9
|
-
from .persistence import PersistenceConfig
|
|
6
|
+
from typing import Any, Iterable
|
|
10
7
|
|
|
8
|
+
from pydantic import BaseModel, Field, model_validator
|
|
11
9
|
|
|
12
|
-
|
|
13
|
-
return Path.home() / ".cache" / "tuft" / "checkpoints"
|
|
10
|
+
from .persistence import PersistenceConfig
|
|
14
11
|
|
|
15
12
|
|
|
16
|
-
def
|
|
17
|
-
|
|
13
|
+
def _default_checkpoint_dir() -> Path | None:
|
|
14
|
+
"""Return None to let CLI set the default based on TUFT_HOME."""
|
|
15
|
+
return None
|
|
18
16
|
|
|
19
17
|
|
|
20
|
-
|
|
21
|
-
class TelemetryConfig:
|
|
18
|
+
class TelemetryConfig(BaseModel):
|
|
22
19
|
"""Configuration for OpenTelemetry integration.
|
|
23
20
|
|
|
24
21
|
Attributes:
|
|
@@ -31,29 +28,65 @@ class TelemetryConfig:
|
|
|
31
28
|
enabled: bool = False
|
|
32
29
|
service_name: str = "tuft"
|
|
33
30
|
otlp_endpoint: str | None = None
|
|
34
|
-
resource_attributes:
|
|
31
|
+
resource_attributes: dict[str, str] = Field(default_factory=dict)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ModelConfig(BaseModel):
|
|
35
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
36
|
+
|
|
37
|
+
model_name: str # name used in APIs
|
|
38
|
+
model_path: Path # path to model checkpoint
|
|
39
|
+
max_model_len: int # maximum context length supported by the model
|
|
40
|
+
tensor_parallel_size: int = 1 # tensor parallel size
|
|
41
|
+
|
|
42
|
+
# default sampling parameters for this model
|
|
43
|
+
temperature: float = 1.0
|
|
44
|
+
top_p: float = 1.0
|
|
45
|
+
top_k: int = -1
|
|
46
|
+
logprobs: int = 0
|
|
47
|
+
seed: int = 42
|
|
48
|
+
min_response_tokens: int = 0
|
|
49
|
+
|
|
50
|
+
# default lora setting
|
|
51
|
+
max_lora_rank: int = 16 # maximum rank for LoRA adapters
|
|
52
|
+
max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
|
|
35
53
|
|
|
54
|
+
# default training setting
|
|
55
|
+
micro_batch_size: int = 1 # micro-batch size for training
|
|
36
56
|
|
|
37
|
-
|
|
38
|
-
|
|
57
|
+
# whether to colocate sampling and training on the same device
|
|
58
|
+
# only for local testing purposes
|
|
59
|
+
colocate: bool = False
|
|
60
|
+
sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
|
|
61
|
+
|
|
62
|
+
@model_validator(mode="after")
|
|
63
|
+
def validate_colocate(self) -> "ModelConfig":
|
|
64
|
+
if self.colocate and self.tensor_parallel_size != 1:
|
|
65
|
+
raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
|
|
66
|
+
return self
|
|
39
67
|
|
|
40
68
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
"""Runtime configuration for the TuFT server."""
|
|
69
|
+
class AppConfig(BaseModel):
|
|
70
|
+
"""Runtime configuration for the TuFT server.
|
|
44
71
|
|
|
45
|
-
|
|
46
|
-
|
|
72
|
+
This is a Pydantic model that can be serialized/deserialized for persistence.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
76
|
+
|
|
77
|
+
checkpoint_dir: Path | None = Field(default_factory=_default_checkpoint_dir)
|
|
78
|
+
supported_models: list[ModelConfig] = Field(default_factory=list)
|
|
47
79
|
model_owner: str = "local-user"
|
|
48
80
|
toy_backend_seed: int = 0
|
|
49
81
|
# TODO: Temporary implementation for user authorization,
|
|
50
82
|
# replace with proper auth system later
|
|
51
|
-
authorized_users:
|
|
52
|
-
persistence: PersistenceConfig =
|
|
53
|
-
telemetry: TelemetryConfig =
|
|
83
|
+
authorized_users: dict[str, str] = Field(default_factory=dict)
|
|
84
|
+
persistence: PersistenceConfig = Field(default_factory=PersistenceConfig)
|
|
85
|
+
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
|
|
54
86
|
|
|
55
87
|
def ensure_directories(self) -> None:
|
|
56
|
-
self.checkpoint_dir
|
|
88
|
+
if self.checkpoint_dir is not None:
|
|
89
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
57
90
|
|
|
58
91
|
def check_validity(self) -> None:
|
|
59
92
|
if not self.supported_models:
|
|
@@ -72,50 +105,21 @@ class AppConfig:
|
|
|
72
105
|
self.supported_models = updated
|
|
73
106
|
return self
|
|
74
107
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
"""Configuration for a specific model."""
|
|
79
|
-
|
|
80
|
-
model_name: str # name used in APIs
|
|
81
|
-
model_path: Path # path to model checkpoint
|
|
82
|
-
max_model_len: int # maximum context length supported by the model
|
|
83
|
-
tensor_parallel_size: int = 1 # tensor parallel size
|
|
84
|
-
|
|
85
|
-
# default sampling parameters for this model
|
|
86
|
-
temperature: float = 1.0
|
|
87
|
-
top_p: float = 1.0
|
|
88
|
-
top_k: int = -1
|
|
89
|
-
logprobs: int = 0
|
|
90
|
-
seed: int = 42
|
|
91
|
-
min_response_tokens: int = 0
|
|
92
|
-
|
|
93
|
-
# default lora setting
|
|
94
|
-
max_lora_rank: int = 16 # maximum rank for LoRA adapters
|
|
95
|
-
max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
|
|
96
|
-
|
|
97
|
-
# whether to colocate sampling and training on the same device
|
|
98
|
-
# only for local testing purposes
|
|
99
|
-
colocate: bool = False
|
|
100
|
-
sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
|
|
101
|
-
|
|
102
|
-
def __post_init__(self) -> None:
|
|
103
|
-
if self.colocate and self.tensor_parallel_size != 1:
|
|
104
|
-
raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
|
|
108
|
+
def get_config_for_persistence(self) -> dict[str, Any]:
|
|
109
|
+
"""Get config fields for persistence signature (excludes persistence config itself)."""
|
|
110
|
+
return self.model_dump(mode="json", exclude={"persistence"})
|
|
105
111
|
|
|
106
112
|
|
|
107
113
|
def load_yaml_config(config_path: Path) -> AppConfig:
|
|
108
114
|
"""Loads an AppConfig from a YAML file."""
|
|
109
115
|
from omegaconf import OmegaConf
|
|
110
116
|
|
|
111
|
-
schema = OmegaConf.structured(AppConfig)
|
|
112
117
|
loaded = OmegaConf.load(config_path)
|
|
113
118
|
try:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
"
|
|
118
|
-
)
|
|
119
|
-
return app_config
|
|
119
|
+
# Convert OmegaConf to plain dict for Pydantic
|
|
120
|
+
config_dict = OmegaConf.to_container(loaded, resolve=True)
|
|
121
|
+
if not isinstance(config_dict, dict):
|
|
122
|
+
raise ValueError("Config file must contain a dictionary at root level")
|
|
123
|
+
return AppConfig.model_validate(config_dict)
|
|
120
124
|
except Exception as e:
|
|
121
125
|
raise ValueError(f"Failed to load config from {config_path}: {e}") from e
|
tuft/exceptions.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Some custom exceptions."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
class TuFTException(Exception):
|
|
5
7
|
"""Base exception for TuFT errors."""
|
|
@@ -79,6 +81,15 @@ class SequenceConflictException(FutureException):
|
|
|
79
81
|
self.got = got
|
|
80
82
|
|
|
81
83
|
|
|
84
|
+
class SequenceTimeoutException(FutureException):
|
|
85
|
+
"""Timeout waiting for the expected sequence ID."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, expected_sequence_id: int):
|
|
88
|
+
detail = f"Timeout when waiting for sequence ID {expected_sequence_id}."
|
|
89
|
+
super().__init__(detail)
|
|
90
|
+
self.sequence_id = expected_sequence_id
|
|
91
|
+
|
|
92
|
+
|
|
82
93
|
class MissingSequenceIDException(FutureException):
|
|
83
94
|
"""Missing sequence ID in the request."""
|
|
84
95
|
|
|
@@ -136,3 +147,58 @@ class LossFunctionInputShapeMismatchException(LossFunctionException):
|
|
|
136
147
|
detail = f"Input tensors must have the same shape. Got shapes: {shapes}"
|
|
137
148
|
super().__init__(detail)
|
|
138
149
|
self.shapes = shapes
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LossFunctionUnknownMetricReductionException(LossFunctionException):
|
|
153
|
+
def __init__(self, reduction_type: str):
|
|
154
|
+
detail = f"Unknown metric reduction type: {reduction_type}"
|
|
155
|
+
super().__init__(detail)
|
|
156
|
+
self.reduction_type = reduction_type
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class PersistenceException(TuFTException):
|
|
160
|
+
"""Base exception for Persistence related errors."""
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ConfigMismatchError(PersistenceException):
|
|
164
|
+
"""Raised when current config doesn't match the stored config in Redis.
|
|
165
|
+
|
|
166
|
+
This error occurs during server startup when persistence is enabled and
|
|
167
|
+
the configuration has changed since the last run. This can cause data
|
|
168
|
+
corruption when restoring persisted state.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
diff: dict[str, dict[str, Any]],
|
|
174
|
+
):
|
|
175
|
+
self.diff = diff
|
|
176
|
+
|
|
177
|
+
# Build detailed diff message
|
|
178
|
+
diff_parts = []
|
|
179
|
+
for field_name, field_diff in diff.items():
|
|
180
|
+
# Handle scalar fields (current/stored)
|
|
181
|
+
current = field_diff.get("current")
|
|
182
|
+
stored = field_diff.get("stored")
|
|
183
|
+
|
|
184
|
+
parts = []
|
|
185
|
+
if current is not None or stored is not None:
|
|
186
|
+
parts.append(f"current: {current}, stored: {stored}")
|
|
187
|
+
|
|
188
|
+
if parts:
|
|
189
|
+
diff_parts.append(f"{field_name} ({', '.join(parts)})")
|
|
190
|
+
|
|
191
|
+
diff_str = "; ".join(diff_parts) if diff_parts else "unknown difference"
|
|
192
|
+
|
|
193
|
+
message = (
|
|
194
|
+
f"Configuration mismatch detected: {diff_str}.\n"
|
|
195
|
+
"The current configuration does not match the stored configuration in Redis.\n"
|
|
196
|
+
"This can cause data corruption when restoring persisted state.\n\n"
|
|
197
|
+
"Options:\n"
|
|
198
|
+
" 1. Use a different Redis database (change redis_url in config)\n"
|
|
199
|
+
" 2. Run `tuft clear persistence -c <config_path>` to clear existing data\n"
|
|
200
|
+
" Use `--force` or `-f` to skip confirmation prompt.\n"
|
|
201
|
+
" (WARNING: This will delete all persisted sessions, training runs, etc.)\n"
|
|
202
|
+
" 3. Restore the original configuration that matches the stored data"
|
|
203
|
+
)
|
|
204
|
+
super().__init__(message)
|
tuft/futures.py
CHANGED
|
@@ -189,6 +189,24 @@ class FutureStore:
|
|
|
189
189
|
count += 1
|
|
190
190
|
return count
|
|
191
191
|
|
|
192
|
+
def mark_pending_sample_futures_failed(
|
|
193
|
+
self,
|
|
194
|
+
error_message: str = "Server restarted while sample request was pending. Please retry.",
|
|
195
|
+
) -> int:
|
|
196
|
+
"""Mark all pending sample futures as failed."""
|
|
197
|
+
count = 0
|
|
198
|
+
for record in self._records.values():
|
|
199
|
+
if record.status == "pending" and record.operation_type == "sample":
|
|
200
|
+
record.status = "failed"
|
|
201
|
+
record.error = types.RequestFailedResponse(
|
|
202
|
+
error=error_message,
|
|
203
|
+
category=types.RequestErrorCategory.Server,
|
|
204
|
+
)
|
|
205
|
+
record.event.set()
|
|
206
|
+
self._save_future(record.request_id)
|
|
207
|
+
count += 1
|
|
208
|
+
return count
|
|
209
|
+
|
|
192
210
|
def _store_record(self, record: FutureRecord) -> None:
|
|
193
211
|
self._records[record.request_id] = record
|
|
194
212
|
self._save_future(record.request_id)
|
|
@@ -327,8 +345,9 @@ class FutureStore:
|
|
|
327
345
|
record.payload = payload
|
|
328
346
|
record.status = "ready"
|
|
329
347
|
record.error = None
|
|
348
|
+
loop = asyncio.get_event_loop()
|
|
349
|
+
await loop.run_in_executor(None, self._save_future, request_id)
|
|
330
350
|
record.event.set()
|
|
331
|
-
self._save_future(request_id)
|
|
332
351
|
|
|
333
352
|
# Update metrics
|
|
334
353
|
get_metrics().futures_completed.add(
|
|
@@ -352,8 +371,9 @@ class FutureStore:
|
|
|
352
371
|
return
|
|
353
372
|
record.status = "failed"
|
|
354
373
|
record.error = failure
|
|
374
|
+
loop = asyncio.get_event_loop()
|
|
375
|
+
await loop.run_in_executor(None, self._save_future, request_id)
|
|
355
376
|
record.event.set()
|
|
356
|
-
self._save_future(request_id)
|
|
357
377
|
|
|
358
378
|
# Update metrics
|
|
359
379
|
get_metrics().futures_completed.add(
|
tuft/loss_fn/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Callable, Dict, Tuple
|
|
2
2
|
|
|
3
|
+
from tinker.lib.chunked_fwdbwd_helpers import REDUCE_MAP
|
|
3
4
|
from torch import Tensor
|
|
4
5
|
from typing_extensions import TypeAlias
|
|
5
6
|
|
|
@@ -7,6 +8,7 @@ from ..exceptions import (
|
|
|
7
8
|
LossFunctionInputShapeMismatchException,
|
|
8
9
|
LossFunctionMissingInputException,
|
|
9
10
|
LossFunctionNotFoundException,
|
|
11
|
+
LossFunctionUnknownMetricReductionException,
|
|
10
12
|
)
|
|
11
13
|
|
|
12
14
|
|
|
@@ -46,3 +48,34 @@ def _check_loss_fn_inputs(
|
|
|
46
48
|
shapes = [loss_fn_inputs[key].shape for key in required_keys]
|
|
47
49
|
if not all(shape == shapes[0] for shape in shapes):
|
|
48
50
|
raise LossFunctionInputShapeMismatchException(shapes)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def metrics_reduction(
|
|
54
|
+
metric_list: list[dict[str, float]],
|
|
55
|
+
weights: list[float],
|
|
56
|
+
) -> dict[str, float]:
|
|
57
|
+
"""Aggregate metrics from multiple batches.
|
|
58
|
+
|
|
59
|
+
Modified from tinker.lib.chunked_fwdbwd_helpers._metrics_reduction
|
|
60
|
+
"""
|
|
61
|
+
if not metric_list:
|
|
62
|
+
return {}
|
|
63
|
+
keys = metric_list[0].keys()
|
|
64
|
+
result = {}
|
|
65
|
+
for key in keys:
|
|
66
|
+
_, reduction = key.split(":")
|
|
67
|
+
if reduction not in REDUCE_MAP:
|
|
68
|
+
raise LossFunctionUnknownMetricReductionException(reduction)
|
|
69
|
+
if not all(key in m for m in metric_list):
|
|
70
|
+
continue
|
|
71
|
+
reduce_fn = REDUCE_MAP[reduction]
|
|
72
|
+
values = [m[key] for m in metric_list]
|
|
73
|
+
|
|
74
|
+
if reduction in ["mean", "slack"]:
|
|
75
|
+
result[key] = reduce_fn(values, weights)
|
|
76
|
+
elif reduction in ["unique"]:
|
|
77
|
+
result[key] = values[0]
|
|
78
|
+
result.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])})
|
|
79
|
+
else:
|
|
80
|
+
result[key] = reduce_fn(values)
|
|
81
|
+
return result
|
tuft/persistence/__init__.py
CHANGED
|
@@ -3,30 +3,38 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from .redis_store import (
|
|
6
|
-
|
|
6
|
+
ConfigCheckField,
|
|
7
7
|
PersistenceConfig,
|
|
8
8
|
PersistenceMode,
|
|
9
9
|
RedisPipeline,
|
|
10
10
|
RedisStore,
|
|
11
11
|
delete_record,
|
|
12
|
+
flush_all_data,
|
|
13
|
+
get_current_namespace,
|
|
12
14
|
get_redis_store,
|
|
13
15
|
is_persistence_enabled,
|
|
14
16
|
load_record,
|
|
17
|
+
save_config_signature,
|
|
15
18
|
save_record,
|
|
16
19
|
save_records_atomic,
|
|
20
|
+
validate_config_signature,
|
|
17
21
|
)
|
|
18
22
|
|
|
19
23
|
|
|
20
24
|
__all__ = [
|
|
21
|
-
"
|
|
25
|
+
"ConfigCheckField",
|
|
22
26
|
"PersistenceConfig",
|
|
23
27
|
"PersistenceMode",
|
|
24
28
|
"RedisPipeline",
|
|
25
29
|
"RedisStore",
|
|
26
30
|
"delete_record",
|
|
31
|
+
"flush_all_data",
|
|
32
|
+
"get_current_namespace",
|
|
27
33
|
"get_redis_store",
|
|
28
34
|
"is_persistence_enabled",
|
|
29
35
|
"load_record",
|
|
36
|
+
"save_config_signature",
|
|
30
37
|
"save_record",
|
|
31
38
|
"save_records_atomic",
|
|
39
|
+
"validate_config_signature",
|
|
32
40
|
]
|