tuft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
+ import logging
2
3
  import shutil
3
- from typing import Dict
4
+ from typing import Callable, Dict
4
5
 
5
6
  import ray
6
7
  import torch
@@ -14,13 +15,12 @@ from transformers import AutoModelForCausalLM
14
15
 
15
16
  from tuft.checkpoints import CheckpointRecord
16
17
  from tuft.config import ModelConfig
17
- from tuft.loss_fn import get_loss_fn
18
+ from tuft.loss_fn import get_loss_fn, metrics_reduction
18
19
  from tuft.telemetry.tracing import extract_context, get_tracer
19
20
 
20
21
 
21
22
  _get_tracer = lambda: get_tracer("tuft.hf_training_model") # noqa: E731
22
23
 
23
-
24
24
  MODULE_MAP = {
25
25
  "llama": {
26
26
  "attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
@@ -58,6 +58,8 @@ class HFTrainingModel:
58
58
  self.model = self._init_peft_model(config)
59
59
  self.adapter_optimizer: Dict[str, torch.optim.AdamW] = {}
60
60
  self._lock = asyncio.Lock()
61
+ self.logger = logging.getLogger()
62
+ self.micro_batch_size = config.micro_batch_size
61
63
 
62
64
  async def async_init(self) -> None:
63
65
  """Do nothing for now. Just used to make sure the actor is ready."""
@@ -193,7 +195,9 @@ class HFTrainingModel:
193
195
  async with self._lock:
194
196
  if lora_id in self.adapter_optimizer:
195
197
  self.model.delete_adapter(lora_id)
196
- self.adapter_optimizer.pop(lora_id)
198
+ optimizer = self.adapter_optimizer.pop(lora_id)
199
+ del optimizer
200
+ torch.cuda.empty_cache()
197
201
 
198
202
  # --------------------------------
199
203
  # Training methods
@@ -207,7 +211,7 @@ class HFTrainingModel:
207
211
  backward: bool = False,
208
212
  trace_context: dict[str, str] | None = None,
209
213
  ) -> types.ForwardBackwardOutput:
210
- """Forward pass (and backward if specified).
214
+ """Forward pass with micro-batch gradient accumulation.
211
215
 
212
216
  Args:
213
217
  data: List of Datum objects containing input data.
@@ -222,73 +226,163 @@ class HFTrainingModel:
222
226
  """
223
227
  ctx = extract_context(trace_context or {})
224
228
  span_name = "hf_model.forward_backward" if backward else "hf_model.forward"
229
+
225
230
  with _get_tracer().start_as_current_span(span_name, context=ctx) as span:
226
231
  span.set_attribute("tuft.lora_id", lora_id)
227
232
  span.set_attribute("tuft.backward", backward)
228
233
  span.set_attribute("tuft.data_count", len(data))
229
- # Prepare input tensors
230
- input_ids = [
231
- torch.tensor(datum.model_input.to_ints(), dtype=torch.long) for datum in data
232
- ]
233
- input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
234
- attention_mask = (input_ids_padded != 0).long()
235
- position_ids = (
236
- torch.arange(input_ids_padded.size(1), dtype=torch.long)
237
- .unsqueeze(0)
238
- .expand(input_ids_padded.size(0), -1)
239
- )
240
- # Move tensors to model device
241
- device = next(self.model.parameters()).device
242
- input_ids_padded = input_ids_padded.to(device)
243
- attention_mask = attention_mask.to(device)
244
- position_ids = position_ids.to(device)
245
-
246
- # Activate the correct adapter
247
- async with self._lock:
248
- self._activate_adapter(lora_id)
249
-
250
- # Forward pass
251
- outputs = self.model(
252
- input_ids=input_ids_padded,
253
- attention_mask=attention_mask,
254
- position_ids=position_ids,
255
- return_dict=True,
256
- )
257
234
 
258
- # Compute loss
259
- if loss_fn_config is None:
260
- loss_fn_config = {}
261
- loss_fn_callable = get_loss_fn(loss_fn)
262
- logits = outputs.logits
263
- if "temperature" in loss_fn_config:
264
- temperature = loss_fn_config["temperature"]
265
- logits.div_(temperature)
235
+ batch_size = len(data)
236
+ micro_batch_size = self.config.micro_batch_size
266
237
 
267
- loss_fn_inputs = self._prepare_loss_fn_inputs(data)
238
+ num_micro_batches = (batch_size + micro_batch_size - 1) // micro_batch_size
239
+ span.set_attribute("tuft.num_micro_batches", num_micro_batches)
268
240
 
269
- ## compute target_logprobs from logits and target_tokens
270
- target_tokens = loss_fn_inputs["target_tokens"]
271
- target_logprobs = self._compute_logprobs_from_target_tokens(logits, target_tokens)
272
- loss_fn_inputs["target_logprobs"] = target_logprobs
241
+ if num_micro_batches > 1:
242
+ self.logger.debug(
243
+ f"[MICRO_BATCH] Splitting batch_size={batch_size} into "
244
+ f"{num_micro_batches} micro-batches of size {micro_batch_size}"
245
+ )
273
246
 
274
- loss, metric = loss_fn_callable(loss_fn_inputs, loss_fn_config)
247
+ loss_fn_callable = get_loss_fn(loss_fn)
248
+ all_loss_fn_outputs = []
249
+ micro_batch_weights = []
250
+ metric_list = []
251
+ total_loss = 0.0
275
252
 
276
- # Backward pass if needed
277
- if backward:
278
- loss.backward()
253
+ async with self._lock:
254
+ self._activate_adapter(lora_id)
279
255
 
280
- unpaded_logprobs = self._unpad_tensor(
281
- target_logprobs, [len(datum.model_input.to_ints()) for datum in data]
256
+ for micro_idx in range(num_micro_batches):
257
+ start_idx = micro_idx * micro_batch_size
258
+ end_idx = min(start_idx + micro_batch_size, batch_size)
259
+ micro_data = data[start_idx:end_idx]
260
+
261
+ torch.cuda.reset_peak_memory_stats()
262
+ self.logger.debug(
263
+ f"[GPU-micro_batch_{micro_idx}] before_forward: "
264
+ f"allocated={torch.cuda.memory_allocated() / 1e9:.2f}GB, "
265
+ f"reserved={torch.cuda.memory_reserved() / 1e9:.2f}GB"
266
+ )
267
+
268
+ micro_loss, micro_metrics, micro_outputs = await self._forward_micro_batch(
269
+ micro_data,
270
+ loss_fn_callable,
271
+ loss_fn_config,
272
+ backward=backward,
273
+ )
274
+
275
+ total_loss += micro_loss
276
+ all_loss_fn_outputs.extend(micro_outputs)
277
+ micro_batch_weights.append(len(micro_outputs))
278
+
279
+ metric_list.append(micro_metrics)
280
+
281
+ self.logger.debug(
282
+ f"[GPU-micro_batch_{micro_idx}] after_forward: "
283
+ f"allocated={torch.cuda.memory_allocated() / 1e9:.2f}GB, "
284
+ f"reserved={torch.cuda.memory_reserved() / 1e9:.2f}GB, "
285
+ f"max_allocated={torch.cuda.max_memory_allocated() / 1e9:.2f}GB"
286
+ )
287
+
288
+ torch.cuda.empty_cache()
289
+
290
+ avg_loss = total_loss / num_micro_batches
291
+ self.logger.debug(f"Average loss: {avg_loss}")
292
+ metric_list = metrics_reduction(metric_list, micro_batch_weights)
293
+
294
+ self.logger.debug(
295
+ f"[GPU-after_micro_batches] allocated={torch.cuda.memory_allocated() / 1e9:.2f}GB"
296
+ f", reserved={torch.cuda.memory_reserved() / 1e9:.2f}GB"
282
297
  )
298
+
283
299
  return types.ForwardBackwardOutput(
284
300
  loss_fn_output_type=loss_fn,
285
- loss_fn_outputs=[
286
- {"logprobs": types.TensorData.from_torch(logprobs.detach())}
287
- for logprobs in unpaded_logprobs
288
- ],
289
- metrics=metric,
301
+ loss_fn_outputs=all_loss_fn_outputs,
302
+ metrics=metric_list or {},
290
303
  )
291
304
 
305
+ async def _forward_micro_batch(
306
+ self,
307
+ data: list[types.Datum],
308
+ loss_fn_callable: Callable,
309
+ loss_fn_config: dict[str, float] | None,
310
+ backward: bool,
311
+ ) -> tuple[float, dict[str, float], list[dict]]:
312
+ """Process a single micro-batch.
313
+
314
+ Returns:
315
+ tuple: (loss_value, metrics_dict, loss_fn_outputs_list)
316
+ """
317
+ # Prepare input tensors
318
+ input_ids = [torch.tensor(datum.model_input.to_ints(), dtype=torch.long) for datum in data]
319
+ input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
320
+ attention_mask = (input_ids_padded != 0).long()
321
+ position_ids = (
322
+ torch.arange(input_ids_padded.size(1), dtype=torch.long)
323
+ .unsqueeze(0)
324
+ .expand(input_ids_padded.size(0), -1)
325
+ )
326
+
327
+ device = next(self.model.parameters()).device
328
+ input_ids_padded = input_ids_padded.to(device)
329
+ attention_mask = attention_mask.to(device)
330
+ position_ids = position_ids.to(device)
331
+
332
+ # Forward pass
333
+ outputs = self.model(
334
+ input_ids=input_ids_padded,
335
+ attention_mask=attention_mask,
336
+ position_ids=position_ids,
337
+ return_dict=True,
338
+ )
339
+
340
+ if loss_fn_config is None:
341
+ loss_fn_config = {}
342
+
343
+ logits = outputs.logits
344
+ del outputs
345
+ torch.cuda.empty_cache()
346
+
347
+ if "temperature" in loss_fn_config:
348
+ temperature = loss_fn_config["temperature"]
349
+ logits = logits / temperature
350
+
351
+ loss_fn_inputs = self._prepare_loss_fn_inputs(data)
352
+ target_tokens = loss_fn_inputs["target_tokens"]
353
+
354
+ target_logprobs = self._compute_logprobs_from_target_tokens(logits, target_tokens)
355
+ del logits
356
+ torch.cuda.empty_cache()
357
+
358
+ loss_fn_inputs["target_logprobs"] = target_logprobs
359
+ loss, metric = loss_fn_callable(loss_fn_inputs, loss_fn_config)
360
+
361
+ # Backward with gradient accumulation
362
+ if backward:
363
+ loss.backward(retain_graph=False)
364
+ torch.cuda.empty_cache()
365
+
366
+ unpaded_logprobs = self._unpad_tensor(
367
+ target_logprobs.detach(),
368
+ [len(datum.model_input.to_ints()) for datum in data],
369
+ )
370
+ loss_fn_outputs = [
371
+ {"logprobs": types.TensorData.from_torch(logprobs.cpu().clone())}
372
+ for logprobs in unpaded_logprobs
373
+ ]
374
+
375
+ loss_value = loss.detach().item()
376
+
377
+ del target_logprobs
378
+ del unpaded_logprobs
379
+ del loss_fn_inputs
380
+ del loss
381
+
382
+ torch.cuda.empty_cache()
383
+
384
+ return loss_value, metric, loss_fn_outputs
385
+
292
386
  async def optim_step(
293
387
  self,
294
388
  adam_params: types.AdamParams,
@@ -316,7 +410,9 @@ class HFTrainingModel:
316
410
  param_group["weight_decay"] = adam_params.weight_decay
317
411
  optimizer.step()
318
412
  optimizer.zero_grad()
319
- return types.OptimStepResponse()
413
+
414
+ torch.cuda.empty_cache()
415
+ return types.OptimStepResponse()
320
416
 
321
417
  # --------------------------------
322
418
  # Helper methods
@@ -362,11 +458,33 @@ class HFTrainingModel:
362
458
  def _compute_logprobs_from_target_tokens(
363
459
  self, logits: torch.Tensor, target_tokens: torch.Tensor
364
460
  ) -> torch.Tensor:
365
- logits_labels = torch.gather(logits, dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)
366
- # loop to reduce peak mem consumption
367
- logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])
368
- logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
369
- return logprobs_labels
461
+ """Compute log probabilities of target tokens from logits with low memory usage.
462
+ https://github.com/OpenRLHF/OpenRLHF/pull/718
463
+ """
464
+ if logits.dtype in [torch.float32, torch.float64]:
465
+ logits_labels = torch.gather(logits, dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(
466
+ -1
467
+ )
468
+ logsumexp_values = torch.stack(
469
+ [
470
+ torch.logsumexp(logit, dim=-1) for logit in logits
471
+ ] # loop to reduce peak mem consumption
472
+ )
473
+ log_probs_labels = (
474
+ logits_labels - logsumexp_values
475
+ ) # log_softmax(x_i) = x_i - logsumexp(x)
476
+ else:
477
+ log_probs_labels = []
478
+ for row_logits, row_labels in zip(
479
+ logits, target_tokens, strict=True
480
+ ): # loop to reduce peak mem consumption
481
+ row_log_probs = torch.nn.functional.log_softmax(row_logits, dim=-1)
482
+ row_log_probs_labels = row_log_probs.gather(
483
+ dim=-1, index=row_labels.unsqueeze(-1)
484
+ ).squeeze(-1)
485
+ log_probs_labels.append(row_log_probs_labels)
486
+ log_probs_labels = torch.stack(log_probs_labels)
487
+ return log_probs_labels
370
488
 
371
489
  def _unpad_tensor(
372
490
  self, padded_tensor: torch.Tensor, original_lengths: list[int]
@@ -383,6 +501,8 @@ class HFTrainingModel:
383
501
  dtype="auto",
384
502
  device_map="auto",
385
503
  )
504
+ model.enable_input_require_grads()
505
+ model.gradient_checkpointing_enable({"use_reentrant": False})
386
506
  peft_config = LoraConfig()
387
507
  peft_model = get_peft_model(model, peft_config=peft_config, adapter_name="default")
388
508
  return peft_model
@@ -398,7 +518,7 @@ class HFTrainingModel:
398
518
  ray.remote(cls)
399
519
  .options(
400
520
  name="training_model_" + config.model_name,
401
- num_gpus=1 if not config.colocate else 1 - config.sampling_memory_fraction,
521
+ num_gpus=(1 if not config.colocate else 1 - config.sampling_memory_fraction),
402
522
  )
403
523
  .remote(config)
404
524
  )
tuft/cli.py CHANGED
@@ -10,12 +10,21 @@ import typer
10
10
  import uvicorn
11
11
 
12
12
  from .config import AppConfig, load_yaml_config
13
+ from .exceptions import ConfigMismatchError
14
+ from .persistence import (
15
+ flush_all_data,
16
+ get_current_namespace,
17
+ get_redis_store,
18
+ validate_config_signature,
19
+ )
13
20
  from .server import create_root_app
14
21
  from .telemetry import init_telemetry
15
22
  from .telemetry.metrics import ResourceMetricsCollector
16
23
 
17
24
 
18
25
  app = typer.Typer(help="TuFT - Tenant-unified Fine-Tuning Server.", no_args_is_help=True)
26
+ clear_app = typer.Typer(help="Clear data commands.", no_args_is_help=True)
27
+ app.add_typer(clear_app, name="clear")
19
28
 
20
29
 
21
30
  # Required for Typer to recognize subcommands when using no_args_is_help=True
@@ -79,6 +88,113 @@ def _build_config(
79
88
  return config
80
89
 
81
90
 
91
+ _FORCE_OPTION = typer.Option(
92
+ False,
93
+ "--force",
94
+ "-f",
95
+ help="Skip confirmation prompts when clearing persistence data.",
96
+ )
97
+
98
+
99
+ @clear_app.command(name="persistence")
100
+ def clear_persistence(
101
+ config_path: Path | None = _CONFIG_OPTION,
102
+ force: bool = _FORCE_OPTION,
103
+ ) -> None:
104
+ """Clear all persistence data and start fresh.
105
+
106
+ This command clears all existing persistence data in the configured namespace.
107
+ Use this when the configuration has changed and you want to discard old data.
108
+ """
109
+ # Build config to get persistence settings
110
+ try:
111
+ resolved_config_path = _resolve_config_path(config_path)
112
+ config = load_yaml_config(resolved_config_path)
113
+ except typer.BadParameter as e:
114
+ typer.secho(f"Error: {e}", fg=typer.colors.RED)
115
+ raise typer.Exit(1) from e
116
+
117
+ if not config.persistence.enabled:
118
+ typer.secho(
119
+ "Persistence is disabled in the configuration. Nothing to clear.",
120
+ fg=typer.colors.YELLOW,
121
+ )
122
+ raise typer.Exit(0)
123
+
124
+ # Configure the store
125
+ store = get_redis_store()
126
+ store.configure(config.persistence)
127
+ namespace = get_current_namespace()
128
+
129
+ if not force:
130
+ typer.secho(
131
+ "\n🚨🚨🚨 CRITICAL WARNING 🚨🚨🚨\n",
132
+ fg=typer.colors.RED,
133
+ bold=True,
134
+ )
135
+ typer.secho(
136
+ "This command will PERMANENTLY DELETE ALL persistence data!\n",
137
+ fg=typer.colors.RED,
138
+ bold=True,
139
+ )
140
+ typer.secho(
141
+ f"📦 Target namespace: '{namespace}'\n",
142
+ fg=typer.colors.YELLOW,
143
+ bold=True,
144
+ )
145
+ typer.echo(
146
+ f"This IRREVERSIBLE action will destroy ALL data in namespace '{namespace}':\n"
147
+ " ❌ All saved sessions\n"
148
+ " ❌ All training run records and checkpoint metadata (NOT local checkpoint files)\n"
149
+ " ❌ All future records\n"
150
+ " ❌ All sampling session records\n"
151
+ " ❌ Configuration signature\n"
152
+ "\n"
153
+ "⚠️ The server will start fresh with NO previous state.\n"
154
+ "⚠️ This action CANNOT be undone!\n"
155
+ "⚠️ Local checkpoint files on disk are NOT affected.\n"
156
+ f"⚠️ Only data in namespace '{namespace}' will be affected.\n"
157
+ )
158
+ confirmed = typer.confirm(
159
+ f"Do you REALLY want to delete all data in namespace '{namespace}'?",
160
+ default=False,
161
+ )
162
+ if not confirmed:
163
+ typer.echo("Aborted. No data was cleared.")
164
+ raise typer.Exit(0)
165
+
166
+ deleted_count, cleared_namespace = flush_all_data()
167
+ typer.secho(
168
+ f"✅ Cleared {deleted_count} keys from namespace '{cleared_namespace}'.",
169
+ fg=typer.colors.GREEN,
170
+ )
171
+ typer.echo("Persistence data has been cleared. You can now start the server fresh.")
172
+
173
+
174
+ def _validate_persistence_config(config: AppConfig) -> None:
175
+ """Validate that persistence config matches stored config.
176
+
177
+ If config mismatch is detected, exits with an error message.
178
+ """
179
+ if not config.persistence.enabled:
180
+ return
181
+
182
+ # Configure the Redis store first
183
+ store = get_redis_store()
184
+ store.configure(config.persistence)
185
+
186
+ try:
187
+ validate_config_signature(config)
188
+ except ConfigMismatchError as e:
189
+ typer.secho(
190
+ "\n 🚫 FATAL ERROR: Configuration Mismatch Detected 🚫",
191
+ fg=typer.colors.RED,
192
+ bold=True,
193
+ )
194
+ typer.echo(f"\n{e}\n")
195
+ raise typer.Exit(1) from e
196
+
197
+
82
198
  def _init_telemetry(config: AppConfig, log_level: str) -> None:
83
199
  """Initialize OpenTelemetry if enabled."""
84
200
  # Configure root logger level to ensure logs flow to OTel
@@ -104,6 +220,10 @@ def launch(
104
220
  ) -> None:
105
221
  """Launch the TuFT server."""
106
222
  app_config = _build_config(config_path, checkpoint_dir)
223
+
224
+ # Validate persistence configuration before starting
225
+ _validate_persistence_config(app_config)
226
+
107
227
  # Initialize telemetry before starting the server
108
228
  _init_telemetry(app_config, log_level)
109
229
  logging.getLogger("tuft").info("Server starting on %s:%s", host, port)
tuft/config.py CHANGED
@@ -2,9 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from dataclasses import dataclass, field
6
5
  from pathlib import Path
7
- from typing import Dict, Iterable, List
6
+ from typing import Any, Iterable
7
+
8
+ from pydantic import BaseModel, Field, model_validator
8
9
 
9
10
  from .persistence import PersistenceConfig
10
11
 
@@ -14,12 +15,7 @@ def _default_checkpoint_dir() -> Path | None:
14
15
  return None
15
16
 
16
17
 
17
- def _default_persistence_config() -> PersistenceConfig:
18
- return PersistenceConfig()
19
-
20
-
21
- @dataclass
22
- class TelemetryConfig:
18
+ class TelemetryConfig(BaseModel):
23
19
  """Configuration for OpenTelemetry integration.
24
20
 
25
21
  Attributes:
@@ -32,26 +28,61 @@ class TelemetryConfig:
32
28
  enabled: bool = False
33
29
  service_name: str = "tuft"
34
30
  otlp_endpoint: str | None = None
35
- resource_attributes: Dict[str, str] = field(default_factory=dict)
31
+ resource_attributes: dict[str, str] = Field(default_factory=dict)
32
+
33
+
34
+ class ModelConfig(BaseModel):
35
+ model_config = {"arbitrary_types_allowed": True}
36
+
37
+ model_name: str # name used in APIs
38
+ model_path: Path # path to model checkpoint
39
+ max_model_len: int # maximum context length supported by the model
40
+ tensor_parallel_size: int = 1 # tensor parallel size
41
+
42
+ # default sampling parameters for this model
43
+ temperature: float = 1.0
44
+ top_p: float = 1.0
45
+ top_k: int = -1
46
+ logprobs: int = 0
47
+ seed: int = 42
48
+ min_response_tokens: int = 0
49
+
50
+ # default lora setting
51
+ max_lora_rank: int = 16 # maximum rank for LoRA adapters
52
+ max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
36
53
 
54
+ # default training setting
55
+ micro_batch_size: int = 1 # micro-batch size for training
37
56
 
38
- def _default_telemetry_config() -> TelemetryConfig:
39
- return TelemetryConfig()
57
+ # whether to colocate sampling and training on the same device
58
+ # only for local testing purposes
59
+ colocate: bool = False
60
+ sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
61
+
62
+ @model_validator(mode="after")
63
+ def validate_colocate(self) -> "ModelConfig":
64
+ if self.colocate and self.tensor_parallel_size != 1:
65
+ raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
66
+ return self
40
67
 
41
68
 
42
- @dataclass
43
- class AppConfig:
44
- """Runtime configuration for the TuFT server."""
69
+ class AppConfig(BaseModel):
70
+ """Runtime configuration for the TuFT server.
45
71
 
46
- checkpoint_dir: Path | None = field(default_factory=_default_checkpoint_dir)
47
- supported_models: List[ModelConfig] = field(default_factory=list)
72
+ This is a Pydantic model that can be serialized/deserialized for persistence.
73
+ """
74
+
75
+ model_config = {"arbitrary_types_allowed": True}
76
+
77
+ checkpoint_dir: Path | None = Field(default_factory=_default_checkpoint_dir)
78
+ supported_models: list[ModelConfig] = Field(default_factory=list)
48
79
  model_owner: str = "local-user"
49
80
  toy_backend_seed: int = 0
50
81
  # TODO: Temporary implementation for user authorization,
51
82
  # replace with proper auth system later
52
- authorized_users: Dict[str, str] = field(default_factory=dict)
53
- persistence: PersistenceConfig = field(default_factory=_default_persistence_config)
54
- telemetry: TelemetryConfig = field(default_factory=_default_telemetry_config)
83
+ authorized_users: dict[str, str] = Field(default_factory=dict)
84
+ persistence: PersistenceConfig = Field(default_factory=PersistenceConfig)
85
+ telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
55
86
 
56
87
  def ensure_directories(self) -> None:
57
88
  if self.checkpoint_dir is not None:
@@ -74,50 +105,21 @@ class AppConfig:
74
105
  self.supported_models = updated
75
106
  return self
76
107
 
77
-
78
- @dataclass
79
- class ModelConfig:
80
- """Configuration for a specific model."""
81
-
82
- model_name: str # name used in APIs
83
- model_path: Path # path to model checkpoint
84
- max_model_len: int # maximum context length supported by the model
85
- tensor_parallel_size: int = 1 # tensor parallel size
86
-
87
- # default sampling parameters for this model
88
- temperature: float = 1.0
89
- top_p: float = 1.0
90
- top_k: int = -1
91
- logprobs: int = 0
92
- seed: int = 42
93
- min_response_tokens: int = 0
94
-
95
- # default lora setting
96
- max_lora_rank: int = 16 # maximum rank for LoRA adapters
97
- max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
98
-
99
- # whether to colocate sampling and training on the same device
100
- # only for local testing purposes
101
- colocate: bool = False
102
- sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
103
-
104
- def __post_init__(self) -> None:
105
- if self.colocate and self.tensor_parallel_size != 1:
106
- raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
108
+ def get_config_for_persistence(self) -> dict[str, Any]:
109
+ """Get config fields for persistence signature (excludes persistence config itself)."""
110
+ return self.model_dump(mode="json", exclude={"persistence"})
107
111
 
108
112
 
109
113
  def load_yaml_config(config_path: Path) -> AppConfig:
110
114
  """Loads an AppConfig from a YAML file."""
111
115
  from omegaconf import OmegaConf
112
116
 
113
- schema = OmegaConf.structured(AppConfig)
114
117
  loaded = OmegaConf.load(config_path)
115
118
  try:
116
- config = OmegaConf.merge(schema, loaded)
117
- app_config = OmegaConf.to_object(config)
118
- assert isinstance(app_config, AppConfig), (
119
- "Loaded config is not of type AppConfig, which should not happen."
120
- )
121
- return app_config
119
+ # Convert OmegaConf to plain dict for Pydantic
120
+ config_dict = OmegaConf.to_container(loaded, resolve=True)
121
+ if not isinstance(config_dict, dict):
122
+ raise ValueError("Config file must contain a dictionary at root level")
123
+ return AppConfig.model_validate(config_dict)
122
124
  except Exception as e:
123
125
  raise ValueError(f"Failed to load config from {config_path}: {e}") from e