tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tinker import types
|
|
8
|
+
|
|
9
|
+
from tuft.backends.base_backend import BaseTrainingBackend
|
|
10
|
+
from tuft.checkpoints import CheckpointRecord
|
|
11
|
+
from tuft.config import ModelConfig
|
|
12
|
+
from tuft.telemetry.tracing import get_tracer, inject_context
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_get_tracer = lambda: get_tracer("tuft.training_backend") # noqa: E731
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HFTrainingBackend(BaseTrainingBackend):
|
|
21
|
+
"""A training backend using Hugging Face transformers."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
24
|
+
from .hf_training_model import HFTrainingModel
|
|
25
|
+
|
|
26
|
+
self.config = config
|
|
27
|
+
logger.info("Ray actor created: HFTrainingModel(%s)", config.model_name)
|
|
28
|
+
self.model = HFTrainingModel.get_actor(config)
|
|
29
|
+
|
|
30
|
+
async def async_init(self) -> None:
|
|
31
|
+
await self.model.async_init.remote()
|
|
32
|
+
|
|
33
|
+
async def create_adapter(self, lora_id: str, lora_config: types.LoraConfig) -> None:
|
|
34
|
+
"""Create a LoRA adapter with the given ID and configuration."""
|
|
35
|
+
with _get_tracer().start_as_current_span("training_backend.create_adapter") as span:
|
|
36
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
37
|
+
span.set_attribute("tuft.lora_rank", lora_config.rank)
|
|
38
|
+
# Inject trace context for Ray actor
|
|
39
|
+
trace_context: dict[str, str] = {}
|
|
40
|
+
inject_context(trace_context)
|
|
41
|
+
await self.model.create_adapter.remote(lora_id, lora_config, trace_context)
|
|
42
|
+
|
|
43
|
+
async def remove_adapter(self, lora_id: str) -> None:
|
|
44
|
+
with _get_tracer().start_as_current_span("training_backend.remove_adapter") as span:
|
|
45
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
46
|
+
await self.model.remove_adapter.remote(lora_id)
|
|
47
|
+
|
|
48
|
+
async def forward(
|
|
49
|
+
self,
|
|
50
|
+
data: list[types.Datum],
|
|
51
|
+
lora_id: str,
|
|
52
|
+
loss_fn: types.LossFnType,
|
|
53
|
+
loss_fn_config: dict[str, float] | None,
|
|
54
|
+
backward: bool = False,
|
|
55
|
+
) -> types.ForwardBackwardOutput:
|
|
56
|
+
"""Forward pass (and backward if specified).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
data: List of Datum objects containing input data.
|
|
60
|
+
lora_id: The LoRA adapter ID to use.
|
|
61
|
+
loss_fn: The loss function to apply.
|
|
62
|
+
loss_fn_config: Optional configuration for the loss function.
|
|
63
|
+
backward: Whether to perform backward pass.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
ForwardBackwardOutput: The output of the forward (and backward) pass.
|
|
67
|
+
"""
|
|
68
|
+
span_name = "training_backend.forward_backward" if backward else "training_backend.forward"
|
|
69
|
+
with _get_tracer().start_as_current_span(span_name) as span:
|
|
70
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
71
|
+
span.set_attribute("tuft.backward", backward)
|
|
72
|
+
span.set_attribute("tuft.data_count", len(data))
|
|
73
|
+
# Inject trace context for Ray actor
|
|
74
|
+
trace_context: dict[str, str] = {}
|
|
75
|
+
inject_context(trace_context)
|
|
76
|
+
return await self.model.forward.remote(
|
|
77
|
+
data=data,
|
|
78
|
+
lora_id=lora_id,
|
|
79
|
+
loss_fn=loss_fn,
|
|
80
|
+
loss_fn_config=loss_fn_config,
|
|
81
|
+
backward=backward,
|
|
82
|
+
trace_context=trace_context,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def optim_step(
|
|
86
|
+
self, adam_params: types.AdamParams, lora_id: str
|
|
87
|
+
) -> types.OptimStepResponse:
|
|
88
|
+
"""Perform an optimization step using Adam optimizer.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
adam_params: Parameters for the Adam optimizer.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
OptimStepResponse: The response containing optimization metrics.
|
|
95
|
+
"""
|
|
96
|
+
with _get_tracer().start_as_current_span("training_backend.optim_step") as span:
|
|
97
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
98
|
+
# Inject trace context for Ray actor
|
|
99
|
+
trace_context: dict[str, str] = {}
|
|
100
|
+
inject_context(trace_context)
|
|
101
|
+
return await self.model.optim_step.remote(adam_params, lora_id, trace_context)
|
|
102
|
+
|
|
103
|
+
async def save_state(
|
|
104
|
+
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Save the state of the specified LoRA adapter."""
|
|
107
|
+
with _get_tracer().start_as_current_span("training_backend.save_state") as span:
|
|
108
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
109
|
+
span.set_attribute("tuft.optimizer", optimizer)
|
|
110
|
+
# Inject trace context for Ray actor
|
|
111
|
+
trace_context: dict[str, str] = {}
|
|
112
|
+
inject_context(trace_context)
|
|
113
|
+
await self.model.save_state.remote(
|
|
114
|
+
lora_id=lora_id,
|
|
115
|
+
checkpoint_record=checkpoint_record,
|
|
116
|
+
optimizer=optimizer,
|
|
117
|
+
trace_context=trace_context,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
async def load_state(
|
|
121
|
+
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Load the state of the specified LoRA adapter from the given path."""
|
|
124
|
+
with _get_tracer().start_as_current_span("training_backend.load_state") as span:
|
|
125
|
+
span.set_attribute("tuft.lora_id", lora_id)
|
|
126
|
+
span.set_attribute("tuft.optimizer", optimizer)
|
|
127
|
+
# Inject trace context for Ray actor
|
|
128
|
+
trace_context: dict[str, str] = {}
|
|
129
|
+
inject_context(trace_context)
|
|
130
|
+
await self.model.load_state.remote(
|
|
131
|
+
lora_id=lora_id,
|
|
132
|
+
checkpoint_record=checkpoint_record,
|
|
133
|
+
optimizer=optimizer,
|
|
134
|
+
trace_context=trace_context,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class DummyTrainingBackend(BaseTrainingBackend):
|
|
140
|
+
"""A dummy training backend for testing purposes."""
|
|
141
|
+
|
|
142
|
+
_lock: asyncio.Lock = field(init=False, repr=False)
|
|
143
|
+
_weights: np.ndarray = field(init=False, repr=False)
|
|
144
|
+
_adam_m: np.ndarray = field(init=False, repr=False)
|
|
145
|
+
_adam_v: np.ndarray = field(init=False, repr=False)
|
|
146
|
+
_beta1_power: float = field(init=False, default=1.0, repr=False)
|
|
147
|
+
_beta2_power: float = field(init=False, default=1.0, repr=False)
|
|
148
|
+
_pending_grad: np.ndarray | None = field(init=False, default=None, repr=False)
|
|
149
|
+
_pending_examples: int = field(init=False, default=0, repr=False)
|
|
150
|
+
_embedding_cache: dict[int, np.ndarray] = field(init=False, default_factory=dict, repr=False)
|
|
151
|
+
step: int = field(init=False, default=0)
|
|
152
|
+
|
|
153
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
154
|
+
self.config = config
|
|
155
|
+
self.seed = config.seed
|
|
156
|
+
self.hidden_dim = 16
|
|
157
|
+
rng = np.random.default_rng(self.seed or 0)
|
|
158
|
+
self._lock = asyncio.Lock()
|
|
159
|
+
self._weights = rng.standard_normal(self.hidden_dim, dtype=np.float32)
|
|
160
|
+
self._adam_m = np.zeros_like(self._weights)
|
|
161
|
+
self._adam_v = np.zeros_like(self._weights)
|
|
162
|
+
self._embedding_cache = {}
|
|
163
|
+
self._adapters = dict()
|
|
164
|
+
|
|
165
|
+
async def async_init(self) -> None:
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
# ------------------------------------------------------------------
|
|
169
|
+
# Forward / backward helpers
|
|
170
|
+
# ------------------------------------------------------------------
|
|
171
|
+
async def forward(
|
|
172
|
+
self,
|
|
173
|
+
data: list[types.Datum],
|
|
174
|
+
lora_id: str,
|
|
175
|
+
loss_fn: types.LossFnType,
|
|
176
|
+
loss_fn_config: dict[str, float] | None,
|
|
177
|
+
backward: bool = False,
|
|
178
|
+
) -> types.ForwardBackwardOutput:
|
|
179
|
+
return await self._run_step(data, backward=backward)
|
|
180
|
+
|
|
181
|
+
async def _run_step(
|
|
182
|
+
self, data: list[types.Datum], *, backward: bool
|
|
183
|
+
) -> types.ForwardBackwardOutput:
|
|
184
|
+
outputs: list[types.LossFnOutput] = []
|
|
185
|
+
total_loss = 0.0
|
|
186
|
+
grad_accum = np.zeros_like(self._weights)
|
|
187
|
+
for datum in data:
|
|
188
|
+
prompt_tokens = datum.model_input.to_ints()
|
|
189
|
+
target_tokens = self._target_tokens(datum)
|
|
190
|
+
prompt_vec = self._vectorize(prompt_tokens)
|
|
191
|
+
target_scalar = self._target_scalar(target_tokens)
|
|
192
|
+
prediction = float(np.dot(self._weights, prompt_vec))
|
|
193
|
+
loss = (prediction - target_scalar) ** 2
|
|
194
|
+
total_loss += loss
|
|
195
|
+
if backward:
|
|
196
|
+
grad = 2 * (prediction - target_scalar) * prompt_vec
|
|
197
|
+
grad_accum += grad
|
|
198
|
+
logprob_tensor = types.TensorData(
|
|
199
|
+
data=[float(-abs(prediction - target_scalar))] * max(len(target_tokens), 1),
|
|
200
|
+
dtype="float32",
|
|
201
|
+
shape=[max(len(target_tokens), 1)],
|
|
202
|
+
)
|
|
203
|
+
outputs.append({"logprobs": logprob_tensor})
|
|
204
|
+
|
|
205
|
+
metrics = {
|
|
206
|
+
"loss:sum": total_loss,
|
|
207
|
+
"step:max": float(self.step),
|
|
208
|
+
}
|
|
209
|
+
if backward:
|
|
210
|
+
grad_norm = float(np.linalg.norm(grad_accum) / max(len(data), 1))
|
|
211
|
+
metrics["grad_norm:mean"] = grad_norm
|
|
212
|
+
async with self._lock:
|
|
213
|
+
if self._pending_grad is None:
|
|
214
|
+
self._pending_grad = grad_accum
|
|
215
|
+
else:
|
|
216
|
+
self._pending_grad += grad_accum
|
|
217
|
+
self._pending_examples += len(data)
|
|
218
|
+
|
|
219
|
+
return types.ForwardBackwardOutput(
|
|
220
|
+
loss_fn_output_type="ToyLoss",
|
|
221
|
+
loss_fn_outputs=outputs,
|
|
222
|
+
metrics=metrics,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# ------------------------------------------------------------------
|
|
226
|
+
# Optimizer
|
|
227
|
+
# ------------------------------------------------------------------
|
|
228
|
+
async def optim_step(
|
|
229
|
+
self, adam_params: types.AdamParams, lora_id: str
|
|
230
|
+
) -> types.OptimStepResponse:
|
|
231
|
+
async with self._lock:
|
|
232
|
+
grad = self._pending_grad
|
|
233
|
+
examples = self._pending_examples
|
|
234
|
+
self._pending_grad = None
|
|
235
|
+
self._pending_examples = 0
|
|
236
|
+
|
|
237
|
+
if grad is None or not np.any(grad):
|
|
238
|
+
return types.OptimStepResponse(
|
|
239
|
+
metrics={
|
|
240
|
+
"learning_rate:mean": adam_params.learning_rate,
|
|
241
|
+
"step:max": float(self.step),
|
|
242
|
+
}
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
grad = grad / max(examples, 1)
|
|
246
|
+
if adam_params.grad_clip_norm > 0:
|
|
247
|
+
norm = np.linalg.norm(grad)
|
|
248
|
+
if norm > adam_params.grad_clip_norm:
|
|
249
|
+
grad *= adam_params.grad_clip_norm / max(norm, 1e-12)
|
|
250
|
+
|
|
251
|
+
if adam_params.weight_decay:
|
|
252
|
+
grad += adam_params.weight_decay * self._weights
|
|
253
|
+
|
|
254
|
+
beta1 = adam_params.beta1
|
|
255
|
+
beta2 = adam_params.beta2
|
|
256
|
+
|
|
257
|
+
self._adam_m = beta1 * self._adam_m + (1 - beta1) * grad
|
|
258
|
+
self._adam_v = beta2 * self._adam_v + (1 - beta2) * (grad**2)
|
|
259
|
+
self._beta1_power *= beta1
|
|
260
|
+
self._beta2_power *= beta2
|
|
261
|
+
m_hat = self._adam_m / (1 - self._beta1_power + 1e-12)
|
|
262
|
+
v_hat = self._adam_v / (1 - self._beta2_power + 1e-12)
|
|
263
|
+
|
|
264
|
+
update = adam_params.learning_rate * m_hat / (np.sqrt(v_hat) + adam_params.eps)
|
|
265
|
+
self._weights -= update
|
|
266
|
+
self.step += 1
|
|
267
|
+
|
|
268
|
+
metrics = {
|
|
269
|
+
"learning_rate:mean": adam_params.learning_rate,
|
|
270
|
+
"step:max": float(self.step),
|
|
271
|
+
"update_norm:mean": float(np.linalg.norm(update)),
|
|
272
|
+
}
|
|
273
|
+
return types.OptimStepResponse(metrics=metrics)
|
|
274
|
+
|
|
275
|
+
async def save_state(
|
|
276
|
+
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
|
|
277
|
+
) -> None:
|
|
278
|
+
if lora_id not in self._adapters:
|
|
279
|
+
raise ValueError(f"Adapter {lora_id} does not exist.")
|
|
280
|
+
# dummy save
|
|
281
|
+
|
|
282
|
+
async def load_state(
|
|
283
|
+
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
|
|
284
|
+
) -> None:
|
|
285
|
+
# create a dummy adapter on load
|
|
286
|
+
self._adapters[lora_id] = types.LoraConfig(rank=4)
|
|
287
|
+
|
|
288
|
+
async def create_adapter(self, lora_id: str, lora_config: types.LoraConfig) -> None:
|
|
289
|
+
self._adapters[lora_id] = lora_config
|
|
290
|
+
|
|
291
|
+
async def remove_adapter(self, lora_id: str) -> None:
|
|
292
|
+
self._adapters.pop(lora_id, None)
|
|
293
|
+
|
|
294
|
+
# ------------------------------------------------------------------
|
|
295
|
+
# Internal helpers
|
|
296
|
+
# ------------------------------------------------------------------
|
|
297
|
+
def _target_tokens(self, datum: types.Datum) -> list[int]:
|
|
298
|
+
if not datum.loss_fn_inputs:
|
|
299
|
+
return datum.model_input.to_ints()
|
|
300
|
+
tensor = datum.loss_fn_inputs.get("target_tokens")
|
|
301
|
+
if tensor is None:
|
|
302
|
+
return datum.model_input.to_ints()
|
|
303
|
+
return [int(value) for value in tensor.data]
|
|
304
|
+
|
|
305
|
+
def _vectorize(self, tokens: Sequence[int]) -> np.ndarray:
|
|
306
|
+
if not tokens:
|
|
307
|
+
return np.zeros(self.hidden_dim, dtype=np.float32)
|
|
308
|
+
vecs = [self._token_embedding(token) for token in tokens]
|
|
309
|
+
return np.mean(vecs, axis=0)
|
|
310
|
+
|
|
311
|
+
def _token_embedding(self, token_id: int) -> np.ndarray:
|
|
312
|
+
cached = self._embedding_cache.get(token_id)
|
|
313
|
+
if cached is None:
|
|
314
|
+
rng = np.random.default_rng(self.seed + token_id)
|
|
315
|
+
cached = rng.standard_normal(self.hidden_dim, dtype=np.float32)
|
|
316
|
+
self._embedding_cache[token_id] = cached
|
|
317
|
+
return cached
|
|
318
|
+
|
|
319
|
+
def _safe_mean(self, values: Sequence[int]) -> float:
|
|
320
|
+
if not values:
|
|
321
|
+
return 0.0
|
|
322
|
+
return float(sum(values) / len(values))
|
|
323
|
+
|
|
324
|
+
def _target_scalar(self, tokens: Sequence[int]) -> float:
|
|
325
|
+
if not tokens:
|
|
326
|
+
return 0.0
|
|
327
|
+
return np.tanh(self._safe_mean(tokens) / 100.0)
|
tuft/checkpoints.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Module for managing checkpoints on disk."""
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import shutil
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer
|
|
9
|
+
from tinker import types
|
|
10
|
+
|
|
11
|
+
from .exceptions import CheckpointMetadataReadException
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CheckpointMetadata(BaseModel):
|
|
15
|
+
"""A representation of checkpoint metadata."""
|
|
16
|
+
|
|
17
|
+
model_id: str
|
|
18
|
+
name: str
|
|
19
|
+
base_model: str
|
|
20
|
+
checkpoint_type: types.CheckpointType
|
|
21
|
+
created_at: str
|
|
22
|
+
session_id: str
|
|
23
|
+
tinker_path: str
|
|
24
|
+
owner_name: str
|
|
25
|
+
size_bytes: int = 0
|
|
26
|
+
lora_rank: int | None = None
|
|
27
|
+
public: bool = False
|
|
28
|
+
future_id: int = 0
|
|
29
|
+
seq_id: int | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CheckpointRecord(BaseModel):
|
|
33
|
+
"""A record representing a checkpoint on disk."""
|
|
34
|
+
|
|
35
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
36
|
+
|
|
37
|
+
checkpoint_id: str
|
|
38
|
+
owner_name: str
|
|
39
|
+
checkpoint_type: types.CheckpointType
|
|
40
|
+
training_run_id: str
|
|
41
|
+
path: Path
|
|
42
|
+
size_bytes: int = 0
|
|
43
|
+
public: bool = False
|
|
44
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
45
|
+
future_id: int = 0
|
|
46
|
+
seq_id: int | None = None
|
|
47
|
+
|
|
48
|
+
@field_serializer("path")
|
|
49
|
+
def serialize_path(self, path: Path) -> str:
|
|
50
|
+
"""Serialize Path to string for JSON."""
|
|
51
|
+
return str(path)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def tinker_checkpoint(self) -> types.Checkpoint:
|
|
55
|
+
"""Get a Tinker Checkpoint instance representing this record."""
|
|
56
|
+
return types.Checkpoint(
|
|
57
|
+
checkpoint_id=self.checkpoint_id,
|
|
58
|
+
checkpoint_type=self.checkpoint_type,
|
|
59
|
+
time=self.created_at,
|
|
60
|
+
tinker_path=self.tinker_path,
|
|
61
|
+
size_bytes=self.size_bytes,
|
|
62
|
+
public=self.public,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def metadata(self) -> CheckpointMetadata:
|
|
67
|
+
"""Get the checkpoint metadata.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
CheckpointMetadataReadException: If the metadata file does not
|
|
71
|
+
exist or is invalid.
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
return CheckpointMetadata.model_validate_json(
|
|
75
|
+
self.metadata_path.read_text(encoding="utf-8")
|
|
76
|
+
)
|
|
77
|
+
except FileNotFoundError as exc:
|
|
78
|
+
raise CheckpointMetadataReadException(checkpoint_id=self.checkpoint_id) from exc
|
|
79
|
+
except ValidationError as exc:
|
|
80
|
+
raise CheckpointMetadataReadException(checkpoint_id=self.checkpoint_id) from exc
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def tinker_path(self) -> str:
|
|
84
|
+
"""Get the tinker style path for this checkpoint."""
|
|
85
|
+
folder = "weights" if self.checkpoint_type == "training" else "sampler_weights"
|
|
86
|
+
return f"tinker://{self.training_run_id}/{folder}/{self.checkpoint_id}"
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def adapter_path(self) -> Path:
|
|
90
|
+
"""Get the path to the adapter weights file."""
|
|
91
|
+
return self.path / "adapter"
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def optimizer_path(self) -> Path:
|
|
95
|
+
"""Get the path to the optimizer state file."""
|
|
96
|
+
return self.path / "optimizer"
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def metadata_path(self) -> Path:
|
|
100
|
+
"""Get the path to the metadata JSON file."""
|
|
101
|
+
return self.path / "metadata.json"
|
|
102
|
+
|
|
103
|
+
def set_visibility(self, public: bool) -> None:
|
|
104
|
+
"""Set the visibility of the checkpoint."""
|
|
105
|
+
self.public = public
|
|
106
|
+
metadata = self.metadata
|
|
107
|
+
metadata.public = public
|
|
108
|
+
self.save_metadata(
|
|
109
|
+
base_model=metadata.base_model,
|
|
110
|
+
session_id=metadata.session_id,
|
|
111
|
+
lora_rank=metadata.lora_rank,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def save_metadata(self, base_model: str, session_id: str, lora_rank: int | None) -> None:
|
|
115
|
+
"""Save the checkpoint metadata to disk."""
|
|
116
|
+
# check the format of metadata
|
|
117
|
+
try:
|
|
118
|
+
metadata = CheckpointMetadata(
|
|
119
|
+
model_id=self.training_run_id,
|
|
120
|
+
name=self.checkpoint_id,
|
|
121
|
+
base_model=base_model,
|
|
122
|
+
checkpoint_type=self.checkpoint_type,
|
|
123
|
+
created_at=self.created_at.isoformat(),
|
|
124
|
+
session_id=session_id,
|
|
125
|
+
tinker_path=self.tinker_path,
|
|
126
|
+
owner_name=self.owner_name,
|
|
127
|
+
lora_rank=lora_rank,
|
|
128
|
+
public=self.public,
|
|
129
|
+
size_bytes=self.size_bytes,
|
|
130
|
+
future_id=self.future_id,
|
|
131
|
+
seq_id=self.seq_id,
|
|
132
|
+
)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
raise ValueError(f"Invalid checkpoint metadata: {e}") from e
|
|
135
|
+
self.metadata_path.write_text(metadata.model_dump_json(indent=2), encoding="utf-8")
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def from_tinker_path(cls, path: str, checkpoint_root_dir: Path) -> "CheckpointRecord":
|
|
139
|
+
"""Create a CheckpointRecord from a Tinker path.
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
FileNotFoundError: If the checkpoint directory or metadata.json is missing.
|
|
143
|
+
json.JSONDecodeError: If metadata.json cannot be parsed as JSON.
|
|
144
|
+
"""
|
|
145
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(path)
|
|
146
|
+
checkpoint_path = (
|
|
147
|
+
checkpoint_root_dir / parsed.training_run_id / parsed.checkpoint_id.split("/", 1)[-1]
|
|
148
|
+
)
|
|
149
|
+
record = cls(
|
|
150
|
+
checkpoint_id=parsed.checkpoint_id.split("/", 1)[-1],
|
|
151
|
+
checkpoint_type=parsed.checkpoint_type,
|
|
152
|
+
training_run_id=parsed.training_run_id,
|
|
153
|
+
path=checkpoint_path,
|
|
154
|
+
owner_name="", # Will be filled from metadata later
|
|
155
|
+
size_bytes=0, # Will be filled from metadata later
|
|
156
|
+
)
|
|
157
|
+
metadata = record.metadata # This may raise FileNotFoundError or JSONDecodeError
|
|
158
|
+
record.owner_name = metadata.owner_name
|
|
159
|
+
record.size_bytes = metadata.size_bytes
|
|
160
|
+
record.public = metadata.public
|
|
161
|
+
record.created_at = datetime.fromisoformat(metadata.created_at)
|
|
162
|
+
record.future_id = metadata.future_id
|
|
163
|
+
record.seq_id = metadata.seq_id
|
|
164
|
+
return record
|
|
165
|
+
|
|
166
|
+
def delete(self) -> None:
|
|
167
|
+
"""Delete the checkpoint from disk."""
|
|
168
|
+
with contextlib.suppress(FileNotFoundError):
|
|
169
|
+
shutil.rmtree(self.path)
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_training_run(
|
|
173
|
+
cls,
|
|
174
|
+
training_run_id: str,
|
|
175
|
+
checkpoint_name: str,
|
|
176
|
+
owner_name: str,
|
|
177
|
+
checkpoint_type: types.CheckpointType,
|
|
178
|
+
checkpoint_root_dir: Path,
|
|
179
|
+
exist_ok: bool = True,
|
|
180
|
+
) -> "CheckpointRecord":
|
|
181
|
+
"""Create a CheckpointRecord from a training run."""
|
|
182
|
+
checkpoint_dir = checkpoint_root_dir / training_run_id / checkpoint_name
|
|
183
|
+
if not exist_ok and checkpoint_dir.exists():
|
|
184
|
+
raise FileExistsError(f"Checkpoint directory already exists: {checkpoint_dir}")
|
|
185
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=exist_ok)
|
|
186
|
+
return cls(
|
|
187
|
+
checkpoint_id=checkpoint_name,
|
|
188
|
+
owner_name=owner_name,
|
|
189
|
+
checkpoint_type=checkpoint_type,
|
|
190
|
+
training_run_id=training_run_id,
|
|
191
|
+
path=checkpoint_dir,
|
|
192
|
+
size_bytes=0,
|
|
193
|
+
)
|
tuft/cli.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Command line utilities for the local TuFT server."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import typer
|
|
9
|
+
import uvicorn
|
|
10
|
+
|
|
11
|
+
from .config import AppConfig, load_yaml_config
|
|
12
|
+
from .server import create_root_app
|
|
13
|
+
from .telemetry import init_telemetry
|
|
14
|
+
from .telemetry.metrics import ResourceMetricsCollector
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
app = typer.Typer(help="TuFT - Tenant-unified Fine-Tuning Server.")
|
|
18
|
+
|
|
19
|
+
_HOST_OPTION = typer.Option("127.0.0.1", "--host", help="Interface to bind", envvar="TUFT_HOST")
|
|
20
|
+
_PORT_OPTION = typer.Option(10610, "--port", "-p", help="Port to bind", envvar="TUFT_PORT")
|
|
21
|
+
_LOG_LEVEL_OPTION = typer.Option("info", "--log-level", help="Uvicorn log level")
|
|
22
|
+
_RELOAD_OPTION = typer.Option(False, "--reload", help="Enable auto-reload (development only)")
|
|
23
|
+
_CONFIG_OPTION = typer.Option(
|
|
24
|
+
None,
|
|
25
|
+
"--config",
|
|
26
|
+
"-c",
|
|
27
|
+
help="Path to a TuFT configuration file (YAML)",
|
|
28
|
+
)
|
|
29
|
+
_CHECKPOINT_DIR_OPTION = typer.Option(
|
|
30
|
+
None,
|
|
31
|
+
"--checkpoint-dir",
|
|
32
|
+
help="Override checkpoint_dir from config file. Defaults to ~/.cache/tuft/checkpoints.",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _build_config(
|
|
37
|
+
config_path: Path | None,
|
|
38
|
+
checkpoint_dir: Path | None,
|
|
39
|
+
) -> AppConfig:
|
|
40
|
+
if config_path is None:
|
|
41
|
+
raise typer.BadParameter("Configuration file must be provided via --config")
|
|
42
|
+
config = load_yaml_config(config_path)
|
|
43
|
+
if checkpoint_dir is not None:
|
|
44
|
+
config.checkpoint_dir = checkpoint_dir.expanduser()
|
|
45
|
+
config.ensure_directories()
|
|
46
|
+
return config
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _init_telemetry(config: AppConfig, log_level: str) -> None:
|
|
50
|
+
"""Initialize OpenTelemetry if enabled."""
|
|
51
|
+
# Configure root logger level to ensure logs flow to OTel
|
|
52
|
+
numeric_level = getattr(logging, log_level.upper(), logging.INFO)
|
|
53
|
+
|
|
54
|
+
if not config.telemetry.enabled:
|
|
55
|
+
logging.basicConfig(level=numeric_level)
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
init_telemetry(config.telemetry)
|
|
59
|
+
# Start resource metrics collection
|
|
60
|
+
ResourceMetricsCollector.start(str(config.checkpoint_dir))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@app.command()
|
|
64
|
+
def launch(
|
|
65
|
+
host: str = _HOST_OPTION,
|
|
66
|
+
port: int = _PORT_OPTION,
|
|
67
|
+
log_level: str = _LOG_LEVEL_OPTION,
|
|
68
|
+
reload: bool = _RELOAD_OPTION,
|
|
69
|
+
config_path: Path | None = _CONFIG_OPTION,
|
|
70
|
+
checkpoint_dir: Path | None = _CHECKPOINT_DIR_OPTION,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Launch the TuFT server."""
|
|
73
|
+
app_config = _build_config(config_path, checkpoint_dir)
|
|
74
|
+
# Initialize telemetry before starting the server
|
|
75
|
+
_init_telemetry(app_config, log_level)
|
|
76
|
+
logging.getLogger("tuft").info("Server starting on %s:%s", host, port)
|
|
77
|
+
uvicorn.run(
|
|
78
|
+
create_root_app(app_config),
|
|
79
|
+
host=host,
|
|
80
|
+
port=port,
|
|
81
|
+
log_level=log_level,
|
|
82
|
+
reload=reload,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def main() -> None:
|
|
87
|
+
app()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
if __name__ == "__main__":
|
|
91
|
+
main()
|