tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,723 @@
1
+ """Training controller for managing training runs and routing requests."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import time
8
+ import uuid
9
+ from datetime import datetime, timedelta, timezone
10
+ from typing import Awaitable, Callable, Dict, List, TypeVar
11
+
12
+ from opentelemetry.trace import StatusCode
13
+ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
14
+ from tinker import types
15
+
16
+ from .backends import BaseTrainingBackend
17
+ from .checkpoints import CheckpointRecord
18
+ from .config import AppConfig, ModelConfig
19
+ from .exceptions import (
20
+ CheckpointAccessDeniedException,
21
+ CheckpointMetadataReadException,
22
+ CheckpointNotFoundException,
23
+ SequenceConflictException,
24
+ UnknownModelException,
25
+ UserMismatchException,
26
+ )
27
+ from .persistence import (
28
+ delete_record,
29
+ get_redis_store,
30
+ is_persistence_enabled,
31
+ load_record,
32
+ save_record,
33
+ save_records_atomic,
34
+ )
35
+ from .telemetry.metrics import get_metrics
36
+ from .telemetry.tracing import get_tracer
37
+
38
+
39
+ _get_tracer = lambda: get_tracer("tuft.training_controller") # noqa: E731
40
+
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ T = TypeVar("T")
45
+
46
+
47
+ def _now() -> datetime:
48
+ return datetime.now(timezone.utc)
49
+
50
+
51
+ class TrainingRunRecord(BaseModel):
52
+ """Training run record with persistence support.
53
+
54
+ Runtime-only fields (backend, _execution_lock) are excluded from serialization.
55
+ Checkpoints are stored separately with their own keys.
56
+ """
57
+
58
+ model_config = ConfigDict(arbitrary_types_allowed=True)
59
+
60
+ training_run_id: str
61
+ base_model: str
62
+ lora_rank: int
63
+ session_id: str
64
+ model_owner: str
65
+ user_metadata: dict[str, str] | None = None
66
+ created_at: datetime = Field(default_factory=_now)
67
+ last_request_time: datetime = Field(default_factory=_now)
68
+ # Checkpoints are stored separately, excluded from serialization
69
+ checkpoints: Dict[str, CheckpointRecord] = Field(default_factory=dict, exclude=True)
70
+ sampler_checkpoints: Dict[str, CheckpointRecord] = Field(default_factory=dict, exclude=True)
71
+ next_training_checkpoint: int = 1
72
+ next_sampler_checkpoint: int = 1
73
+ corrupted: bool = False
74
+ next_seq_id: int = 1
75
+ # Runtime-only fields, excluded from serialization
76
+ backend: BaseTrainingBackend | None = Field(default=None, exclude=True)
77
+ # Private attribute for execution lock (not a model field)
78
+ _execution_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
79
+
80
+ def to_training_run(self) -> types.TrainingRun:
81
+ training_checkpoint = self._latest_checkpoint(self.checkpoints)
82
+ sampler_checkpoint = self._latest_checkpoint(self.sampler_checkpoints)
83
+ return types.TrainingRun(
84
+ training_run_id=self.training_run_id,
85
+ base_model=self.base_model,
86
+ model_owner=self.model_owner,
87
+ is_lora=True,
88
+ corrupted=self.corrupted,
89
+ lora_rank=self.lora_rank,
90
+ last_request_time=self.last_request_time,
91
+ last_checkpoint=training_checkpoint,
92
+ last_sampler_checkpoint=sampler_checkpoint,
93
+ user_metadata=self.user_metadata,
94
+ )
95
+
96
+ def _latest_checkpoint(self, items: Dict[str, CheckpointRecord]) -> types.Checkpoint | None:
97
+ if not items:
98
+ return None
99
+ latest = max(items.values(), key=lambda record: record.created_at)
100
+ return latest.tinker_checkpoint
101
+
102
+
103
+ class TrainingController:
104
+ """Tracks training runs, enforces request ordering.
105
+
106
+ Routes work into ModelBackend instances.
107
+ """
108
+
109
+ REDIS_KEY_PREFIX = "training_run"
110
+
111
+ def __init__(self, config: AppConfig) -> None:
112
+ self.config = config
113
+ self.training_backends = self._create_backends(config.supported_models)
114
+ # TODO: add a mechanism to manage training_runs
115
+ self.training_runs: Dict[str, TrainingRunRecord] = {}
116
+ self._restore_from_redis()
117
+
118
+ def _create_backends(self, model_configs: List[ModelConfig]) -> Dict[str, BaseTrainingBackend]:
119
+ backends: Dict[str, BaseTrainingBackend] = {}
120
+ for config in model_configs:
121
+ backends[config.model_name] = BaseTrainingBackend.create_backend(config)
122
+ return backends
123
+
124
+ def _build_key(self, model_id: str) -> str:
125
+ return get_redis_store().build_key(self.REDIS_KEY_PREFIX, model_id)
126
+
127
+ def _build_checkpoint_key(self, model_id: str, checkpoint_id: str) -> str:
128
+ return get_redis_store().build_key(self.REDIS_KEY_PREFIX, model_id, "ckpt", checkpoint_id)
129
+
130
+ def _build_sampler_checkpoint_key(self, model_id: str, checkpoint_id: str) -> str:
131
+ return get_redis_store().build_key(
132
+ self.REDIS_KEY_PREFIX, model_id, "sampler_ckpt", checkpoint_id
133
+ )
134
+
135
+ def _restore_from_redis(self) -> None:
136
+ """Restore training runs from Redis on startup."""
137
+ if not is_persistence_enabled():
138
+ return
139
+ store = get_redis_store()
140
+ # Match only top-level training runs (3 parts: namespace::prefix::model_id)
141
+ for key in store.keys(store.build_key(self.REDIS_KEY_PREFIX, "*")):
142
+ parts = key.split("::")
143
+ if len(parts) != 3:
144
+ continue
145
+ record = load_record(key, TrainingRunRecord)
146
+ if record is None:
147
+ continue
148
+ model_id = record.training_run_id
149
+ # Restore checkpoints (stored separately, not subject to TTL)
150
+ self._restore_checkpoints(model_id, record)
151
+ self._restore_sampler_checkpoints(model_id, record)
152
+ # Restore backend reference
153
+ if record.base_model in self.training_backends:
154
+ record.backend = self.training_backends[record.base_model]
155
+ else:
156
+ record.corrupted = True
157
+ self.training_runs[model_id] = record
158
+
159
+ def _restore_checkpoints(self, model_id: str, record: TrainingRunRecord) -> None:
160
+ store = get_redis_store()
161
+ pattern = self._build_checkpoint_key(model_id, "*")
162
+ record.checkpoints = {}
163
+ for key in store.keys(pattern):
164
+ ckpt = load_record(key, CheckpointRecord)
165
+ if ckpt is not None:
166
+ record.checkpoints[ckpt.checkpoint_id] = ckpt
167
+
168
+ def _restore_sampler_checkpoints(self, model_id: str, record: TrainingRunRecord) -> None:
169
+ store = get_redis_store()
170
+ pattern = self._build_sampler_checkpoint_key(model_id, "*")
171
+ record.sampler_checkpoints = {}
172
+ for key in store.keys(pattern):
173
+ ckpt = load_record(key, CheckpointRecord)
174
+ if ckpt is not None:
175
+ record.sampler_checkpoints[ckpt.checkpoint_id] = ckpt
176
+
177
+ def _save_training_run(self, model_id: str) -> None:
178
+ """Save training run to Redis (no TTL - permanent record)."""
179
+ if not is_persistence_enabled():
180
+ return
181
+ record = self.training_runs.get(model_id)
182
+ if record is not None:
183
+ save_record(self._build_key(model_id), record)
184
+
185
+ def _save_checkpoint(self, model_id: str, checkpoint_id: str) -> None:
186
+ """Save checkpoint to Redis (no TTL - permanent record)."""
187
+ if not is_persistence_enabled():
188
+ return
189
+ record = self.training_runs.get(model_id)
190
+ if record is not None:
191
+ ckpt = record.checkpoints.get(checkpoint_id)
192
+ if ckpt is not None:
193
+ save_record(self._build_checkpoint_key(model_id, checkpoint_id), ckpt)
194
+
195
+ def _save_sampler_checkpoint(self, model_id: str, checkpoint_id: str) -> None:
196
+ """Save sampler checkpoint to Redis (no TTL - permanent record)."""
197
+ if not is_persistence_enabled():
198
+ return
199
+ record = self.training_runs.get(model_id)
200
+ if record is not None:
201
+ ckpt = record.sampler_checkpoints.get(checkpoint_id)
202
+ if ckpt is not None:
203
+ save_record(self._build_sampler_checkpoint_key(model_id, checkpoint_id), ckpt)
204
+
205
+ def _save_training_run_with_checkpoint(
206
+ self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType
207
+ ) -> None:
208
+ """Save training run and checkpoint atomically using Redis transaction.
209
+
210
+ This ensures consistency if the server crashes between saves.
211
+ No TTL is used for these records as they are permanent.
212
+ """
213
+ if not is_persistence_enabled():
214
+ return
215
+ record = self.training_runs.get(model_id)
216
+ if record is None:
217
+ return
218
+
219
+ if checkpoint_type == "training":
220
+ ckpt = record.checkpoints.get(checkpoint_id)
221
+ ckpt_key = self._build_checkpoint_key(model_id, checkpoint_id)
222
+ else:
223
+ ckpt = record.sampler_checkpoints.get(checkpoint_id)
224
+ ckpt_key = self._build_sampler_checkpoint_key(model_id, checkpoint_id)
225
+
226
+ if ckpt is None:
227
+ # Defensive fallback: checkpoint should exist at this point since
228
+ # _save_training_run_with_checkpoint is called after adding the checkpoint
229
+ # to the target_map. This branch handles unexpected edge cases (e.g., code
230
+ # refactoring that changes call order) to ensure the training run is still
231
+ # persisted even if the checkpoint lookup fails.
232
+ logger.warning(
233
+ "Checkpoint %s not found for model %s during persistence, "
234
+ "saving training run without checkpoint",
235
+ checkpoint_id,
236
+ model_id,
237
+ )
238
+ save_record(self._build_key(model_id), record)
239
+ return
240
+
241
+ # Save both atomically (no TTL for permanent records)
242
+ save_records_atomic(
243
+ [
244
+ (self._build_key(model_id), record),
245
+ (ckpt_key, ckpt),
246
+ ]
247
+ )
248
+
249
+ def _delete_training_run(self, model_id: str) -> None:
250
+ if not is_persistence_enabled():
251
+ return
252
+ store = get_redis_store()
253
+ store.delete(self._build_key(model_id))
254
+ store.delete_pattern(self._build_checkpoint_key(model_id, "*"))
255
+ store.delete_pattern(self._build_sampler_checkpoint_key(model_id, "*"))
256
+
257
+ def _delete_checkpoint_record(self, model_id: str, checkpoint_id: str) -> None:
258
+ if not is_persistence_enabled():
259
+ return
260
+ delete_record(self._build_checkpoint_key(model_id, checkpoint_id))
261
+
262
+ def _delete_sampler_checkpoint_record(self, model_id: str, checkpoint_id: str) -> None:
263
+ if not is_persistence_enabled():
264
+ return
265
+ delete_record(self._build_sampler_checkpoint_key(model_id, checkpoint_id))
266
+
267
+ async def _with_sequence_guard(
268
+ self,
269
+ record: TrainingRunRecord,
270
+ seq_id: int | None,
271
+ operation: Callable[[], Awaitable[T]],
272
+ ) -> T:
273
+ async with record._execution_lock:
274
+ if seq_id is not None:
275
+ self._reserve_seq_id(record, seq_id)
276
+ # Save the updated next_seq_id to Redis
277
+ self._save_training_run(record.training_run_id)
278
+ return await operation()
279
+
280
+ def _reserve_seq_id(self, record: TrainingRunRecord, seq_id: int) -> None:
281
+ expected = record.next_seq_id
282
+ if seq_id != expected:
283
+ raise SequenceConflictException(expected=expected, got=seq_id)
284
+ record.next_seq_id += 1
285
+
286
+ async def create_model(
287
+ self,
288
+ session_id: str,
289
+ base_model: str,
290
+ lora_config: types.LoraConfig,
291
+ model_owner: str,
292
+ user_metadata: dict[str, str] | None,
293
+ ) -> TrainingRunRecord:
294
+ model_id = str(uuid.uuid4())
295
+ with _get_tracer().start_as_current_span("training_controller.create_model") as span:
296
+ span.set_attribute("tuft.training_run_id", model_id)
297
+ span.set_attribute("tuft.session_id", session_id)
298
+ span.set_attribute("tuft.base_model", base_model)
299
+ span.set_attribute("tuft.lora_rank", lora_config.rank)
300
+ try:
301
+ logger.info("Creating model %s", model_id)
302
+
303
+ if base_model not in self.training_backends:
304
+ raise UnknownModelException(model_name=base_model)
305
+ backend = self.training_backends[base_model]
306
+ record = TrainingRunRecord(
307
+ training_run_id=model_id,
308
+ base_model=base_model,
309
+ lora_rank=lora_config.rank,
310
+ session_id=session_id,
311
+ model_owner=model_owner,
312
+ user_metadata=user_metadata,
313
+ backend=backend,
314
+ )
315
+ await backend.create_adapter(model_id, lora_config)
316
+ self.training_runs[model_id] = record
317
+ self._save_training_run(model_id)
318
+
319
+ # Update metrics
320
+ get_metrics().training_models_active.add(1, {"base_model": base_model})
321
+ return record
322
+ except Exception as e:
323
+ span.record_exception(e)
324
+ span.set_status(StatusCode.ERROR)
325
+ raise
326
+
327
+ def get_run_record(
328
+ self,
329
+ model_id: str,
330
+ user_id: str,
331
+ enforce_user_match: bool = True,
332
+ ) -> TrainingRunRecord:
333
+ record = self.training_runs.get(model_id)
334
+ if record is None:
335
+ raise UnknownModelException(model_name=model_id)
336
+ if enforce_user_match and record.model_owner != user_id:
337
+ raise UserMismatchException()
338
+ return record
339
+
340
+ def build_supported_models(self) -> list[types.SupportedModel]:
341
+ return [
342
+ types.SupportedModel(model_name=model.model_name)
343
+ for model in self.config.supported_models
344
+ ]
345
+
346
+ def update_activity(self, model_id: str, user_id: str) -> None:
347
+ record = self.get_run_record(model_id, user_id)
348
+ record.last_request_time = datetime.now(timezone.utc)
349
+ self._save_training_run(model_id)
350
+
351
+ async def run_forward(
352
+ self,
353
+ model_id: str,
354
+ user_id: str,
355
+ data: list[types.Datum],
356
+ loss_fn: types.LossFnType,
357
+ loss_fn_config: dict[str, float] | None,
358
+ seq_id: int | None,
359
+ *,
360
+ backward: bool,
361
+ ) -> types.ForwardBackwardOutput:
362
+ record = self.get_run_record(model_id, user_id)
363
+ self.update_activity(model_id, user_id)
364
+
365
+ span_name = (
366
+ "training_controller.run_forward_backward"
367
+ if backward
368
+ else "training_controller.run_forward"
369
+ )
370
+ with _get_tracer().start_as_current_span(span_name) as span:
371
+ span.set_attribute("tuft.training_run_id", model_id)
372
+ span.set_attribute("tuft.session_id", record.session_id)
373
+ span.set_attribute("tuft.backward", backward)
374
+ span.set_attribute("tuft.data_count", len(data))
375
+ span.set_attribute("tuft.loss_fn", loss_fn)
376
+
377
+ logger.info("Forward/backward begin for %s", model_id)
378
+ start_time = time.perf_counter()
379
+
380
+ # Count total input tokens for metrics
381
+ total_tokens = sum(len(datum.model_input.to_ints()) for datum in data)
382
+
383
+ async def _operation() -> types.ForwardBackwardOutput:
384
+ if record.backend is None:
385
+ raise UnknownModelException(model_name=model_id)
386
+ result = await record.backend.forward(
387
+ data,
388
+ lora_id=model_id,
389
+ loss_fn=loss_fn,
390
+ loss_fn_config=loss_fn_config,
391
+ backward=backward,
392
+ )
393
+ logger.info("Forward/backward completed for %s", model_id)
394
+ return result
395
+
396
+ result = await self._with_sequence_guard(record, seq_id, _operation)
397
+
398
+ # Record tokens per second metric
399
+ duration = time.perf_counter() - start_time
400
+ if total_tokens > 0 and duration > 0:
401
+ tokens_per_second = total_tokens / duration
402
+ get_metrics().training_tokens_per_second.record(
403
+ tokens_per_second, {"base_model": record.base_model}
404
+ )
405
+
406
+ return result
407
+
408
+ async def run_optim_step(
409
+ self, model_id: str, user_id: str, params: types.AdamParams, seq_id: int | None
410
+ ) -> types.OptimStepResponse:
411
+ record = self.get_run_record(model_id, user_id)
412
+ self.update_activity(model_id, user_id)
413
+
414
+ with _get_tracer().start_as_current_span("training_controller.run_optim_step") as span:
415
+ span.set_attribute("tuft.training_run_id", model_id)
416
+ span.set_attribute("tuft.session_id", record.session_id)
417
+ span.set_attribute("tuft.learning_rate", params.learning_rate)
418
+
419
+ logger.info("Optimizer step begin for %s", model_id)
420
+
421
+ async def _operation() -> types.OptimStepResponse:
422
+ if record.backend is None:
423
+ raise UnknownModelException(model_name=model_id)
424
+ result = await record.backend.optim_step(adam_params=params, lora_id=model_id)
425
+ logger.info("Optimizer step completed for %s", model_id)
426
+ return result
427
+
428
+ return await self._with_sequence_guard(record, seq_id, _operation)
429
+
430
+ async def unload_model(self, model_id: str, user_id: str) -> None:
431
+ # TODO: Ensure that all created training runs can be unloaded to reduce
432
+ # GPU memory usage.
433
+ if model_id not in self.training_runs:
434
+ raise UnknownModelException(model_name=model_id)
435
+ record = self.training_runs[model_id]
436
+ if record.model_owner != user_id:
437
+ raise UserMismatchException()
438
+ base_model = record.base_model
439
+ if record.backend is not None:
440
+ await record.backend.remove_adapter(model_id)
441
+ del self.training_runs[model_id]
442
+ self._delete_training_run(model_id)
443
+
444
+ # Update metrics
445
+ get_metrics().training_models_active.add(-1, {"base_model": base_model})
446
+
447
+ def list_training_runs(
448
+ self, *, user_id: str, limit: int | None = None, offset: int = 0
449
+ ) -> types.TrainingRunsResponse:
450
+ runs = [
451
+ record.to_training_run()
452
+ for record in self.training_runs.values()
453
+ if record.model_owner == user_id
454
+ ]
455
+ runs.sort(key=lambda run: run.last_request_time, reverse=True)
456
+ total = len(runs)
457
+ start = min(offset, total)
458
+ end = total if limit is None else min(start + limit, total)
459
+ paged = runs[start:end]
460
+ cursor = types.Cursor(offset=offset, limit=limit or total, total_count=total)
461
+ return types.TrainingRunsResponse(training_runs=paged, cursor=cursor)
462
+
463
+ def get_training_run_view(self, model_id: str, user_id: str) -> types.TrainingRun:
464
+ record = self.get_run_record(model_id=model_id, user_id=user_id)
465
+ return record.to_training_run()
466
+
467
+ def get_model_info(self, model_id: str, user_id: str) -> types.GetInfoResponse:
468
+ record = self.get_run_record(model_id=model_id, user_id=user_id)
469
+ model_data = types.ModelData(
470
+ arch="toy-transformer",
471
+ model_name=record.base_model,
472
+ tokenizer_id=record.base_model,
473
+ )
474
+ return types.GetInfoResponse(
475
+ model_data=model_data,
476
+ model_id=model_id,
477
+ is_lora=True,
478
+ lora_rank=record.lora_rank,
479
+ model_name=record.base_model,
480
+ )
481
+
482
+ async def save_checkpoint(
483
+ self,
484
+ model_id: str,
485
+ user_id: str,
486
+ name: str | None,
487
+ checkpoint_type: types.CheckpointType,
488
+ future_id: int = 0,
489
+ seq_id: int | None = None,
490
+ ) -> CheckpointRecord:
491
+ """Save a checkpoint for the given training run."""
492
+ training_run = self.get_run_record(model_id=model_id, user_id=user_id)
493
+
494
+ with _get_tracer().start_as_current_span("training_controller.save_checkpoint") as span:
495
+ span.set_attribute("tuft.training_run_id", model_id)
496
+ span.set_attribute("tuft.session_id", training_run.session_id)
497
+ span.set_attribute("tuft.checkpoint_type", checkpoint_type)
498
+
499
+ async def _operation() -> CheckpointRecord:
500
+ counter_attr = (
501
+ "next_training_checkpoint"
502
+ if checkpoint_type == "training"
503
+ else "next_sampler_checkpoint"
504
+ )
505
+ counter = getattr(training_run, counter_attr)
506
+ checkpoint_name = name or f"checkpoint-{counter:04d}"
507
+ checkpoint_id = f"{model_id}/{checkpoint_name}"
508
+ logger.info("Checkpoint save begin: %s", checkpoint_id)
509
+
510
+ setattr(training_run, counter_attr, counter + 1)
511
+ checkpoint = CheckpointRecord.from_training_run(
512
+ training_run_id=training_run.training_run_id,
513
+ checkpoint_name=checkpoint_name,
514
+ owner_name=training_run.model_owner,
515
+ checkpoint_type=checkpoint_type,
516
+ checkpoint_root_dir=self.config.checkpoint_dir,
517
+ exist_ok=True,
518
+ )
519
+ checkpoint.future_id = future_id
520
+ checkpoint.seq_id = seq_id
521
+ target_map = (
522
+ training_run.checkpoints
523
+ if checkpoint_type == "training"
524
+ else training_run.sampler_checkpoints
525
+ )
526
+ if training_run.backend is not None:
527
+ await training_run.backend.save_state(
528
+ lora_id=training_run.training_run_id,
529
+ checkpoint_record=checkpoint,
530
+ optimizer=(checkpoint_type == "training"),
531
+ )
532
+ checkpoint.size_bytes = checkpoint.path.stat().st_size
533
+ checkpoint.save_metadata(
534
+ base_model=training_run.base_model,
535
+ session_id=training_run.session_id,
536
+ lora_rank=training_run.lora_rank,
537
+ )
538
+ # save the checkpoint record in the training run
539
+ target_map[checkpoint_name] = checkpoint
540
+
541
+ # Save training run and checkpoint atomically to prevent inconsistency
542
+ # if server crashes between saves
543
+ self._save_training_run_with_checkpoint(model_id, checkpoint_name, checkpoint_type)
544
+
545
+ # Update metrics
546
+ metrics = get_metrics()
547
+ metrics.training_checkpoints_saved.add(
548
+ 1, {"model_id": model_id, "checkpoint_type": checkpoint_type}
549
+ )
550
+ logger.info("Checkpoint saved: %s", checkpoint_id)
551
+ metrics.training_checkpoint_size.record(
552
+ checkpoint.size_bytes,
553
+ {"model_id": model_id, "checkpoint_type": checkpoint_type},
554
+ )
555
+
556
+ return checkpoint
557
+
558
+ return await self._with_sequence_guard(training_run, seq_id, _operation)
559
+
560
+ async def load_checkpoint(
561
+ self,
562
+ model_id: str,
563
+ user_id: str,
564
+ path: str,
565
+ optimizer: bool,
566
+ seq_id: int | None = None,
567
+ ) -> None:
568
+ """Load a checkpoint."""
569
+ try:
570
+ parsed_checkpoint = CheckpointRecord.from_tinker_path(path, self.config.checkpoint_dir)
571
+ except FileNotFoundError as exc:
572
+ raise CheckpointNotFoundException(checkpoint_id=model_id) from exc
573
+ source_model_id = parsed_checkpoint.training_run_id or model_id
574
+ training_run = self.get_run_record(source_model_id, user_id, enforce_user_match=False)
575
+
576
+ collection = (
577
+ training_run.checkpoints
578
+ if parsed_checkpoint.checkpoint_type == "training"
579
+ else training_run.sampler_checkpoints
580
+ )
581
+
582
+ checkpoint = collection.get(parsed_checkpoint.checkpoint_id)
583
+ if checkpoint is None:
584
+ raise CheckpointNotFoundException(checkpoint_id=parsed_checkpoint.checkpoint_id)
585
+ try:
586
+ metadata = checkpoint.metadata
587
+ except FileNotFoundError as exc:
588
+ raise CheckpointMetadataReadException(
589
+ checkpoint_id=parsed_checkpoint.checkpoint_id
590
+ ) from exc
591
+ if metadata.public or (metadata.owner_name == user_id):
592
+ if training_run.backend is None:
593
+ raise UnknownModelException(model_name=model_id)
594
+
595
+ checkpoint_id = parsed_checkpoint.checkpoint_id
596
+ logger.info("Checkpoint load begin: %s", checkpoint_id)
597
+
598
+ async def _operation() -> None:
599
+ assert training_run.backend is not None
600
+ await training_run.backend.load_state(
601
+ lora_id=training_run.training_run_id,
602
+ checkpoint_record=checkpoint,
603
+ optimizer=optimizer,
604
+ )
605
+ logger.info("Checkpoint loaded: %s", checkpoint_id)
606
+
607
+ await self._with_sequence_guard(training_run, seq_id, _operation)
608
+ else:
609
+ raise CheckpointAccessDeniedException(checkpoint_id=parsed_checkpoint.checkpoint_id)
610
+
611
+ def delete_checkpoint(self, model_id: str, user_id: str, checkpoint_id: str) -> None:
612
+ training_run = self.get_run_record(model_id, user_id)
613
+ removed = training_run.checkpoints.pop(checkpoint_id, None)
614
+ is_sampler = False
615
+ if removed is None:
616
+ removed = training_run.sampler_checkpoints.pop(checkpoint_id, None)
617
+ is_sampler = True
618
+ if removed is None:
619
+ raise CheckpointNotFoundException(checkpoint_id=checkpoint_id)
620
+ removed.delete()
621
+
622
+ self._save_training_run(model_id)
623
+ if is_sampler:
624
+ self._delete_sampler_checkpoint_record(model_id, checkpoint_id)
625
+ else:
626
+ self._delete_checkpoint_record(model_id, checkpoint_id)
627
+
628
+ def list_checkpoints(self, model_id: str, user_id: str) -> list[types.Checkpoint]:
629
+ training_run = self.get_run_record(model_id, user_id)
630
+ checkpoints = [item.tinker_checkpoint for item in training_run.checkpoints.values()]
631
+ checkpoints += [
632
+ item.tinker_checkpoint for item in training_run.sampler_checkpoints.values()
633
+ ]
634
+ checkpoints.sort(key=lambda ckpt: ckpt.time)
635
+ return checkpoints
636
+
637
+ def list_user_checkpoints(
638
+ self,
639
+ user_id: str,
640
+ ) -> list[types.Checkpoint]:
641
+ checkpoints: list[types.Checkpoint] = []
642
+ training_runs = [run for run in self.training_runs.values() if run.model_owner == user_id]
643
+ for run in training_runs:
644
+ checkpoints.extend([item.tinker_checkpoint for item in run.checkpoints.values()])
645
+ checkpoints.sort(key=lambda item: item.time, reverse=True)
646
+ return checkpoints
647
+
648
+ def set_visibility(
649
+ self, model_id: str, checkpoint_id: str, user_id: str, *, public: bool
650
+ ) -> None:
651
+ training_run = self.get_run_record(model_id=model_id, user_id=user_id)
652
+ target = training_run.checkpoints.get(checkpoint_id)
653
+ is_sampler = False
654
+ if target is None:
655
+ target = training_run.sampler_checkpoints.get(checkpoint_id)
656
+ is_sampler = True
657
+ if target is None:
658
+ raise CheckpointNotFoundException(checkpoint_id=checkpoint_id)
659
+ target.set_visibility(public)
660
+
661
+ if is_sampler:
662
+ self._save_sampler_checkpoint(model_id, checkpoint_id)
663
+ else:
664
+ self._save_checkpoint(model_id, checkpoint_id)
665
+
666
+ def build_archive_url(
667
+ self,
668
+ model_id: str,
669
+ user_id: str,
670
+ checkpoint_id: str,
671
+ ) -> types.CheckpointArchiveUrlResponse:
672
+ training_run = self.get_run_record(model_id, user_id)
673
+ checkpoint = training_run.checkpoints.get(
674
+ checkpoint_id
675
+ ) or training_run.sampler_checkpoints.get(checkpoint_id)
676
+ if checkpoint is None:
677
+ raise CheckpointNotFoundException(checkpoint_id=checkpoint_id)
678
+ expires = datetime.now(timezone.utc) + timedelta(minutes=15)
679
+ return types.CheckpointArchiveUrlResponse(url=checkpoint.path.as_uri(), expires=expires)
680
+
681
+ def get_weights_info(self, model_id: str, user_id: str) -> types.WeightsInfoResponse:
682
+ training_run = self.get_run_record(model_id, user_id)
683
+ return types.WeightsInfoResponse(
684
+ base_model=training_run.base_model,
685
+ is_lora=True,
686
+ lora_rank=training_run.lora_rank,
687
+ )
688
+
689
+ def get_latest_checkpoint(self, model_id: str) -> CheckpointRecord | None:
690
+ record = self.training_runs.get(model_id)
691
+ if record is None:
692
+ return None
693
+ all_checkpoints = list(record.checkpoints.values()) + list(
694
+ record.sampler_checkpoints.values()
695
+ )
696
+ if not all_checkpoints:
697
+ return None
698
+ return max(all_checkpoints, key=lambda c: c.created_at)
699
+
700
+ async def restore_from_checkpoint(self, model_id: str) -> CheckpointRecord | None:
701
+ latest_ckpt = self.get_latest_checkpoint(model_id)
702
+ if latest_ckpt is None:
703
+ return None
704
+ record = self.training_runs.get(model_id)
705
+ if record is None or record.backend is None:
706
+ return None
707
+ try:
708
+ await record.backend.create_adapter(model_id, types.LoraConfig(rank=record.lora_rank))
709
+ except Exception:
710
+ logger.exception("Failed to create adapter for model %s during restore", model_id)
711
+ try:
712
+ await record.backend.load_state(
713
+ lora_id=model_id,
714
+ checkpoint_record=latest_ckpt,
715
+ optimizer=(latest_ckpt.checkpoint_type == "training"),
716
+ )
717
+ except Exception: # pylint: disable=broad-except
718
+ # If loading fails, mark as corrupted
719
+ record.corrupted = True
720
+ self._save_training_run(model_id)
721
+ return None
722
+
723
+ return latest_ckpt