tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
"""Training controller for managing training runs and routing requests."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime, timedelta, timezone
|
|
10
|
+
from typing import Awaitable, Callable, Dict, List, TypeVar
|
|
11
|
+
|
|
12
|
+
from opentelemetry.trace import StatusCode
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
|
14
|
+
from tinker import types
|
|
15
|
+
|
|
16
|
+
from .backends import BaseTrainingBackend
|
|
17
|
+
from .checkpoints import CheckpointRecord
|
|
18
|
+
from .config import AppConfig, ModelConfig
|
|
19
|
+
from .exceptions import (
|
|
20
|
+
CheckpointAccessDeniedException,
|
|
21
|
+
CheckpointMetadataReadException,
|
|
22
|
+
CheckpointNotFoundException,
|
|
23
|
+
SequenceConflictException,
|
|
24
|
+
UnknownModelException,
|
|
25
|
+
UserMismatchException,
|
|
26
|
+
)
|
|
27
|
+
from .persistence import (
|
|
28
|
+
delete_record,
|
|
29
|
+
get_redis_store,
|
|
30
|
+
is_persistence_enabled,
|
|
31
|
+
load_record,
|
|
32
|
+
save_record,
|
|
33
|
+
save_records_atomic,
|
|
34
|
+
)
|
|
35
|
+
from .telemetry.metrics import get_metrics
|
|
36
|
+
from .telemetry.tracing import get_tracer
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
_get_tracer = lambda: get_tracer("tuft.training_controller") # noqa: E731
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
T = TypeVar("T")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _now() -> datetime:
|
|
48
|
+
return datetime.now(timezone.utc)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TrainingRunRecord(BaseModel):
|
|
52
|
+
"""Training run record with persistence support.
|
|
53
|
+
|
|
54
|
+
Runtime-only fields (backend, _execution_lock) are excluded from serialization.
|
|
55
|
+
Checkpoints are stored separately with their own keys.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
59
|
+
|
|
60
|
+
training_run_id: str
|
|
61
|
+
base_model: str
|
|
62
|
+
lora_rank: int
|
|
63
|
+
session_id: str
|
|
64
|
+
model_owner: str
|
|
65
|
+
user_metadata: dict[str, str] | None = None
|
|
66
|
+
created_at: datetime = Field(default_factory=_now)
|
|
67
|
+
last_request_time: datetime = Field(default_factory=_now)
|
|
68
|
+
# Checkpoints are stored separately, excluded from serialization
|
|
69
|
+
checkpoints: Dict[str, CheckpointRecord] = Field(default_factory=dict, exclude=True)
|
|
70
|
+
sampler_checkpoints: Dict[str, CheckpointRecord] = Field(default_factory=dict, exclude=True)
|
|
71
|
+
next_training_checkpoint: int = 1
|
|
72
|
+
next_sampler_checkpoint: int = 1
|
|
73
|
+
corrupted: bool = False
|
|
74
|
+
next_seq_id: int = 1
|
|
75
|
+
# Runtime-only fields, excluded from serialization
|
|
76
|
+
backend: BaseTrainingBackend | None = Field(default=None, exclude=True)
|
|
77
|
+
# Private attribute for execution lock (not a model field)
|
|
78
|
+
_execution_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
|
79
|
+
|
|
80
|
+
def to_training_run(self) -> types.TrainingRun:
|
|
81
|
+
training_checkpoint = self._latest_checkpoint(self.checkpoints)
|
|
82
|
+
sampler_checkpoint = self._latest_checkpoint(self.sampler_checkpoints)
|
|
83
|
+
return types.TrainingRun(
|
|
84
|
+
training_run_id=self.training_run_id,
|
|
85
|
+
base_model=self.base_model,
|
|
86
|
+
model_owner=self.model_owner,
|
|
87
|
+
is_lora=True,
|
|
88
|
+
corrupted=self.corrupted,
|
|
89
|
+
lora_rank=self.lora_rank,
|
|
90
|
+
last_request_time=self.last_request_time,
|
|
91
|
+
last_checkpoint=training_checkpoint,
|
|
92
|
+
last_sampler_checkpoint=sampler_checkpoint,
|
|
93
|
+
user_metadata=self.user_metadata,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def _latest_checkpoint(self, items: Dict[str, CheckpointRecord]) -> types.Checkpoint | None:
|
|
97
|
+
if not items:
|
|
98
|
+
return None
|
|
99
|
+
latest = max(items.values(), key=lambda record: record.created_at)
|
|
100
|
+
return latest.tinker_checkpoint
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class TrainingController:
|
|
104
|
+
"""Tracks training runs, enforces request ordering.
|
|
105
|
+
|
|
106
|
+
Routes work into ModelBackend instances.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
REDIS_KEY_PREFIX = "training_run"
|
|
110
|
+
|
|
111
|
+
def __init__(self, config: AppConfig) -> None:
|
|
112
|
+
self.config = config
|
|
113
|
+
self.training_backends = self._create_backends(config.supported_models)
|
|
114
|
+
# TODO: add a mechanism to manage training_runs
|
|
115
|
+
self.training_runs: Dict[str, TrainingRunRecord] = {}
|
|
116
|
+
self._restore_from_redis()
|
|
117
|
+
|
|
118
|
+
def _create_backends(self, model_configs: List[ModelConfig]) -> Dict[str, BaseTrainingBackend]:
|
|
119
|
+
backends: Dict[str, BaseTrainingBackend] = {}
|
|
120
|
+
for config in model_configs:
|
|
121
|
+
backends[config.model_name] = BaseTrainingBackend.create_backend(config)
|
|
122
|
+
return backends
|
|
123
|
+
|
|
124
|
+
def _build_key(self, model_id: str) -> str:
|
|
125
|
+
return get_redis_store().build_key(self.REDIS_KEY_PREFIX, model_id)
|
|
126
|
+
|
|
127
|
+
def _build_checkpoint_key(self, model_id: str, checkpoint_id: str) -> str:
|
|
128
|
+
return get_redis_store().build_key(self.REDIS_KEY_PREFIX, model_id, "ckpt", checkpoint_id)
|
|
129
|
+
|
|
130
|
+
def _build_sampler_checkpoint_key(self, model_id: str, checkpoint_id: str) -> str:
|
|
131
|
+
return get_redis_store().build_key(
|
|
132
|
+
self.REDIS_KEY_PREFIX, model_id, "sampler_ckpt", checkpoint_id
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _restore_from_redis(self) -> None:
|
|
136
|
+
"""Restore training runs from Redis on startup."""
|
|
137
|
+
if not is_persistence_enabled():
|
|
138
|
+
return
|
|
139
|
+
store = get_redis_store()
|
|
140
|
+
# Match only top-level training runs (3 parts: namespace::prefix::model_id)
|
|
141
|
+
for key in store.keys(store.build_key(self.REDIS_KEY_PREFIX, "*")):
|
|
142
|
+
parts = key.split("::")
|
|
143
|
+
if len(parts) != 3:
|
|
144
|
+
continue
|
|
145
|
+
record = load_record(key, TrainingRunRecord)
|
|
146
|
+
if record is None:
|
|
147
|
+
continue
|
|
148
|
+
model_id = record.training_run_id
|
|
149
|
+
# Restore checkpoints (stored separately, not subject to TTL)
|
|
150
|
+
self._restore_checkpoints(model_id, record)
|
|
151
|
+
self._restore_sampler_checkpoints(model_id, record)
|
|
152
|
+
# Restore backend reference
|
|
153
|
+
if record.base_model in self.training_backends:
|
|
154
|
+
record.backend = self.training_backends[record.base_model]
|
|
155
|
+
else:
|
|
156
|
+
record.corrupted = True
|
|
157
|
+
self.training_runs[model_id] = record
|
|
158
|
+
|
|
159
|
+
def _restore_checkpoints(self, model_id: str, record: TrainingRunRecord) -> None:
|
|
160
|
+
store = get_redis_store()
|
|
161
|
+
pattern = self._build_checkpoint_key(model_id, "*")
|
|
162
|
+
record.checkpoints = {}
|
|
163
|
+
for key in store.keys(pattern):
|
|
164
|
+
ckpt = load_record(key, CheckpointRecord)
|
|
165
|
+
if ckpt is not None:
|
|
166
|
+
record.checkpoints[ckpt.checkpoint_id] = ckpt
|
|
167
|
+
|
|
168
|
+
def _restore_sampler_checkpoints(self, model_id: str, record: TrainingRunRecord) -> None:
|
|
169
|
+
store = get_redis_store()
|
|
170
|
+
pattern = self._build_sampler_checkpoint_key(model_id, "*")
|
|
171
|
+
record.sampler_checkpoints = {}
|
|
172
|
+
for key in store.keys(pattern):
|
|
173
|
+
ckpt = load_record(key, CheckpointRecord)
|
|
174
|
+
if ckpt is not None:
|
|
175
|
+
record.sampler_checkpoints[ckpt.checkpoint_id] = ckpt
|
|
176
|
+
|
|
177
|
+
def _save_training_run(self, model_id: str) -> None:
|
|
178
|
+
"""Save training run to Redis (no TTL - permanent record)."""
|
|
179
|
+
if not is_persistence_enabled():
|
|
180
|
+
return
|
|
181
|
+
record = self.training_runs.get(model_id)
|
|
182
|
+
if record is not None:
|
|
183
|
+
save_record(self._build_key(model_id), record)
|
|
184
|
+
|
|
185
|
+
def _save_checkpoint(self, model_id: str, checkpoint_id: str) -> None:
|
|
186
|
+
"""Save checkpoint to Redis (no TTL - permanent record)."""
|
|
187
|
+
if not is_persistence_enabled():
|
|
188
|
+
return
|
|
189
|
+
record = self.training_runs.get(model_id)
|
|
190
|
+
if record is not None:
|
|
191
|
+
ckpt = record.checkpoints.get(checkpoint_id)
|
|
192
|
+
if ckpt is not None:
|
|
193
|
+
save_record(self._build_checkpoint_key(model_id, checkpoint_id), ckpt)
|
|
194
|
+
|
|
195
|
+
def _save_sampler_checkpoint(self, model_id: str, checkpoint_id: str) -> None:
|
|
196
|
+
"""Save sampler checkpoint to Redis (no TTL - permanent record)."""
|
|
197
|
+
if not is_persistence_enabled():
|
|
198
|
+
return
|
|
199
|
+
record = self.training_runs.get(model_id)
|
|
200
|
+
if record is not None:
|
|
201
|
+
ckpt = record.sampler_checkpoints.get(checkpoint_id)
|
|
202
|
+
if ckpt is not None:
|
|
203
|
+
save_record(self._build_sampler_checkpoint_key(model_id, checkpoint_id), ckpt)
|
|
204
|
+
|
|
205
|
+
def _save_training_run_with_checkpoint(
|
|
206
|
+
self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType
|
|
207
|
+
) -> None:
|
|
208
|
+
"""Save training run and checkpoint atomically using Redis transaction.
|
|
209
|
+
|
|
210
|
+
This ensures consistency if the server crashes between saves.
|
|
211
|
+
No TTL is used for these records as they are permanent.
|
|
212
|
+
"""
|
|
213
|
+
if not is_persistence_enabled():
|
|
214
|
+
return
|
|
215
|
+
record = self.training_runs.get(model_id)
|
|
216
|
+
if record is None:
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
if checkpoint_type == "training":
|
|
220
|
+
ckpt = record.checkpoints.get(checkpoint_id)
|
|
221
|
+
ckpt_key = self._build_checkpoint_key(model_id, checkpoint_id)
|
|
222
|
+
else:
|
|
223
|
+
ckpt = record.sampler_checkpoints.get(checkpoint_id)
|
|
224
|
+
ckpt_key = self._build_sampler_checkpoint_key(model_id, checkpoint_id)
|
|
225
|
+
|
|
226
|
+
if ckpt is None:
|
|
227
|
+
# Defensive fallback: checkpoint should exist at this point since
|
|
228
|
+
# _save_training_run_with_checkpoint is called after adding the checkpoint
|
|
229
|
+
# to the target_map. This branch handles unexpected edge cases (e.g., code
|
|
230
|
+
# refactoring that changes call order) to ensure the training run is still
|
|
231
|
+
# persisted even if the checkpoint lookup fails.
|
|
232
|
+
logger.warning(
|
|
233
|
+
"Checkpoint %s not found for model %s during persistence, "
|
|
234
|
+
"saving training run without checkpoint",
|
|
235
|
+
checkpoint_id,
|
|
236
|
+
model_id,
|
|
237
|
+
)
|
|
238
|
+
save_record(self._build_key(model_id), record)
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
# Save both atomically (no TTL for permanent records)
|
|
242
|
+
save_records_atomic(
|
|
243
|
+
[
|
|
244
|
+
(self._build_key(model_id), record),
|
|
245
|
+
(ckpt_key, ckpt),
|
|
246
|
+
]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def _delete_training_run(self, model_id: str) -> None:
|
|
250
|
+
if not is_persistence_enabled():
|
|
251
|
+
return
|
|
252
|
+
store = get_redis_store()
|
|
253
|
+
store.delete(self._build_key(model_id))
|
|
254
|
+
store.delete_pattern(self._build_checkpoint_key(model_id, "*"))
|
|
255
|
+
store.delete_pattern(self._build_sampler_checkpoint_key(model_id, "*"))
|
|
256
|
+
|
|
257
|
+
def _delete_checkpoint_record(self, model_id: str, checkpoint_id: str) -> None:
|
|
258
|
+
if not is_persistence_enabled():
|
|
259
|
+
return
|
|
260
|
+
delete_record(self._build_checkpoint_key(model_id, checkpoint_id))
|
|
261
|
+
|
|
262
|
+
def _delete_sampler_checkpoint_record(self, model_id: str, checkpoint_id: str) -> None:
|
|
263
|
+
if not is_persistence_enabled():
|
|
264
|
+
return
|
|
265
|
+
delete_record(self._build_sampler_checkpoint_key(model_id, checkpoint_id))
|
|
266
|
+
|
|
267
|
+
async def _with_sequence_guard(
|
|
268
|
+
self,
|
|
269
|
+
record: TrainingRunRecord,
|
|
270
|
+
seq_id: int | None,
|
|
271
|
+
operation: Callable[[], Awaitable[T]],
|
|
272
|
+
) -> T:
|
|
273
|
+
async with record._execution_lock:
|
|
274
|
+
if seq_id is not None:
|
|
275
|
+
self._reserve_seq_id(record, seq_id)
|
|
276
|
+
# Save the updated next_seq_id to Redis
|
|
277
|
+
self._save_training_run(record.training_run_id)
|
|
278
|
+
return await operation()
|
|
279
|
+
|
|
280
|
+
def _reserve_seq_id(self, record: TrainingRunRecord, seq_id: int) -> None:
|
|
281
|
+
expected = record.next_seq_id
|
|
282
|
+
if seq_id != expected:
|
|
283
|
+
raise SequenceConflictException(expected=expected, got=seq_id)
|
|
284
|
+
record.next_seq_id += 1
|
|
285
|
+
|
|
286
|
+
async def create_model(
|
|
287
|
+
self,
|
|
288
|
+
session_id: str,
|
|
289
|
+
base_model: str,
|
|
290
|
+
lora_config: types.LoraConfig,
|
|
291
|
+
model_owner: str,
|
|
292
|
+
user_metadata: dict[str, str] | None,
|
|
293
|
+
) -> TrainingRunRecord:
|
|
294
|
+
model_id = str(uuid.uuid4())
|
|
295
|
+
with _get_tracer().start_as_current_span("training_controller.create_model") as span:
|
|
296
|
+
span.set_attribute("tuft.training_run_id", model_id)
|
|
297
|
+
span.set_attribute("tuft.session_id", session_id)
|
|
298
|
+
span.set_attribute("tuft.base_model", base_model)
|
|
299
|
+
span.set_attribute("tuft.lora_rank", lora_config.rank)
|
|
300
|
+
try:
|
|
301
|
+
logger.info("Creating model %s", model_id)
|
|
302
|
+
|
|
303
|
+
if base_model not in self.training_backends:
|
|
304
|
+
raise UnknownModelException(model_name=base_model)
|
|
305
|
+
backend = self.training_backends[base_model]
|
|
306
|
+
record = TrainingRunRecord(
|
|
307
|
+
training_run_id=model_id,
|
|
308
|
+
base_model=base_model,
|
|
309
|
+
lora_rank=lora_config.rank,
|
|
310
|
+
session_id=session_id,
|
|
311
|
+
model_owner=model_owner,
|
|
312
|
+
user_metadata=user_metadata,
|
|
313
|
+
backend=backend,
|
|
314
|
+
)
|
|
315
|
+
await backend.create_adapter(model_id, lora_config)
|
|
316
|
+
self.training_runs[model_id] = record
|
|
317
|
+
self._save_training_run(model_id)
|
|
318
|
+
|
|
319
|
+
# Update metrics
|
|
320
|
+
get_metrics().training_models_active.add(1, {"base_model": base_model})
|
|
321
|
+
return record
|
|
322
|
+
except Exception as e:
|
|
323
|
+
span.record_exception(e)
|
|
324
|
+
span.set_status(StatusCode.ERROR)
|
|
325
|
+
raise
|
|
326
|
+
|
|
327
|
+
def get_run_record(
|
|
328
|
+
self,
|
|
329
|
+
model_id: str,
|
|
330
|
+
user_id: str,
|
|
331
|
+
enforce_user_match: bool = True,
|
|
332
|
+
) -> TrainingRunRecord:
|
|
333
|
+
record = self.training_runs.get(model_id)
|
|
334
|
+
if record is None:
|
|
335
|
+
raise UnknownModelException(model_name=model_id)
|
|
336
|
+
if enforce_user_match and record.model_owner != user_id:
|
|
337
|
+
raise UserMismatchException()
|
|
338
|
+
return record
|
|
339
|
+
|
|
340
|
+
def build_supported_models(self) -> list[types.SupportedModel]:
|
|
341
|
+
return [
|
|
342
|
+
types.SupportedModel(model_name=model.model_name)
|
|
343
|
+
for model in self.config.supported_models
|
|
344
|
+
]
|
|
345
|
+
|
|
346
|
+
def update_activity(self, model_id: str, user_id: str) -> None:
|
|
347
|
+
record = self.get_run_record(model_id, user_id)
|
|
348
|
+
record.last_request_time = datetime.now(timezone.utc)
|
|
349
|
+
self._save_training_run(model_id)
|
|
350
|
+
|
|
351
|
+
async def run_forward(
|
|
352
|
+
self,
|
|
353
|
+
model_id: str,
|
|
354
|
+
user_id: str,
|
|
355
|
+
data: list[types.Datum],
|
|
356
|
+
loss_fn: types.LossFnType,
|
|
357
|
+
loss_fn_config: dict[str, float] | None,
|
|
358
|
+
seq_id: int | None,
|
|
359
|
+
*,
|
|
360
|
+
backward: bool,
|
|
361
|
+
) -> types.ForwardBackwardOutput:
|
|
362
|
+
record = self.get_run_record(model_id, user_id)
|
|
363
|
+
self.update_activity(model_id, user_id)
|
|
364
|
+
|
|
365
|
+
span_name = (
|
|
366
|
+
"training_controller.run_forward_backward"
|
|
367
|
+
if backward
|
|
368
|
+
else "training_controller.run_forward"
|
|
369
|
+
)
|
|
370
|
+
with _get_tracer().start_as_current_span(span_name) as span:
|
|
371
|
+
span.set_attribute("tuft.training_run_id", model_id)
|
|
372
|
+
span.set_attribute("tuft.session_id", record.session_id)
|
|
373
|
+
span.set_attribute("tuft.backward", backward)
|
|
374
|
+
span.set_attribute("tuft.data_count", len(data))
|
|
375
|
+
span.set_attribute("tuft.loss_fn", loss_fn)
|
|
376
|
+
|
|
377
|
+
logger.info("Forward/backward begin for %s", model_id)
|
|
378
|
+
start_time = time.perf_counter()
|
|
379
|
+
|
|
380
|
+
# Count total input tokens for metrics
|
|
381
|
+
total_tokens = sum(len(datum.model_input.to_ints()) for datum in data)
|
|
382
|
+
|
|
383
|
+
async def _operation() -> types.ForwardBackwardOutput:
|
|
384
|
+
if record.backend is None:
|
|
385
|
+
raise UnknownModelException(model_name=model_id)
|
|
386
|
+
result = await record.backend.forward(
|
|
387
|
+
data,
|
|
388
|
+
lora_id=model_id,
|
|
389
|
+
loss_fn=loss_fn,
|
|
390
|
+
loss_fn_config=loss_fn_config,
|
|
391
|
+
backward=backward,
|
|
392
|
+
)
|
|
393
|
+
logger.info("Forward/backward completed for %s", model_id)
|
|
394
|
+
return result
|
|
395
|
+
|
|
396
|
+
result = await self._with_sequence_guard(record, seq_id, _operation)
|
|
397
|
+
|
|
398
|
+
# Record tokens per second metric
|
|
399
|
+
duration = time.perf_counter() - start_time
|
|
400
|
+
if total_tokens > 0 and duration > 0:
|
|
401
|
+
tokens_per_second = total_tokens / duration
|
|
402
|
+
get_metrics().training_tokens_per_second.record(
|
|
403
|
+
tokens_per_second, {"base_model": record.base_model}
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
return result
|
|
407
|
+
|
|
408
|
+
async def run_optim_step(
|
|
409
|
+
self, model_id: str, user_id: str, params: types.AdamParams, seq_id: int | None
|
|
410
|
+
) -> types.OptimStepResponse:
|
|
411
|
+
record = self.get_run_record(model_id, user_id)
|
|
412
|
+
self.update_activity(model_id, user_id)
|
|
413
|
+
|
|
414
|
+
with _get_tracer().start_as_current_span("training_controller.run_optim_step") as span:
|
|
415
|
+
span.set_attribute("tuft.training_run_id", model_id)
|
|
416
|
+
span.set_attribute("tuft.session_id", record.session_id)
|
|
417
|
+
span.set_attribute("tuft.learning_rate", params.learning_rate)
|
|
418
|
+
|
|
419
|
+
logger.info("Optimizer step begin for %s", model_id)
|
|
420
|
+
|
|
421
|
+
async def _operation() -> types.OptimStepResponse:
|
|
422
|
+
if record.backend is None:
|
|
423
|
+
raise UnknownModelException(model_name=model_id)
|
|
424
|
+
result = await record.backend.optim_step(adam_params=params, lora_id=model_id)
|
|
425
|
+
logger.info("Optimizer step completed for %s", model_id)
|
|
426
|
+
return result
|
|
427
|
+
|
|
428
|
+
return await self._with_sequence_guard(record, seq_id, _operation)
|
|
429
|
+
|
|
430
|
+
async def unload_model(self, model_id: str, user_id: str) -> None:
|
|
431
|
+
# TODO: Ensure that all created training runs can be unloaded to reduce
|
|
432
|
+
# GPU memory usage.
|
|
433
|
+
if model_id not in self.training_runs:
|
|
434
|
+
raise UnknownModelException(model_name=model_id)
|
|
435
|
+
record = self.training_runs[model_id]
|
|
436
|
+
if record.model_owner != user_id:
|
|
437
|
+
raise UserMismatchException()
|
|
438
|
+
base_model = record.base_model
|
|
439
|
+
if record.backend is not None:
|
|
440
|
+
await record.backend.remove_adapter(model_id)
|
|
441
|
+
del self.training_runs[model_id]
|
|
442
|
+
self._delete_training_run(model_id)
|
|
443
|
+
|
|
444
|
+
# Update metrics
|
|
445
|
+
get_metrics().training_models_active.add(-1, {"base_model": base_model})
|
|
446
|
+
|
|
447
|
+
def list_training_runs(
|
|
448
|
+
self, *, user_id: str, limit: int | None = None, offset: int = 0
|
|
449
|
+
) -> types.TrainingRunsResponse:
|
|
450
|
+
runs = [
|
|
451
|
+
record.to_training_run()
|
|
452
|
+
for record in self.training_runs.values()
|
|
453
|
+
if record.model_owner == user_id
|
|
454
|
+
]
|
|
455
|
+
runs.sort(key=lambda run: run.last_request_time, reverse=True)
|
|
456
|
+
total = len(runs)
|
|
457
|
+
start = min(offset, total)
|
|
458
|
+
end = total if limit is None else min(start + limit, total)
|
|
459
|
+
paged = runs[start:end]
|
|
460
|
+
cursor = types.Cursor(offset=offset, limit=limit or total, total_count=total)
|
|
461
|
+
return types.TrainingRunsResponse(training_runs=paged, cursor=cursor)
|
|
462
|
+
|
|
463
|
+
def get_training_run_view(self, model_id: str, user_id: str) -> types.TrainingRun:
|
|
464
|
+
record = self.get_run_record(model_id=model_id, user_id=user_id)
|
|
465
|
+
return record.to_training_run()
|
|
466
|
+
|
|
467
|
+
def get_model_info(self, model_id: str, user_id: str) -> types.GetInfoResponse:
|
|
468
|
+
record = self.get_run_record(model_id=model_id, user_id=user_id)
|
|
469
|
+
model_data = types.ModelData(
|
|
470
|
+
arch="toy-transformer",
|
|
471
|
+
model_name=record.base_model,
|
|
472
|
+
tokenizer_id=record.base_model,
|
|
473
|
+
)
|
|
474
|
+
return types.GetInfoResponse(
|
|
475
|
+
model_data=model_data,
|
|
476
|
+
model_id=model_id,
|
|
477
|
+
is_lora=True,
|
|
478
|
+
lora_rank=record.lora_rank,
|
|
479
|
+
model_name=record.base_model,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
async def save_checkpoint(
|
|
483
|
+
self,
|
|
484
|
+
model_id: str,
|
|
485
|
+
user_id: str,
|
|
486
|
+
name: str | None,
|
|
487
|
+
checkpoint_type: types.CheckpointType,
|
|
488
|
+
future_id: int = 0,
|
|
489
|
+
seq_id: int | None = None,
|
|
490
|
+
) -> CheckpointRecord:
|
|
491
|
+
"""Save a checkpoint for the given training run."""
|
|
492
|
+
training_run = self.get_run_record(model_id=model_id, user_id=user_id)
|
|
493
|
+
|
|
494
|
+
with _get_tracer().start_as_current_span("training_controller.save_checkpoint") as span:
|
|
495
|
+
span.set_attribute("tuft.training_run_id", model_id)
|
|
496
|
+
span.set_attribute("tuft.session_id", training_run.session_id)
|
|
497
|
+
span.set_attribute("tuft.checkpoint_type", checkpoint_type)
|
|
498
|
+
|
|
499
|
+
async def _operation() -> CheckpointRecord:
|
|
500
|
+
counter_attr = (
|
|
501
|
+
"next_training_checkpoint"
|
|
502
|
+
if checkpoint_type == "training"
|
|
503
|
+
else "next_sampler_checkpoint"
|
|
504
|
+
)
|
|
505
|
+
counter = getattr(training_run, counter_attr)
|
|
506
|
+
checkpoint_name = name or f"checkpoint-{counter:04d}"
|
|
507
|
+
checkpoint_id = f"{model_id}/{checkpoint_name}"
|
|
508
|
+
logger.info("Checkpoint save begin: %s", checkpoint_id)
|
|
509
|
+
|
|
510
|
+
setattr(training_run, counter_attr, counter + 1)
|
|
511
|
+
checkpoint = CheckpointRecord.from_training_run(
|
|
512
|
+
training_run_id=training_run.training_run_id,
|
|
513
|
+
checkpoint_name=checkpoint_name,
|
|
514
|
+
owner_name=training_run.model_owner,
|
|
515
|
+
checkpoint_type=checkpoint_type,
|
|
516
|
+
checkpoint_root_dir=self.config.checkpoint_dir,
|
|
517
|
+
exist_ok=True,
|
|
518
|
+
)
|
|
519
|
+
checkpoint.future_id = future_id
|
|
520
|
+
checkpoint.seq_id = seq_id
|
|
521
|
+
target_map = (
|
|
522
|
+
training_run.checkpoints
|
|
523
|
+
if checkpoint_type == "training"
|
|
524
|
+
else training_run.sampler_checkpoints
|
|
525
|
+
)
|
|
526
|
+
if training_run.backend is not None:
|
|
527
|
+
await training_run.backend.save_state(
|
|
528
|
+
lora_id=training_run.training_run_id,
|
|
529
|
+
checkpoint_record=checkpoint,
|
|
530
|
+
optimizer=(checkpoint_type == "training"),
|
|
531
|
+
)
|
|
532
|
+
checkpoint.size_bytes = checkpoint.path.stat().st_size
|
|
533
|
+
checkpoint.save_metadata(
|
|
534
|
+
base_model=training_run.base_model,
|
|
535
|
+
session_id=training_run.session_id,
|
|
536
|
+
lora_rank=training_run.lora_rank,
|
|
537
|
+
)
|
|
538
|
+
# save the checkpoint record in the training run
|
|
539
|
+
target_map[checkpoint_name] = checkpoint
|
|
540
|
+
|
|
541
|
+
# Save training run and checkpoint atomically to prevent inconsistency
|
|
542
|
+
# if server crashes between saves
|
|
543
|
+
self._save_training_run_with_checkpoint(model_id, checkpoint_name, checkpoint_type)
|
|
544
|
+
|
|
545
|
+
# Update metrics
|
|
546
|
+
metrics = get_metrics()
|
|
547
|
+
metrics.training_checkpoints_saved.add(
|
|
548
|
+
1, {"model_id": model_id, "checkpoint_type": checkpoint_type}
|
|
549
|
+
)
|
|
550
|
+
logger.info("Checkpoint saved: %s", checkpoint_id)
|
|
551
|
+
metrics.training_checkpoint_size.record(
|
|
552
|
+
checkpoint.size_bytes,
|
|
553
|
+
{"model_id": model_id, "checkpoint_type": checkpoint_type},
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
return checkpoint
|
|
557
|
+
|
|
558
|
+
return await self._with_sequence_guard(training_run, seq_id, _operation)
|
|
559
|
+
|
|
560
|
+
async def load_checkpoint(
|
|
561
|
+
self,
|
|
562
|
+
model_id: str,
|
|
563
|
+
user_id: str,
|
|
564
|
+
path: str,
|
|
565
|
+
optimizer: bool,
|
|
566
|
+
seq_id: int | None = None,
|
|
567
|
+
) -> None:
|
|
568
|
+
"""Load a checkpoint."""
|
|
569
|
+
try:
|
|
570
|
+
parsed_checkpoint = CheckpointRecord.from_tinker_path(path, self.config.checkpoint_dir)
|
|
571
|
+
except FileNotFoundError as exc:
|
|
572
|
+
raise CheckpointNotFoundException(checkpoint_id=model_id) from exc
|
|
573
|
+
source_model_id = parsed_checkpoint.training_run_id or model_id
|
|
574
|
+
training_run = self.get_run_record(source_model_id, user_id, enforce_user_match=False)
|
|
575
|
+
|
|
576
|
+
collection = (
|
|
577
|
+
training_run.checkpoints
|
|
578
|
+
if parsed_checkpoint.checkpoint_type == "training"
|
|
579
|
+
else training_run.sampler_checkpoints
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
checkpoint = collection.get(parsed_checkpoint.checkpoint_id)
|
|
583
|
+
if checkpoint is None:
|
|
584
|
+
raise CheckpointNotFoundException(checkpoint_id=parsed_checkpoint.checkpoint_id)
|
|
585
|
+
try:
|
|
586
|
+
metadata = checkpoint.metadata
|
|
587
|
+
except FileNotFoundError as exc:
|
|
588
|
+
raise CheckpointMetadataReadException(
|
|
589
|
+
checkpoint_id=parsed_checkpoint.checkpoint_id
|
|
590
|
+
) from exc
|
|
591
|
+
if metadata.public or (metadata.owner_name == user_id):
|
|
592
|
+
if training_run.backend is None:
|
|
593
|
+
raise UnknownModelException(model_name=model_id)
|
|
594
|
+
|
|
595
|
+
checkpoint_id = parsed_checkpoint.checkpoint_id
|
|
596
|
+
logger.info("Checkpoint load begin: %s", checkpoint_id)
|
|
597
|
+
|
|
598
|
+
async def _operation() -> None:
|
|
599
|
+
assert training_run.backend is not None
|
|
600
|
+
await training_run.backend.load_state(
|
|
601
|
+
lora_id=training_run.training_run_id,
|
|
602
|
+
checkpoint_record=checkpoint,
|
|
603
|
+
optimizer=optimizer,
|
|
604
|
+
)
|
|
605
|
+
logger.info("Checkpoint loaded: %s", checkpoint_id)
|
|
606
|
+
|
|
607
|
+
await self._with_sequence_guard(training_run, seq_id, _operation)
|
|
608
|
+
else:
|
|
609
|
+
raise CheckpointAccessDeniedException(checkpoint_id=parsed_checkpoint.checkpoint_id)
|
|
610
|
+
|
|
611
|
+
def delete_checkpoint(self, model_id: str, user_id: str, checkpoint_id: str) -> None:
|
|
612
|
+
training_run = self.get_run_record(model_id, user_id)
|
|
613
|
+
removed = training_run.checkpoints.pop(checkpoint_id, None)
|
|
614
|
+
is_sampler = False
|
|
615
|
+
if removed is None:
|
|
616
|
+
removed = training_run.sampler_checkpoints.pop(checkpoint_id, None)
|
|
617
|
+
is_sampler = True
|
|
618
|
+
if removed is None:
|
|
619
|
+
raise CheckpointNotFoundException(checkpoint_id=checkpoint_id)
|
|
620
|
+
removed.delete()
|
|
621
|
+
|
|
622
|
+
self._save_training_run(model_id)
|
|
623
|
+
if is_sampler:
|
|
624
|
+
self._delete_sampler_checkpoint_record(model_id, checkpoint_id)
|
|
625
|
+
else:
|
|
626
|
+
self._delete_checkpoint_record(model_id, checkpoint_id)
|
|
627
|
+
|
|
628
|
+
def list_checkpoints(self, model_id: str, user_id: str) -> list[types.Checkpoint]:
|
|
629
|
+
training_run = self.get_run_record(model_id, user_id)
|
|
630
|
+
checkpoints = [item.tinker_checkpoint for item in training_run.checkpoints.values()]
|
|
631
|
+
checkpoints += [
|
|
632
|
+
item.tinker_checkpoint for item in training_run.sampler_checkpoints.values()
|
|
633
|
+
]
|
|
634
|
+
checkpoints.sort(key=lambda ckpt: ckpt.time)
|
|
635
|
+
return checkpoints
|
|
636
|
+
|
|
637
|
+
def list_user_checkpoints(
|
|
638
|
+
self,
|
|
639
|
+
user_id: str,
|
|
640
|
+
) -> list[types.Checkpoint]:
|
|
641
|
+
checkpoints: list[types.Checkpoint] = []
|
|
642
|
+
training_runs = [run for run in self.training_runs.values() if run.model_owner == user_id]
|
|
643
|
+
for run in training_runs:
|
|
644
|
+
checkpoints.extend([item.tinker_checkpoint for item in run.checkpoints.values()])
|
|
645
|
+
checkpoints.sort(key=lambda item: item.time, reverse=True)
|
|
646
|
+
return checkpoints
|
|
647
|
+
|
|
648
|
+
def set_visibility(
|
|
649
|
+
self, model_id: str, checkpoint_id: str, user_id: str, *, public: bool
|
|
650
|
+
) -> None:
|
|
651
|
+
training_run = self.get_run_record(model_id=model_id, user_id=user_id)
|
|
652
|
+
target = training_run.checkpoints.get(checkpoint_id)
|
|
653
|
+
is_sampler = False
|
|
654
|
+
if target is None:
|
|
655
|
+
target = training_run.sampler_checkpoints.get(checkpoint_id)
|
|
656
|
+
is_sampler = True
|
|
657
|
+
if target is None:
|
|
658
|
+
raise CheckpointNotFoundException(checkpoint_id=checkpoint_id)
|
|
659
|
+
target.set_visibility(public)
|
|
660
|
+
|
|
661
|
+
if is_sampler:
|
|
662
|
+
self._save_sampler_checkpoint(model_id, checkpoint_id)
|
|
663
|
+
else:
|
|
664
|
+
self._save_checkpoint(model_id, checkpoint_id)
|
|
665
|
+
|
|
666
|
+
def build_archive_url(
|
|
667
|
+
self,
|
|
668
|
+
model_id: str,
|
|
669
|
+
user_id: str,
|
|
670
|
+
checkpoint_id: str,
|
|
671
|
+
) -> types.CheckpointArchiveUrlResponse:
|
|
672
|
+
training_run = self.get_run_record(model_id, user_id)
|
|
673
|
+
checkpoint = training_run.checkpoints.get(
|
|
674
|
+
checkpoint_id
|
|
675
|
+
) or training_run.sampler_checkpoints.get(checkpoint_id)
|
|
676
|
+
if checkpoint is None:
|
|
677
|
+
raise CheckpointNotFoundException(checkpoint_id=checkpoint_id)
|
|
678
|
+
expires = datetime.now(timezone.utc) + timedelta(minutes=15)
|
|
679
|
+
return types.CheckpointArchiveUrlResponse(url=checkpoint.path.as_uri(), expires=expires)
|
|
680
|
+
|
|
681
|
+
def get_weights_info(self, model_id: str, user_id: str) -> types.WeightsInfoResponse:
|
|
682
|
+
training_run = self.get_run_record(model_id, user_id)
|
|
683
|
+
return types.WeightsInfoResponse(
|
|
684
|
+
base_model=training_run.base_model,
|
|
685
|
+
is_lora=True,
|
|
686
|
+
lora_rank=training_run.lora_rank,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
def get_latest_checkpoint(self, model_id: str) -> CheckpointRecord | None:
|
|
690
|
+
record = self.training_runs.get(model_id)
|
|
691
|
+
if record is None:
|
|
692
|
+
return None
|
|
693
|
+
all_checkpoints = list(record.checkpoints.values()) + list(
|
|
694
|
+
record.sampler_checkpoints.values()
|
|
695
|
+
)
|
|
696
|
+
if not all_checkpoints:
|
|
697
|
+
return None
|
|
698
|
+
return max(all_checkpoints, key=lambda c: c.created_at)
|
|
699
|
+
|
|
700
|
+
async def restore_from_checkpoint(self, model_id: str) -> CheckpointRecord | None:
|
|
701
|
+
latest_ckpt = self.get_latest_checkpoint(model_id)
|
|
702
|
+
if latest_ckpt is None:
|
|
703
|
+
return None
|
|
704
|
+
record = self.training_runs.get(model_id)
|
|
705
|
+
if record is None or record.backend is None:
|
|
706
|
+
return None
|
|
707
|
+
try:
|
|
708
|
+
await record.backend.create_adapter(model_id, types.LoraConfig(rank=record.lora_rank))
|
|
709
|
+
except Exception:
|
|
710
|
+
logger.exception("Failed to create adapter for model %s during restore", model_id)
|
|
711
|
+
try:
|
|
712
|
+
await record.backend.load_state(
|
|
713
|
+
lora_id=model_id,
|
|
714
|
+
checkpoint_record=latest_ckpt,
|
|
715
|
+
optimizer=(latest_ckpt.checkpoint_type == "training"),
|
|
716
|
+
)
|
|
717
|
+
except Exception: # pylint: disable=broad-except
|
|
718
|
+
# If loading fails, mark as corrupted
|
|
719
|
+
record.corrupted = True
|
|
720
|
+
self._save_training_run(model_id)
|
|
721
|
+
return None
|
|
722
|
+
|
|
723
|
+
return latest_ckpt
|