tuft 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__main__.py +7 -0
- tuft/backends/hf_training_model.py +184 -64
- tuft/cli.py +161 -8
- tuft/config.py +63 -59
- tuft/exceptions.py +66 -0
- tuft/futures.py +22 -2
- tuft/loss_fn/__init__.py +33 -0
- tuft/persistence/__init__.py +10 -2
- tuft/persistence/redis_store.py +352 -31
- tuft/sampling_controller.py +37 -11
- tuft/sequence_executor.py +72 -0
- tuft/server.py +9 -2
- tuft/state.py +3 -0
- tuft/training_controller.py +20 -5
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/METADATA +10 -66
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/RECORD +19 -17
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/WHEEL +0 -0
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/entry_points.txt +0 -0
- {tuft-0.1.1.dist-info → tuft-0.1.3.dist-info}/licenses/LICENSE +0 -0
tuft/__main__.py
ADDED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import logging
|
|
2
3
|
import shutil
|
|
3
|
-
from typing import Dict
|
|
4
|
+
from typing import Callable, Dict
|
|
4
5
|
|
|
5
6
|
import ray
|
|
6
7
|
import torch
|
|
@@ -14,13 +15,12 @@ from transformers import AutoModelForCausalLM
|
|
|
14
15
|
|
|
15
16
|
from tuft.checkpoints import CheckpointRecord
|
|
16
17
|
from tuft.config import ModelConfig
|
|
17
|
-
from tuft.loss_fn import get_loss_fn
|
|
18
|
+
from tuft.loss_fn import get_loss_fn, metrics_reduction
|
|
18
19
|
from tuft.telemetry.tracing import extract_context, get_tracer
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
_get_tracer = lambda: get_tracer("tuft.hf_training_model") # noqa: E731
|
|
22
23
|
|
|
23
|
-
|
|
24
24
|
MODULE_MAP = {
|
|
25
25
|
"llama": {
|
|
26
26
|
"attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
@@ -58,6 +58,8 @@ class HFTrainingModel:
|
|
|
58
58
|
self.model = self._init_peft_model(config)
|
|
59
59
|
self.adapter_optimizer: Dict[str, torch.optim.AdamW] = {}
|
|
60
60
|
self._lock = asyncio.Lock()
|
|
61
|
+
self.logger = logging.getLogger()
|
|
62
|
+
self.micro_batch_size = config.micro_batch_size
|
|
61
63
|
|
|
62
64
|
async def async_init(self) -> None:
|
|
63
65
|
"""Do nothing for now. Just used to make sure the actor is ready."""
|
|
@@ -193,7 +195,9 @@ class HFTrainingModel:
|
|
|
193
195
|
async with self._lock:
|
|
194
196
|
if lora_id in self.adapter_optimizer:
|
|
195
197
|
self.model.delete_adapter(lora_id)
|
|
196
|
-
self.adapter_optimizer.pop(lora_id)
|
|
198
|
+
optimizer = self.adapter_optimizer.pop(lora_id)
|
|
199
|
+
del optimizer
|
|
200
|
+
torch.cuda.empty_cache()
|
|
197
201
|
|
|
198
202
|
# --------------------------------
|
|
199
203
|
# Training methods
|
|
@@ -207,7 +211,7 @@ class HFTrainingModel:
|
|
|
207
211
|
backward: bool = False,
|
|
208
212
|
trace_context: dict[str, str] | None = None,
|
|
209
213
|
) -> types.ForwardBackwardOutput:
|
|
210
|
-
"""Forward pass
|
|
214
|
+
"""Forward pass with micro-batch gradient accumulation.
|
|
211
215
|
|
|
212
216
|
Args:
|
|
213
217
|
data: List of Datum objects containing input data.
|
|
@@ -222,73 +226,163 @@ class HFTrainingModel:
|
|
|
222
226
|
"""
|
|
223
227
|
ctx = extract_context(trace_context or {})
|
|
224
228
|
span_name = "hf_model.forward_backward" if backward else "hf_model.forward"
|
|
229
|
+
|
|
225
230
|
with _get_tracer().start_as_current_span(span_name, context=ctx) as span:
|
|
226
231
|
span.set_attribute("tuft.lora_id", lora_id)
|
|
227
232
|
span.set_attribute("tuft.backward", backward)
|
|
228
233
|
span.set_attribute("tuft.data_count", len(data))
|
|
229
|
-
# Prepare input tensors
|
|
230
|
-
input_ids = [
|
|
231
|
-
torch.tensor(datum.model_input.to_ints(), dtype=torch.long) for datum in data
|
|
232
|
-
]
|
|
233
|
-
input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
234
|
-
attention_mask = (input_ids_padded != 0).long()
|
|
235
|
-
position_ids = (
|
|
236
|
-
torch.arange(input_ids_padded.size(1), dtype=torch.long)
|
|
237
|
-
.unsqueeze(0)
|
|
238
|
-
.expand(input_ids_padded.size(0), -1)
|
|
239
|
-
)
|
|
240
|
-
# Move tensors to model device
|
|
241
|
-
device = next(self.model.parameters()).device
|
|
242
|
-
input_ids_padded = input_ids_padded.to(device)
|
|
243
|
-
attention_mask = attention_mask.to(device)
|
|
244
|
-
position_ids = position_ids.to(device)
|
|
245
|
-
|
|
246
|
-
# Activate the correct adapter
|
|
247
|
-
async with self._lock:
|
|
248
|
-
self._activate_adapter(lora_id)
|
|
249
|
-
|
|
250
|
-
# Forward pass
|
|
251
|
-
outputs = self.model(
|
|
252
|
-
input_ids=input_ids_padded,
|
|
253
|
-
attention_mask=attention_mask,
|
|
254
|
-
position_ids=position_ids,
|
|
255
|
-
return_dict=True,
|
|
256
|
-
)
|
|
257
234
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
loss_fn_config = {}
|
|
261
|
-
loss_fn_callable = get_loss_fn(loss_fn)
|
|
262
|
-
logits = outputs.logits
|
|
263
|
-
if "temperature" in loss_fn_config:
|
|
264
|
-
temperature = loss_fn_config["temperature"]
|
|
265
|
-
logits.div_(temperature)
|
|
235
|
+
batch_size = len(data)
|
|
236
|
+
micro_batch_size = self.config.micro_batch_size
|
|
266
237
|
|
|
267
|
-
|
|
238
|
+
num_micro_batches = (batch_size + micro_batch_size - 1) // micro_batch_size
|
|
239
|
+
span.set_attribute("tuft.num_micro_batches", num_micro_batches)
|
|
268
240
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
241
|
+
if num_micro_batches > 1:
|
|
242
|
+
self.logger.debug(
|
|
243
|
+
f"[MICRO_BATCH] Splitting batch_size={batch_size} into "
|
|
244
|
+
f"{num_micro_batches} micro-batches of size {micro_batch_size}"
|
|
245
|
+
)
|
|
273
246
|
|
|
274
|
-
|
|
247
|
+
loss_fn_callable = get_loss_fn(loss_fn)
|
|
248
|
+
all_loss_fn_outputs = []
|
|
249
|
+
micro_batch_weights = []
|
|
250
|
+
metric_list = []
|
|
251
|
+
total_loss = 0.0
|
|
275
252
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
loss.backward()
|
|
253
|
+
async with self._lock:
|
|
254
|
+
self._activate_adapter(lora_id)
|
|
279
255
|
|
|
280
|
-
|
|
281
|
-
|
|
256
|
+
for micro_idx in range(num_micro_batches):
|
|
257
|
+
start_idx = micro_idx * micro_batch_size
|
|
258
|
+
end_idx = min(start_idx + micro_batch_size, batch_size)
|
|
259
|
+
micro_data = data[start_idx:end_idx]
|
|
260
|
+
|
|
261
|
+
torch.cuda.reset_peak_memory_stats()
|
|
262
|
+
self.logger.debug(
|
|
263
|
+
f"[GPU-micro_batch_{micro_idx}] before_forward: "
|
|
264
|
+
f"allocated={torch.cuda.memory_allocated() / 1e9:.2f}GB, "
|
|
265
|
+
f"reserved={torch.cuda.memory_reserved() / 1e9:.2f}GB"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
micro_loss, micro_metrics, micro_outputs = await self._forward_micro_batch(
|
|
269
|
+
micro_data,
|
|
270
|
+
loss_fn_callable,
|
|
271
|
+
loss_fn_config,
|
|
272
|
+
backward=backward,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
total_loss += micro_loss
|
|
276
|
+
all_loss_fn_outputs.extend(micro_outputs)
|
|
277
|
+
micro_batch_weights.append(len(micro_outputs))
|
|
278
|
+
|
|
279
|
+
metric_list.append(micro_metrics)
|
|
280
|
+
|
|
281
|
+
self.logger.debug(
|
|
282
|
+
f"[GPU-micro_batch_{micro_idx}] after_forward: "
|
|
283
|
+
f"allocated={torch.cuda.memory_allocated() / 1e9:.2f}GB, "
|
|
284
|
+
f"reserved={torch.cuda.memory_reserved() / 1e9:.2f}GB, "
|
|
285
|
+
f"max_allocated={torch.cuda.max_memory_allocated() / 1e9:.2f}GB"
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
torch.cuda.empty_cache()
|
|
289
|
+
|
|
290
|
+
avg_loss = total_loss / num_micro_batches
|
|
291
|
+
self.logger.debug(f"Average loss: {avg_loss}")
|
|
292
|
+
metric_list = metrics_reduction(metric_list, micro_batch_weights)
|
|
293
|
+
|
|
294
|
+
self.logger.debug(
|
|
295
|
+
f"[GPU-after_micro_batches] allocated={torch.cuda.memory_allocated() / 1e9:.2f}GB"
|
|
296
|
+
f", reserved={torch.cuda.memory_reserved() / 1e9:.2f}GB"
|
|
282
297
|
)
|
|
298
|
+
|
|
283
299
|
return types.ForwardBackwardOutput(
|
|
284
300
|
loss_fn_output_type=loss_fn,
|
|
285
|
-
loss_fn_outputs=
|
|
286
|
-
|
|
287
|
-
for logprobs in unpaded_logprobs
|
|
288
|
-
],
|
|
289
|
-
metrics=metric,
|
|
301
|
+
loss_fn_outputs=all_loss_fn_outputs,
|
|
302
|
+
metrics=metric_list or {},
|
|
290
303
|
)
|
|
291
304
|
|
|
305
|
+
async def _forward_micro_batch(
|
|
306
|
+
self,
|
|
307
|
+
data: list[types.Datum],
|
|
308
|
+
loss_fn_callable: Callable,
|
|
309
|
+
loss_fn_config: dict[str, float] | None,
|
|
310
|
+
backward: bool,
|
|
311
|
+
) -> tuple[float, dict[str, float], list[dict]]:
|
|
312
|
+
"""Process a single micro-batch.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
tuple: (loss_value, metrics_dict, loss_fn_outputs_list)
|
|
316
|
+
"""
|
|
317
|
+
# Prepare input tensors
|
|
318
|
+
input_ids = [torch.tensor(datum.model_input.to_ints(), dtype=torch.long) for datum in data]
|
|
319
|
+
input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
|
|
320
|
+
attention_mask = (input_ids_padded != 0).long()
|
|
321
|
+
position_ids = (
|
|
322
|
+
torch.arange(input_ids_padded.size(1), dtype=torch.long)
|
|
323
|
+
.unsqueeze(0)
|
|
324
|
+
.expand(input_ids_padded.size(0), -1)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
device = next(self.model.parameters()).device
|
|
328
|
+
input_ids_padded = input_ids_padded.to(device)
|
|
329
|
+
attention_mask = attention_mask.to(device)
|
|
330
|
+
position_ids = position_ids.to(device)
|
|
331
|
+
|
|
332
|
+
# Forward pass
|
|
333
|
+
outputs = self.model(
|
|
334
|
+
input_ids=input_ids_padded,
|
|
335
|
+
attention_mask=attention_mask,
|
|
336
|
+
position_ids=position_ids,
|
|
337
|
+
return_dict=True,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if loss_fn_config is None:
|
|
341
|
+
loss_fn_config = {}
|
|
342
|
+
|
|
343
|
+
logits = outputs.logits
|
|
344
|
+
del outputs
|
|
345
|
+
torch.cuda.empty_cache()
|
|
346
|
+
|
|
347
|
+
if "temperature" in loss_fn_config:
|
|
348
|
+
temperature = loss_fn_config["temperature"]
|
|
349
|
+
logits = logits / temperature
|
|
350
|
+
|
|
351
|
+
loss_fn_inputs = self._prepare_loss_fn_inputs(data)
|
|
352
|
+
target_tokens = loss_fn_inputs["target_tokens"]
|
|
353
|
+
|
|
354
|
+
target_logprobs = self._compute_logprobs_from_target_tokens(logits, target_tokens)
|
|
355
|
+
del logits
|
|
356
|
+
torch.cuda.empty_cache()
|
|
357
|
+
|
|
358
|
+
loss_fn_inputs["target_logprobs"] = target_logprobs
|
|
359
|
+
loss, metric = loss_fn_callable(loss_fn_inputs, loss_fn_config)
|
|
360
|
+
|
|
361
|
+
# Backward with gradient accumulation
|
|
362
|
+
if backward:
|
|
363
|
+
loss.backward(retain_graph=False)
|
|
364
|
+
torch.cuda.empty_cache()
|
|
365
|
+
|
|
366
|
+
unpaded_logprobs = self._unpad_tensor(
|
|
367
|
+
target_logprobs.detach(),
|
|
368
|
+
[len(datum.model_input.to_ints()) for datum in data],
|
|
369
|
+
)
|
|
370
|
+
loss_fn_outputs = [
|
|
371
|
+
{"logprobs": types.TensorData.from_torch(logprobs.cpu().clone())}
|
|
372
|
+
for logprobs in unpaded_logprobs
|
|
373
|
+
]
|
|
374
|
+
|
|
375
|
+
loss_value = loss.detach().item()
|
|
376
|
+
|
|
377
|
+
del target_logprobs
|
|
378
|
+
del unpaded_logprobs
|
|
379
|
+
del loss_fn_inputs
|
|
380
|
+
del loss
|
|
381
|
+
|
|
382
|
+
torch.cuda.empty_cache()
|
|
383
|
+
|
|
384
|
+
return loss_value, metric, loss_fn_outputs
|
|
385
|
+
|
|
292
386
|
async def optim_step(
|
|
293
387
|
self,
|
|
294
388
|
adam_params: types.AdamParams,
|
|
@@ -316,7 +410,9 @@ class HFTrainingModel:
|
|
|
316
410
|
param_group["weight_decay"] = adam_params.weight_decay
|
|
317
411
|
optimizer.step()
|
|
318
412
|
optimizer.zero_grad()
|
|
319
|
-
|
|
413
|
+
|
|
414
|
+
torch.cuda.empty_cache()
|
|
415
|
+
return types.OptimStepResponse()
|
|
320
416
|
|
|
321
417
|
# --------------------------------
|
|
322
418
|
# Helper methods
|
|
@@ -362,11 +458,33 @@ class HFTrainingModel:
|
|
|
362
458
|
def _compute_logprobs_from_target_tokens(
|
|
363
459
|
self, logits: torch.Tensor, target_tokens: torch.Tensor
|
|
364
460
|
) -> torch.Tensor:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
461
|
+
"""Compute log probabilities of target tokens from logits with low memory usage.
|
|
462
|
+
https://github.com/OpenRLHF/OpenRLHF/pull/718
|
|
463
|
+
"""
|
|
464
|
+
if logits.dtype in [torch.float32, torch.float64]:
|
|
465
|
+
logits_labels = torch.gather(logits, dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(
|
|
466
|
+
-1
|
|
467
|
+
)
|
|
468
|
+
logsumexp_values = torch.stack(
|
|
469
|
+
[
|
|
470
|
+
torch.logsumexp(logit, dim=-1) for logit in logits
|
|
471
|
+
] # loop to reduce peak mem consumption
|
|
472
|
+
)
|
|
473
|
+
log_probs_labels = (
|
|
474
|
+
logits_labels - logsumexp_values
|
|
475
|
+
) # log_softmax(x_i) = x_i - logsumexp(x)
|
|
476
|
+
else:
|
|
477
|
+
log_probs_labels = []
|
|
478
|
+
for row_logits, row_labels in zip(
|
|
479
|
+
logits, target_tokens, strict=True
|
|
480
|
+
): # loop to reduce peak mem consumption
|
|
481
|
+
row_log_probs = torch.nn.functional.log_softmax(row_logits, dim=-1)
|
|
482
|
+
row_log_probs_labels = row_log_probs.gather(
|
|
483
|
+
dim=-1, index=row_labels.unsqueeze(-1)
|
|
484
|
+
).squeeze(-1)
|
|
485
|
+
log_probs_labels.append(row_log_probs_labels)
|
|
486
|
+
log_probs_labels = torch.stack(log_probs_labels)
|
|
487
|
+
return log_probs_labels
|
|
370
488
|
|
|
371
489
|
def _unpad_tensor(
|
|
372
490
|
self, padded_tensor: torch.Tensor, original_lengths: list[int]
|
|
@@ -383,6 +501,8 @@ class HFTrainingModel:
|
|
|
383
501
|
dtype="auto",
|
|
384
502
|
device_map="auto",
|
|
385
503
|
)
|
|
504
|
+
model.enable_input_require_grads()
|
|
505
|
+
model.gradient_checkpointing_enable({"use_reentrant": False})
|
|
386
506
|
peft_config = LoraConfig()
|
|
387
507
|
peft_model = get_peft_model(model, peft_config=peft_config, adapter_name="default")
|
|
388
508
|
return peft_model
|
|
@@ -398,7 +518,7 @@ class HFTrainingModel:
|
|
|
398
518
|
ray.remote(cls)
|
|
399
519
|
.options(
|
|
400
520
|
name="training_model_" + config.model_name,
|
|
401
|
-
num_gpus=1 if not config.colocate else 1 - config.sampling_memory_fraction,
|
|
521
|
+
num_gpus=(1 if not config.colocate else 1 - config.sampling_memory_fraction),
|
|
402
522
|
)
|
|
403
523
|
.remote(config)
|
|
404
524
|
)
|
tuft/cli.py
CHANGED
|
@@ -3,49 +3,198 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
+
import os
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
|
|
8
9
|
import typer
|
|
9
10
|
import uvicorn
|
|
10
11
|
|
|
11
12
|
from .config import AppConfig, load_yaml_config
|
|
13
|
+
from .exceptions import ConfigMismatchError
|
|
14
|
+
from .persistence import (
|
|
15
|
+
flush_all_data,
|
|
16
|
+
get_current_namespace,
|
|
17
|
+
get_redis_store,
|
|
18
|
+
validate_config_signature,
|
|
19
|
+
)
|
|
12
20
|
from .server import create_root_app
|
|
13
21
|
from .telemetry import init_telemetry
|
|
14
22
|
from .telemetry.metrics import ResourceMetricsCollector
|
|
15
23
|
|
|
16
24
|
|
|
17
|
-
app = typer.Typer(help="TuFT - Tenant-unified Fine-Tuning Server.")
|
|
25
|
+
app = typer.Typer(help="TuFT - Tenant-unified Fine-Tuning Server.", no_args_is_help=True)
|
|
26
|
+
clear_app = typer.Typer(help="Clear data commands.", no_args_is_help=True)
|
|
27
|
+
app.add_typer(clear_app, name="clear")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Required for Typer to recognize subcommands when using no_args_is_help=True
|
|
31
|
+
@app.callback()
|
|
32
|
+
def callback() -> None:
|
|
33
|
+
"""TuFT - Tenant-unified Fine-Tuning Server."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Default paths based on TUFT_HOME
|
|
37
|
+
_TUFT_HOME = Path(os.environ.get("TUFT_HOME", Path.home() / ".tuft"))
|
|
38
|
+
_DEFAULT_CONFIG_PATH = _TUFT_HOME / "configs" / "tuft_config.yaml"
|
|
39
|
+
_DEFAULT_CHECKPOINT_DIR = _TUFT_HOME / "checkpoints"
|
|
18
40
|
|
|
19
41
|
_HOST_OPTION = typer.Option("127.0.0.1", "--host", help="Interface to bind", envvar="TUFT_HOST")
|
|
20
42
|
_PORT_OPTION = typer.Option(10610, "--port", "-p", help="Port to bind", envvar="TUFT_PORT")
|
|
21
|
-
_LOG_LEVEL_OPTION = typer.Option(
|
|
43
|
+
_LOG_LEVEL_OPTION = typer.Option(
|
|
44
|
+
"info", "--log-level", help="Uvicorn log level", envvar="TUFT_LOG_LEVEL"
|
|
45
|
+
)
|
|
22
46
|
_RELOAD_OPTION = typer.Option(False, "--reload", help="Enable auto-reload (development only)")
|
|
23
47
|
_CONFIG_OPTION = typer.Option(
|
|
24
48
|
None,
|
|
25
49
|
"--config",
|
|
26
50
|
"-c",
|
|
27
|
-
help="Path to a TuFT configuration file (YAML)",
|
|
51
|
+
help=f"Path to a TuFT configuration file (YAML). Defaults to {_DEFAULT_CONFIG_PATH}",
|
|
52
|
+
envvar="TUFT_CONFIG",
|
|
28
53
|
)
|
|
29
54
|
_CHECKPOINT_DIR_OPTION = typer.Option(
|
|
30
55
|
None,
|
|
31
56
|
"--checkpoint-dir",
|
|
32
|
-
help="Override checkpoint_dir from config file. Defaults to
|
|
57
|
+
help=f"Override checkpoint_dir from config file. Defaults to {_DEFAULT_CHECKPOINT_DIR}",
|
|
58
|
+
envvar="TUFT_CHECKPOINT_DIR",
|
|
33
59
|
)
|
|
34
60
|
|
|
35
61
|
|
|
62
|
+
def _resolve_config_path(config_path: Path | None) -> Path:
|
|
63
|
+
"""Resolve the config path, falling back to default if not provided."""
|
|
64
|
+
if config_path is not None:
|
|
65
|
+
return config_path
|
|
66
|
+
if _DEFAULT_CONFIG_PATH.exists():
|
|
67
|
+
return _DEFAULT_CONFIG_PATH
|
|
68
|
+
raise typer.BadParameter(
|
|
69
|
+
f"Configuration file must be provided via --config or TUFT_CONFIG, "
|
|
70
|
+
f"or create a default config at {_DEFAULT_CONFIG_PATH}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
36
74
|
def _build_config(
|
|
37
75
|
config_path: Path | None,
|
|
38
76
|
checkpoint_dir: Path | None,
|
|
39
77
|
) -> AppConfig:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
78
|
+
resolved_config_path = _resolve_config_path(config_path)
|
|
79
|
+
config = load_yaml_config(resolved_config_path)
|
|
80
|
+
# Apply checkpoint_dir override, or use default if not in config
|
|
43
81
|
if checkpoint_dir is not None:
|
|
44
82
|
config.checkpoint_dir = checkpoint_dir.expanduser()
|
|
83
|
+
elif config.checkpoint_dir is None:
|
|
84
|
+
config.checkpoint_dir = _DEFAULT_CHECKPOINT_DIR
|
|
85
|
+
# Guarantee checkpoint_dir is set after resolution
|
|
86
|
+
assert config.checkpoint_dir is not None, "checkpoint_dir must be set after config resolution"
|
|
45
87
|
config.ensure_directories()
|
|
46
88
|
return config
|
|
47
89
|
|
|
48
90
|
|
|
91
|
+
_FORCE_OPTION = typer.Option(
|
|
92
|
+
False,
|
|
93
|
+
"--force",
|
|
94
|
+
"-f",
|
|
95
|
+
help="Skip confirmation prompts when clearing persistence data.",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@clear_app.command(name="persistence")
|
|
100
|
+
def clear_persistence(
|
|
101
|
+
config_path: Path | None = _CONFIG_OPTION,
|
|
102
|
+
force: bool = _FORCE_OPTION,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Clear all persistence data and start fresh.
|
|
105
|
+
|
|
106
|
+
This command clears all existing persistence data in the configured namespace.
|
|
107
|
+
Use this when the configuration has changed and you want to discard old data.
|
|
108
|
+
"""
|
|
109
|
+
# Build config to get persistence settings
|
|
110
|
+
try:
|
|
111
|
+
resolved_config_path = _resolve_config_path(config_path)
|
|
112
|
+
config = load_yaml_config(resolved_config_path)
|
|
113
|
+
except typer.BadParameter as e:
|
|
114
|
+
typer.secho(f"Error: {e}", fg=typer.colors.RED)
|
|
115
|
+
raise typer.Exit(1) from e
|
|
116
|
+
|
|
117
|
+
if not config.persistence.enabled:
|
|
118
|
+
typer.secho(
|
|
119
|
+
"Persistence is disabled in the configuration. Nothing to clear.",
|
|
120
|
+
fg=typer.colors.YELLOW,
|
|
121
|
+
)
|
|
122
|
+
raise typer.Exit(0)
|
|
123
|
+
|
|
124
|
+
# Configure the store
|
|
125
|
+
store = get_redis_store()
|
|
126
|
+
store.configure(config.persistence)
|
|
127
|
+
namespace = get_current_namespace()
|
|
128
|
+
|
|
129
|
+
if not force:
|
|
130
|
+
typer.secho(
|
|
131
|
+
"\n🚨🚨🚨 CRITICAL WARNING 🚨🚨🚨\n",
|
|
132
|
+
fg=typer.colors.RED,
|
|
133
|
+
bold=True,
|
|
134
|
+
)
|
|
135
|
+
typer.secho(
|
|
136
|
+
"This command will PERMANENTLY DELETE ALL persistence data!\n",
|
|
137
|
+
fg=typer.colors.RED,
|
|
138
|
+
bold=True,
|
|
139
|
+
)
|
|
140
|
+
typer.secho(
|
|
141
|
+
f"📦 Target namespace: '{namespace}'\n",
|
|
142
|
+
fg=typer.colors.YELLOW,
|
|
143
|
+
bold=True,
|
|
144
|
+
)
|
|
145
|
+
typer.echo(
|
|
146
|
+
f"This IRREVERSIBLE action will destroy ALL data in namespace '{namespace}':\n"
|
|
147
|
+
" ❌ All saved sessions\n"
|
|
148
|
+
" ❌ All training run records and checkpoint metadata (NOT local checkpoint files)\n"
|
|
149
|
+
" ❌ All future records\n"
|
|
150
|
+
" ❌ All sampling session records\n"
|
|
151
|
+
" ❌ Configuration signature\n"
|
|
152
|
+
"\n"
|
|
153
|
+
"⚠️ The server will start fresh with NO previous state.\n"
|
|
154
|
+
"⚠️ This action CANNOT be undone!\n"
|
|
155
|
+
"⚠️ Local checkpoint files on disk are NOT affected.\n"
|
|
156
|
+
f"⚠️ Only data in namespace '{namespace}' will be affected.\n"
|
|
157
|
+
)
|
|
158
|
+
confirmed = typer.confirm(
|
|
159
|
+
f"Do you REALLY want to delete all data in namespace '{namespace}'?",
|
|
160
|
+
default=False,
|
|
161
|
+
)
|
|
162
|
+
if not confirmed:
|
|
163
|
+
typer.echo("Aborted. No data was cleared.")
|
|
164
|
+
raise typer.Exit(0)
|
|
165
|
+
|
|
166
|
+
deleted_count, cleared_namespace = flush_all_data()
|
|
167
|
+
typer.secho(
|
|
168
|
+
f"✅ Cleared {deleted_count} keys from namespace '{cleared_namespace}'.",
|
|
169
|
+
fg=typer.colors.GREEN,
|
|
170
|
+
)
|
|
171
|
+
typer.echo("Persistence data has been cleared. You can now start the server fresh.")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _validate_persistence_config(config: AppConfig) -> None:
|
|
175
|
+
"""Validate that persistence config matches stored config.
|
|
176
|
+
|
|
177
|
+
If config mismatch is detected, exits with an error message.
|
|
178
|
+
"""
|
|
179
|
+
if not config.persistence.enabled:
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
# Configure the Redis store first
|
|
183
|
+
store = get_redis_store()
|
|
184
|
+
store.configure(config.persistence)
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
validate_config_signature(config)
|
|
188
|
+
except ConfigMismatchError as e:
|
|
189
|
+
typer.secho(
|
|
190
|
+
"\n 🚫 FATAL ERROR: Configuration Mismatch Detected 🚫",
|
|
191
|
+
fg=typer.colors.RED,
|
|
192
|
+
bold=True,
|
|
193
|
+
)
|
|
194
|
+
typer.echo(f"\n{e}\n")
|
|
195
|
+
raise typer.Exit(1) from e
|
|
196
|
+
|
|
197
|
+
|
|
49
198
|
def _init_telemetry(config: AppConfig, log_level: str) -> None:
|
|
50
199
|
"""Initialize OpenTelemetry if enabled."""
|
|
51
200
|
# Configure root logger level to ensure logs flow to OTel
|
|
@@ -71,6 +220,10 @@ def launch(
|
|
|
71
220
|
) -> None:
|
|
72
221
|
"""Launch the TuFT server."""
|
|
73
222
|
app_config = _build_config(config_path, checkpoint_dir)
|
|
223
|
+
|
|
224
|
+
# Validate persistence configuration before starting
|
|
225
|
+
_validate_persistence_config(app_config)
|
|
226
|
+
|
|
74
227
|
# Initialize telemetry before starting the server
|
|
75
228
|
_init_telemetry(app_config, log_level)
|
|
76
229
|
logging.getLogger("tuft").info("Server starting on %s:%s", host, port)
|
|
@@ -84,7 +237,7 @@ def launch(
|
|
|
84
237
|
|
|
85
238
|
|
|
86
239
|
def main() -> None:
|
|
87
|
-
app()
|
|
240
|
+
app(prog_name="tuft")
|
|
88
241
|
|
|
89
242
|
|
|
90
243
|
if __name__ == "__main__":
|