tuft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/backends/hf_training_model.py +184 -64
- tuft/cli.py +120 -0
- tuft/config.py +58 -56
- tuft/exceptions.py +66 -0
- tuft/futures.py +22 -2
- tuft/loss_fn/__init__.py +33 -0
- tuft/persistence/__init__.py +10 -2
- tuft/persistence/redis_store.py +352 -31
- tuft/sampling_controller.py +34 -10
- tuft/sequence_executor.py +72 -0
- tuft/server.py +9 -2
- tuft/state.py +3 -0
- tuft/training_controller.py +14 -4
- {tuft-0.1.2.dist-info → tuft-0.1.3.dist-info}/METADATA +9 -65
- {tuft-0.1.2.dist-info → tuft-0.1.3.dist-info}/RECORD +18 -17
- {tuft-0.1.2.dist-info → tuft-0.1.3.dist-info}/WHEEL +0 -0
- {tuft-0.1.2.dist-info → tuft-0.1.3.dist-info}/entry_points.txt +0 -0
- {tuft-0.1.2.dist-info → tuft-0.1.3.dist-info}/licenses/LICENSE +0 -0
tuft/exceptions.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Some custom exceptions."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
class TuFTException(Exception):
|
|
5
7
|
"""Base exception for TuFT errors."""
|
|
@@ -79,6 +81,15 @@ class SequenceConflictException(FutureException):
|
|
|
79
81
|
self.got = got
|
|
80
82
|
|
|
81
83
|
|
|
84
|
+
class SequenceTimeoutException(FutureException):
|
|
85
|
+
"""Timeout waiting for the expected sequence ID."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, expected_sequence_id: int):
|
|
88
|
+
detail = f"Timeout when waiting for sequence ID {expected_sequence_id}."
|
|
89
|
+
super().__init__(detail)
|
|
90
|
+
self.sequence_id = expected_sequence_id
|
|
91
|
+
|
|
92
|
+
|
|
82
93
|
class MissingSequenceIDException(FutureException):
|
|
83
94
|
"""Missing sequence ID in the request."""
|
|
84
95
|
|
|
@@ -136,3 +147,58 @@ class LossFunctionInputShapeMismatchException(LossFunctionException):
|
|
|
136
147
|
detail = f"Input tensors must have the same shape. Got shapes: {shapes}"
|
|
137
148
|
super().__init__(detail)
|
|
138
149
|
self.shapes = shapes
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LossFunctionUnknownMetricReductionException(LossFunctionException):
|
|
153
|
+
def __init__(self, reduction_type: str):
|
|
154
|
+
detail = f"Unknown metric reduction type: {reduction_type}"
|
|
155
|
+
super().__init__(detail)
|
|
156
|
+
self.reduction_type = reduction_type
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class PersistenceException(TuFTException):
|
|
160
|
+
"""Base exception for Persistence related errors."""
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ConfigMismatchError(PersistenceException):
|
|
164
|
+
"""Raised when current config doesn't match the stored config in Redis.
|
|
165
|
+
|
|
166
|
+
This error occurs during server startup when persistence is enabled and
|
|
167
|
+
the configuration has changed since the last run. This can cause data
|
|
168
|
+
corruption when restoring persisted state.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
diff: dict[str, dict[str, Any]],
|
|
174
|
+
):
|
|
175
|
+
self.diff = diff
|
|
176
|
+
|
|
177
|
+
# Build detailed diff message
|
|
178
|
+
diff_parts = []
|
|
179
|
+
for field_name, field_diff in diff.items():
|
|
180
|
+
# Handle scalar fields (current/stored)
|
|
181
|
+
current = field_diff.get("current")
|
|
182
|
+
stored = field_diff.get("stored")
|
|
183
|
+
|
|
184
|
+
parts = []
|
|
185
|
+
if current is not None or stored is not None:
|
|
186
|
+
parts.append(f"current: {current}, stored: {stored}")
|
|
187
|
+
|
|
188
|
+
if parts:
|
|
189
|
+
diff_parts.append(f"{field_name} ({', '.join(parts)})")
|
|
190
|
+
|
|
191
|
+
diff_str = "; ".join(diff_parts) if diff_parts else "unknown difference"
|
|
192
|
+
|
|
193
|
+
message = (
|
|
194
|
+
f"Configuration mismatch detected: {diff_str}.\n"
|
|
195
|
+
"The current configuration does not match the stored configuration in Redis.\n"
|
|
196
|
+
"This can cause data corruption when restoring persisted state.\n\n"
|
|
197
|
+
"Options:\n"
|
|
198
|
+
" 1. Use a different Redis database (change redis_url in config)\n"
|
|
199
|
+
" 2. Run `tuft clear persistence -c <config_path>` to clear existing data\n"
|
|
200
|
+
" Use `--force` or `-f` to skip confirmation prompt.\n"
|
|
201
|
+
" (WARNING: This will delete all persisted sessions, training runs, etc.)\n"
|
|
202
|
+
" 3. Restore the original configuration that matches the stored data"
|
|
203
|
+
)
|
|
204
|
+
super().__init__(message)
|
tuft/futures.py
CHANGED
|
@@ -189,6 +189,24 @@ class FutureStore:
|
|
|
189
189
|
count += 1
|
|
190
190
|
return count
|
|
191
191
|
|
|
192
|
+
def mark_pending_sample_futures_failed(
|
|
193
|
+
self,
|
|
194
|
+
error_message: str = "Server restarted while sample request was pending. Please retry.",
|
|
195
|
+
) -> int:
|
|
196
|
+
"""Mark all pending sample futures as failed."""
|
|
197
|
+
count = 0
|
|
198
|
+
for record in self._records.values():
|
|
199
|
+
if record.status == "pending" and record.operation_type == "sample":
|
|
200
|
+
record.status = "failed"
|
|
201
|
+
record.error = types.RequestFailedResponse(
|
|
202
|
+
error=error_message,
|
|
203
|
+
category=types.RequestErrorCategory.Server,
|
|
204
|
+
)
|
|
205
|
+
record.event.set()
|
|
206
|
+
self._save_future(record.request_id)
|
|
207
|
+
count += 1
|
|
208
|
+
return count
|
|
209
|
+
|
|
192
210
|
def _store_record(self, record: FutureRecord) -> None:
|
|
193
211
|
self._records[record.request_id] = record
|
|
194
212
|
self._save_future(record.request_id)
|
|
@@ -327,8 +345,9 @@ class FutureStore:
|
|
|
327
345
|
record.payload = payload
|
|
328
346
|
record.status = "ready"
|
|
329
347
|
record.error = None
|
|
348
|
+
loop = asyncio.get_event_loop()
|
|
349
|
+
await loop.run_in_executor(None, self._save_future, request_id)
|
|
330
350
|
record.event.set()
|
|
331
|
-
self._save_future(request_id)
|
|
332
351
|
|
|
333
352
|
# Update metrics
|
|
334
353
|
get_metrics().futures_completed.add(
|
|
@@ -352,8 +371,9 @@ class FutureStore:
|
|
|
352
371
|
return
|
|
353
372
|
record.status = "failed"
|
|
354
373
|
record.error = failure
|
|
374
|
+
loop = asyncio.get_event_loop()
|
|
375
|
+
await loop.run_in_executor(None, self._save_future, request_id)
|
|
355
376
|
record.event.set()
|
|
356
|
-
self._save_future(request_id)
|
|
357
377
|
|
|
358
378
|
# Update metrics
|
|
359
379
|
get_metrics().futures_completed.add(
|
tuft/loss_fn/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Callable, Dict, Tuple
|
|
2
2
|
|
|
3
|
+
from tinker.lib.chunked_fwdbwd_helpers import REDUCE_MAP
|
|
3
4
|
from torch import Tensor
|
|
4
5
|
from typing_extensions import TypeAlias
|
|
5
6
|
|
|
@@ -7,6 +8,7 @@ from ..exceptions import (
|
|
|
7
8
|
LossFunctionInputShapeMismatchException,
|
|
8
9
|
LossFunctionMissingInputException,
|
|
9
10
|
LossFunctionNotFoundException,
|
|
11
|
+
LossFunctionUnknownMetricReductionException,
|
|
10
12
|
)
|
|
11
13
|
|
|
12
14
|
|
|
@@ -46,3 +48,34 @@ def _check_loss_fn_inputs(
|
|
|
46
48
|
shapes = [loss_fn_inputs[key].shape for key in required_keys]
|
|
47
49
|
if not all(shape == shapes[0] for shape in shapes):
|
|
48
50
|
raise LossFunctionInputShapeMismatchException(shapes)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def metrics_reduction(
|
|
54
|
+
metric_list: list[dict[str, float]],
|
|
55
|
+
weights: list[float],
|
|
56
|
+
) -> dict[str, float]:
|
|
57
|
+
"""Aggregate metrics from multiple batches.
|
|
58
|
+
|
|
59
|
+
Modified from tinker.lib.chunked_fwdbwd_helpers._metrics_reduction
|
|
60
|
+
"""
|
|
61
|
+
if not metric_list:
|
|
62
|
+
return {}
|
|
63
|
+
keys = metric_list[0].keys()
|
|
64
|
+
result = {}
|
|
65
|
+
for key in keys:
|
|
66
|
+
_, reduction = key.split(":")
|
|
67
|
+
if reduction not in REDUCE_MAP:
|
|
68
|
+
raise LossFunctionUnknownMetricReductionException(reduction)
|
|
69
|
+
if not all(key in m for m in metric_list):
|
|
70
|
+
continue
|
|
71
|
+
reduce_fn = REDUCE_MAP[reduction]
|
|
72
|
+
values = [m[key] for m in metric_list]
|
|
73
|
+
|
|
74
|
+
if reduction in ["mean", "slack"]:
|
|
75
|
+
result[key] = reduce_fn(values, weights)
|
|
76
|
+
elif reduction in ["unique"]:
|
|
77
|
+
result[key] = values[0]
|
|
78
|
+
result.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])})
|
|
79
|
+
else:
|
|
80
|
+
result[key] = reduce_fn(values)
|
|
81
|
+
return result
|
tuft/persistence/__init__.py
CHANGED
|
@@ -3,30 +3,38 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from .redis_store import (
|
|
6
|
-
|
|
6
|
+
ConfigCheckField,
|
|
7
7
|
PersistenceConfig,
|
|
8
8
|
PersistenceMode,
|
|
9
9
|
RedisPipeline,
|
|
10
10
|
RedisStore,
|
|
11
11
|
delete_record,
|
|
12
|
+
flush_all_data,
|
|
13
|
+
get_current_namespace,
|
|
12
14
|
get_redis_store,
|
|
13
15
|
is_persistence_enabled,
|
|
14
16
|
load_record,
|
|
17
|
+
save_config_signature,
|
|
15
18
|
save_record,
|
|
16
19
|
save_records_atomic,
|
|
20
|
+
validate_config_signature,
|
|
17
21
|
)
|
|
18
22
|
|
|
19
23
|
|
|
20
24
|
__all__ = [
|
|
21
|
-
"
|
|
25
|
+
"ConfigCheckField",
|
|
22
26
|
"PersistenceConfig",
|
|
23
27
|
"PersistenceMode",
|
|
24
28
|
"RedisPipeline",
|
|
25
29
|
"RedisStore",
|
|
26
30
|
"delete_record",
|
|
31
|
+
"flush_all_data",
|
|
32
|
+
"get_current_namespace",
|
|
27
33
|
"get_redis_store",
|
|
28
34
|
"is_persistence_enabled",
|
|
29
35
|
"load_record",
|
|
36
|
+
"save_config_signature",
|
|
30
37
|
"save_record",
|
|
31
38
|
"save_records_atomic",
|
|
39
|
+
"validate_config_signature",
|
|
32
40
|
]
|