tuft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,327 @@
1
+ import asyncio
2
+ import logging
3
+ from dataclasses import dataclass, field
4
+ from typing import Sequence
5
+
6
+ import numpy as np
7
+ from tinker import types
8
+
9
+ from tuft.backends.base_backend import BaseTrainingBackend
10
+ from tuft.checkpoints import CheckpointRecord
11
+ from tuft.config import ModelConfig
12
+ from tuft.telemetry.tracing import get_tracer, inject_context
13
+
14
+
15
+ _get_tracer = lambda: get_tracer("tuft.training_backend") # noqa: E731
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class HFTrainingBackend(BaseTrainingBackend):
21
+ """A training backend using Hugging Face transformers."""
22
+
23
+ def __init__(self, config: ModelConfig) -> None:
24
+ from .hf_training_model import HFTrainingModel
25
+
26
+ self.config = config
27
+ logger.info("Ray actor created: HFTrainingModel(%s)", config.model_name)
28
+ self.model = HFTrainingModel.get_actor(config)
29
+
30
+ async def async_init(self) -> None:
31
+ await self.model.async_init.remote()
32
+
33
+ async def create_adapter(self, lora_id: str, lora_config: types.LoraConfig) -> None:
34
+ """Create a LoRA adapter with the given ID and configuration."""
35
+ with _get_tracer().start_as_current_span("training_backend.create_adapter") as span:
36
+ span.set_attribute("tuft.lora_id", lora_id)
37
+ span.set_attribute("tuft.lora_rank", lora_config.rank)
38
+ # Inject trace context for Ray actor
39
+ trace_context: dict[str, str] = {}
40
+ inject_context(trace_context)
41
+ await self.model.create_adapter.remote(lora_id, lora_config, trace_context)
42
+
43
+ async def remove_adapter(self, lora_id: str) -> None:
44
+ with _get_tracer().start_as_current_span("training_backend.remove_adapter") as span:
45
+ span.set_attribute("tuft.lora_id", lora_id)
46
+ await self.model.remove_adapter.remote(lora_id)
47
+
48
+ async def forward(
49
+ self,
50
+ data: list[types.Datum],
51
+ lora_id: str,
52
+ loss_fn: types.LossFnType,
53
+ loss_fn_config: dict[str, float] | None,
54
+ backward: bool = False,
55
+ ) -> types.ForwardBackwardOutput:
56
+ """Forward pass (and backward if specified).
57
+
58
+ Args:
59
+ data: List of Datum objects containing input data.
60
+ lora_id: The LoRA adapter ID to use.
61
+ loss_fn: The loss function to apply.
62
+ loss_fn_config: Optional configuration for the loss function.
63
+ backward: Whether to perform backward pass.
64
+
65
+ Returns:
66
+ ForwardBackwardOutput: The output of the forward (and backward) pass.
67
+ """
68
+ span_name = "training_backend.forward_backward" if backward else "training_backend.forward"
69
+ with _get_tracer().start_as_current_span(span_name) as span:
70
+ span.set_attribute("tuft.lora_id", lora_id)
71
+ span.set_attribute("tuft.backward", backward)
72
+ span.set_attribute("tuft.data_count", len(data))
73
+ # Inject trace context for Ray actor
74
+ trace_context: dict[str, str] = {}
75
+ inject_context(trace_context)
76
+ return await self.model.forward.remote(
77
+ data=data,
78
+ lora_id=lora_id,
79
+ loss_fn=loss_fn,
80
+ loss_fn_config=loss_fn_config,
81
+ backward=backward,
82
+ trace_context=trace_context,
83
+ )
84
+
85
+ async def optim_step(
86
+ self, adam_params: types.AdamParams, lora_id: str
87
+ ) -> types.OptimStepResponse:
88
+ """Perform an optimization step using Adam optimizer.
89
+
90
+ Args:
91
+ adam_params: Parameters for the Adam optimizer.
92
+
93
+ Returns:
94
+ OptimStepResponse: The response containing optimization metrics.
95
+ """
96
+ with _get_tracer().start_as_current_span("training_backend.optim_step") as span:
97
+ span.set_attribute("tuft.lora_id", lora_id)
98
+ # Inject trace context for Ray actor
99
+ trace_context: dict[str, str] = {}
100
+ inject_context(trace_context)
101
+ return await self.model.optim_step.remote(adam_params, lora_id, trace_context)
102
+
103
+ async def save_state(
104
+ self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
105
+ ) -> None:
106
+ """Save the state of the specified LoRA adapter."""
107
+ with _get_tracer().start_as_current_span("training_backend.save_state") as span:
108
+ span.set_attribute("tuft.lora_id", lora_id)
109
+ span.set_attribute("tuft.optimizer", optimizer)
110
+ # Inject trace context for Ray actor
111
+ trace_context: dict[str, str] = {}
112
+ inject_context(trace_context)
113
+ await self.model.save_state.remote(
114
+ lora_id=lora_id,
115
+ checkpoint_record=checkpoint_record,
116
+ optimizer=optimizer,
117
+ trace_context=trace_context,
118
+ )
119
+
120
+ async def load_state(
121
+ self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
122
+ ) -> None:
123
+ """Load the state of the specified LoRA adapter from the given path."""
124
+ with _get_tracer().start_as_current_span("training_backend.load_state") as span:
125
+ span.set_attribute("tuft.lora_id", lora_id)
126
+ span.set_attribute("tuft.optimizer", optimizer)
127
+ # Inject trace context for Ray actor
128
+ trace_context: dict[str, str] = {}
129
+ inject_context(trace_context)
130
+ await self.model.load_state.remote(
131
+ lora_id=lora_id,
132
+ checkpoint_record=checkpoint_record,
133
+ optimizer=optimizer,
134
+ trace_context=trace_context,
135
+ )
136
+
137
+
138
+ @dataclass
139
+ class DummyTrainingBackend(BaseTrainingBackend):
140
+ """A dummy training backend for testing purposes."""
141
+
142
+ _lock: asyncio.Lock = field(init=False, repr=False)
143
+ _weights: np.ndarray = field(init=False, repr=False)
144
+ _adam_m: np.ndarray = field(init=False, repr=False)
145
+ _adam_v: np.ndarray = field(init=False, repr=False)
146
+ _beta1_power: float = field(init=False, default=1.0, repr=False)
147
+ _beta2_power: float = field(init=False, default=1.0, repr=False)
148
+ _pending_grad: np.ndarray | None = field(init=False, default=None, repr=False)
149
+ _pending_examples: int = field(init=False, default=0, repr=False)
150
+ _embedding_cache: dict[int, np.ndarray] = field(init=False, default_factory=dict, repr=False)
151
+ step: int = field(init=False, default=0)
152
+
153
+ def __init__(self, config: ModelConfig) -> None:
154
+ self.config = config
155
+ self.seed = config.seed
156
+ self.hidden_dim = 16
157
+ rng = np.random.default_rng(self.seed or 0)
158
+ self._lock = asyncio.Lock()
159
+ self._weights = rng.standard_normal(self.hidden_dim, dtype=np.float32)
160
+ self._adam_m = np.zeros_like(self._weights)
161
+ self._adam_v = np.zeros_like(self._weights)
162
+ self._embedding_cache = {}
163
+ self._adapters = dict()
164
+
165
+ async def async_init(self) -> None:
166
+ pass
167
+
168
+ # ------------------------------------------------------------------
169
+ # Forward / backward helpers
170
+ # ------------------------------------------------------------------
171
+ async def forward(
172
+ self,
173
+ data: list[types.Datum],
174
+ lora_id: str,
175
+ loss_fn: types.LossFnType,
176
+ loss_fn_config: dict[str, float] | None,
177
+ backward: bool = False,
178
+ ) -> types.ForwardBackwardOutput:
179
+ return await self._run_step(data, backward=backward)
180
+
181
+ async def _run_step(
182
+ self, data: list[types.Datum], *, backward: bool
183
+ ) -> types.ForwardBackwardOutput:
184
+ outputs: list[types.LossFnOutput] = []
185
+ total_loss = 0.0
186
+ grad_accum = np.zeros_like(self._weights)
187
+ for datum in data:
188
+ prompt_tokens = datum.model_input.to_ints()
189
+ target_tokens = self._target_tokens(datum)
190
+ prompt_vec = self._vectorize(prompt_tokens)
191
+ target_scalar = self._target_scalar(target_tokens)
192
+ prediction = float(np.dot(self._weights, prompt_vec))
193
+ loss = (prediction - target_scalar) ** 2
194
+ total_loss += loss
195
+ if backward:
196
+ grad = 2 * (prediction - target_scalar) * prompt_vec
197
+ grad_accum += grad
198
+ logprob_tensor = types.TensorData(
199
+ data=[float(-abs(prediction - target_scalar))] * max(len(target_tokens), 1),
200
+ dtype="float32",
201
+ shape=[max(len(target_tokens), 1)],
202
+ )
203
+ outputs.append({"logprobs": logprob_tensor})
204
+
205
+ metrics = {
206
+ "loss:sum": total_loss,
207
+ "step:max": float(self.step),
208
+ }
209
+ if backward:
210
+ grad_norm = float(np.linalg.norm(grad_accum) / max(len(data), 1))
211
+ metrics["grad_norm:mean"] = grad_norm
212
+ async with self._lock:
213
+ if self._pending_grad is None:
214
+ self._pending_grad = grad_accum
215
+ else:
216
+ self._pending_grad += grad_accum
217
+ self._pending_examples += len(data)
218
+
219
+ return types.ForwardBackwardOutput(
220
+ loss_fn_output_type="ToyLoss",
221
+ loss_fn_outputs=outputs,
222
+ metrics=metrics,
223
+ )
224
+
225
+ # ------------------------------------------------------------------
226
+ # Optimizer
227
+ # ------------------------------------------------------------------
228
+ async def optim_step(
229
+ self, adam_params: types.AdamParams, lora_id: str
230
+ ) -> types.OptimStepResponse:
231
+ async with self._lock:
232
+ grad = self._pending_grad
233
+ examples = self._pending_examples
234
+ self._pending_grad = None
235
+ self._pending_examples = 0
236
+
237
+ if grad is None or not np.any(grad):
238
+ return types.OptimStepResponse(
239
+ metrics={
240
+ "learning_rate:mean": adam_params.learning_rate,
241
+ "step:max": float(self.step),
242
+ }
243
+ )
244
+
245
+ grad = grad / max(examples, 1)
246
+ if adam_params.grad_clip_norm > 0:
247
+ norm = np.linalg.norm(grad)
248
+ if norm > adam_params.grad_clip_norm:
249
+ grad *= adam_params.grad_clip_norm / max(norm, 1e-12)
250
+
251
+ if adam_params.weight_decay:
252
+ grad += adam_params.weight_decay * self._weights
253
+
254
+ beta1 = adam_params.beta1
255
+ beta2 = adam_params.beta2
256
+
257
+ self._adam_m = beta1 * self._adam_m + (1 - beta1) * grad
258
+ self._adam_v = beta2 * self._adam_v + (1 - beta2) * (grad**2)
259
+ self._beta1_power *= beta1
260
+ self._beta2_power *= beta2
261
+ m_hat = self._adam_m / (1 - self._beta1_power + 1e-12)
262
+ v_hat = self._adam_v / (1 - self._beta2_power + 1e-12)
263
+
264
+ update = adam_params.learning_rate * m_hat / (np.sqrt(v_hat) + adam_params.eps)
265
+ self._weights -= update
266
+ self.step += 1
267
+
268
+ metrics = {
269
+ "learning_rate:mean": adam_params.learning_rate,
270
+ "step:max": float(self.step),
271
+ "update_norm:mean": float(np.linalg.norm(update)),
272
+ }
273
+ return types.OptimStepResponse(metrics=metrics)
274
+
275
+ async def save_state(
276
+ self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
277
+ ) -> None:
278
+ if lora_id not in self._adapters:
279
+ raise ValueError(f"Adapter {lora_id} does not exist.")
280
+ # dummy save
281
+
282
+ async def load_state(
283
+ self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
284
+ ) -> None:
285
+ # create a dummy adapter on load
286
+ self._adapters[lora_id] = types.LoraConfig(rank=4)
287
+
288
+ async def create_adapter(self, lora_id: str, lora_config: types.LoraConfig) -> None:
289
+ self._adapters[lora_id] = lora_config
290
+
291
+ async def remove_adapter(self, lora_id: str) -> None:
292
+ self._adapters.pop(lora_id, None)
293
+
294
+ # ------------------------------------------------------------------
295
+ # Internal helpers
296
+ # ------------------------------------------------------------------
297
+ def _target_tokens(self, datum: types.Datum) -> list[int]:
298
+ if not datum.loss_fn_inputs:
299
+ return datum.model_input.to_ints()
300
+ tensor = datum.loss_fn_inputs.get("target_tokens")
301
+ if tensor is None:
302
+ return datum.model_input.to_ints()
303
+ return [int(value) for value in tensor.data]
304
+
305
+ def _vectorize(self, tokens: Sequence[int]) -> np.ndarray:
306
+ if not tokens:
307
+ return np.zeros(self.hidden_dim, dtype=np.float32)
308
+ vecs = [self._token_embedding(token) for token in tokens]
309
+ return np.mean(vecs, axis=0)
310
+
311
+ def _token_embedding(self, token_id: int) -> np.ndarray:
312
+ cached = self._embedding_cache.get(token_id)
313
+ if cached is None:
314
+ rng = np.random.default_rng(self.seed + token_id)
315
+ cached = rng.standard_normal(self.hidden_dim, dtype=np.float32)
316
+ self._embedding_cache[token_id] = cached
317
+ return cached
318
+
319
+ def _safe_mean(self, values: Sequence[int]) -> float:
320
+ if not values:
321
+ return 0.0
322
+ return float(sum(values) / len(values))
323
+
324
+ def _target_scalar(self, tokens: Sequence[int]) -> float:
325
+ if not tokens:
326
+ return 0.0
327
+ return np.tanh(self._safe_mean(tokens) / 100.0)
tuft/checkpoints.py ADDED
@@ -0,0 +1,193 @@
1
+ """Module for managing checkpoints on disk."""
2
+
3
+ import contextlib
4
+ import shutil
5
+ from datetime import datetime, timezone
6
+ from pathlib import Path
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer
9
+ from tinker import types
10
+
11
+ from .exceptions import CheckpointMetadataReadException
12
+
13
+
14
+ class CheckpointMetadata(BaseModel):
15
+ """A representation of checkpoint metadata."""
16
+
17
+ model_id: str
18
+ name: str
19
+ base_model: str
20
+ checkpoint_type: types.CheckpointType
21
+ created_at: str
22
+ session_id: str
23
+ tinker_path: str
24
+ owner_name: str
25
+ size_bytes: int = 0
26
+ lora_rank: int | None = None
27
+ public: bool = False
28
+ future_id: int = 0
29
+ seq_id: int | None = None
30
+
31
+
32
+ class CheckpointRecord(BaseModel):
33
+ """A record representing a checkpoint on disk."""
34
+
35
+ model_config = ConfigDict(arbitrary_types_allowed=True)
36
+
37
+ checkpoint_id: str
38
+ owner_name: str
39
+ checkpoint_type: types.CheckpointType
40
+ training_run_id: str
41
+ path: Path
42
+ size_bytes: int = 0
43
+ public: bool = False
44
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
45
+ future_id: int = 0
46
+ seq_id: int | None = None
47
+
48
+ @field_serializer("path")
49
+ def serialize_path(self, path: Path) -> str:
50
+ """Serialize Path to string for JSON."""
51
+ return str(path)
52
+
53
+ @property
54
+ def tinker_checkpoint(self) -> types.Checkpoint:
55
+ """Get a Tinker Checkpoint instance representing this record."""
56
+ return types.Checkpoint(
57
+ checkpoint_id=self.checkpoint_id,
58
+ checkpoint_type=self.checkpoint_type,
59
+ time=self.created_at,
60
+ tinker_path=self.tinker_path,
61
+ size_bytes=self.size_bytes,
62
+ public=self.public,
63
+ )
64
+
65
+ @property
66
+ def metadata(self) -> CheckpointMetadata:
67
+ """Get the checkpoint metadata.
68
+
69
+ Raises:
70
+ CheckpointMetadataReadException: If the metadata file does not
71
+ exist or is invalid.
72
+ """
73
+ try:
74
+ return CheckpointMetadata.model_validate_json(
75
+ self.metadata_path.read_text(encoding="utf-8")
76
+ )
77
+ except FileNotFoundError as exc:
78
+ raise CheckpointMetadataReadException(checkpoint_id=self.checkpoint_id) from exc
79
+ except ValidationError as exc:
80
+ raise CheckpointMetadataReadException(checkpoint_id=self.checkpoint_id) from exc
81
+
82
+ @property
83
+ def tinker_path(self) -> str:
84
+ """Get the tinker style path for this checkpoint."""
85
+ folder = "weights" if self.checkpoint_type == "training" else "sampler_weights"
86
+ return f"tinker://{self.training_run_id}/{folder}/{self.checkpoint_id}"
87
+
88
+ @property
89
+ def adapter_path(self) -> Path:
90
+ """Get the path to the adapter weights file."""
91
+ return self.path / "adapter"
92
+
93
+ @property
94
+ def optimizer_path(self) -> Path:
95
+ """Get the path to the optimizer state file."""
96
+ return self.path / "optimizer"
97
+
98
+ @property
99
+ def metadata_path(self) -> Path:
100
+ """Get the path to the metadata JSON file."""
101
+ return self.path / "metadata.json"
102
+
103
+ def set_visibility(self, public: bool) -> None:
104
+ """Set the visibility of the checkpoint."""
105
+ self.public = public
106
+ metadata = self.metadata
107
+ metadata.public = public
108
+ self.save_metadata(
109
+ base_model=metadata.base_model,
110
+ session_id=metadata.session_id,
111
+ lora_rank=metadata.lora_rank,
112
+ )
113
+
114
+ def save_metadata(self, base_model: str, session_id: str, lora_rank: int | None) -> None:
115
+ """Save the checkpoint metadata to disk."""
116
+ # check the format of metadata
117
+ try:
118
+ metadata = CheckpointMetadata(
119
+ model_id=self.training_run_id,
120
+ name=self.checkpoint_id,
121
+ base_model=base_model,
122
+ checkpoint_type=self.checkpoint_type,
123
+ created_at=self.created_at.isoformat(),
124
+ session_id=session_id,
125
+ tinker_path=self.tinker_path,
126
+ owner_name=self.owner_name,
127
+ lora_rank=lora_rank,
128
+ public=self.public,
129
+ size_bytes=self.size_bytes,
130
+ future_id=self.future_id,
131
+ seq_id=self.seq_id,
132
+ )
133
+ except Exception as e:
134
+ raise ValueError(f"Invalid checkpoint metadata: {e}") from e
135
+ self.metadata_path.write_text(metadata.model_dump_json(indent=2), encoding="utf-8")
136
+
137
+ @classmethod
138
+ def from_tinker_path(cls, path: str, checkpoint_root_dir: Path) -> "CheckpointRecord":
139
+ """Create a CheckpointRecord from a Tinker path.
140
+
141
+ Raises:
142
+ FileNotFoundError: If the checkpoint directory or metadata.json is missing.
143
+ json.JSONDecodeError: If metadata.json cannot be parsed as JSON.
144
+ """
145
+ parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(path)
146
+ checkpoint_path = (
147
+ checkpoint_root_dir / parsed.training_run_id / parsed.checkpoint_id.split("/", 1)[-1]
148
+ )
149
+ record = cls(
150
+ checkpoint_id=parsed.checkpoint_id.split("/", 1)[-1],
151
+ checkpoint_type=parsed.checkpoint_type,
152
+ training_run_id=parsed.training_run_id,
153
+ path=checkpoint_path,
154
+ owner_name="", # Will be filled from metadata later
155
+ size_bytes=0, # Will be filled from metadata later
156
+ )
157
+ metadata = record.metadata # This may raise FileNotFoundError or JSONDecodeError
158
+ record.owner_name = metadata.owner_name
159
+ record.size_bytes = metadata.size_bytes
160
+ record.public = metadata.public
161
+ record.created_at = datetime.fromisoformat(metadata.created_at)
162
+ record.future_id = metadata.future_id
163
+ record.seq_id = metadata.seq_id
164
+ return record
165
+
166
+ def delete(self) -> None:
167
+ """Delete the checkpoint from disk."""
168
+ with contextlib.suppress(FileNotFoundError):
169
+ shutil.rmtree(self.path)
170
+
171
+ @classmethod
172
+ def from_training_run(
173
+ cls,
174
+ training_run_id: str,
175
+ checkpoint_name: str,
176
+ owner_name: str,
177
+ checkpoint_type: types.CheckpointType,
178
+ checkpoint_root_dir: Path,
179
+ exist_ok: bool = True,
180
+ ) -> "CheckpointRecord":
181
+ """Create a CheckpointRecord from a training run."""
182
+ checkpoint_dir = checkpoint_root_dir / training_run_id / checkpoint_name
183
+ if not exist_ok and checkpoint_dir.exists():
184
+ raise FileExistsError(f"Checkpoint directory already exists: {checkpoint_dir}")
185
+ checkpoint_dir.mkdir(parents=True, exist_ok=exist_ok)
186
+ return cls(
187
+ checkpoint_id=checkpoint_name,
188
+ owner_name=owner_name,
189
+ checkpoint_type=checkpoint_type,
190
+ training_run_id=training_run_id,
191
+ path=checkpoint_dir,
192
+ size_bytes=0,
193
+ )
tuft/cli.py ADDED
@@ -0,0 +1,124 @@
1
+ """Command line utilities for the local TuFT server."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import typer
10
+ import uvicorn
11
+
12
+ from .config import AppConfig, load_yaml_config
13
+ from .server import create_root_app
14
+ from .telemetry import init_telemetry
15
+ from .telemetry.metrics import ResourceMetricsCollector
16
+
17
+
18
+ app = typer.Typer(help="TuFT - Tenant-unified Fine-Tuning Server.", no_args_is_help=True)
19
+
20
+
21
+ # Required for Typer to recognize subcommands when using no_args_is_help=True
22
+ @app.callback()
23
+ def callback() -> None:
24
+ """TuFT - Tenant-unified Fine-Tuning Server."""
25
+
26
+
27
+ # Default paths based on TUFT_HOME
28
+ _TUFT_HOME = Path(os.environ.get("TUFT_HOME", Path.home() / ".tuft"))
29
+ _DEFAULT_CONFIG_PATH = _TUFT_HOME / "configs" / "tuft_config.yaml"
30
+ _DEFAULT_CHECKPOINT_DIR = _TUFT_HOME / "checkpoints"
31
+
32
+ _HOST_OPTION = typer.Option("127.0.0.1", "--host", help="Interface to bind", envvar="TUFT_HOST")
33
+ _PORT_OPTION = typer.Option(10610, "--port", "-p", help="Port to bind", envvar="TUFT_PORT")
34
+ _LOG_LEVEL_OPTION = typer.Option(
35
+ "info", "--log-level", help="Uvicorn log level", envvar="TUFT_LOG_LEVEL"
36
+ )
37
+ _RELOAD_OPTION = typer.Option(False, "--reload", help="Enable auto-reload (development only)")
38
+ _CONFIG_OPTION = typer.Option(
39
+ None,
40
+ "--config",
41
+ "-c",
42
+ help=f"Path to a TuFT configuration file (YAML). Defaults to {_DEFAULT_CONFIG_PATH}",
43
+ envvar="TUFT_CONFIG",
44
+ )
45
+ _CHECKPOINT_DIR_OPTION = typer.Option(
46
+ None,
47
+ "--checkpoint-dir",
48
+ help=f"Override checkpoint_dir from config file. Defaults to {_DEFAULT_CHECKPOINT_DIR}",
49
+ envvar="TUFT_CHECKPOINT_DIR",
50
+ )
51
+
52
+
53
+ def _resolve_config_path(config_path: Path | None) -> Path:
54
+ """Resolve the config path, falling back to default if not provided."""
55
+ if config_path is not None:
56
+ return config_path
57
+ if _DEFAULT_CONFIG_PATH.exists():
58
+ return _DEFAULT_CONFIG_PATH
59
+ raise typer.BadParameter(
60
+ f"Configuration file must be provided via --config or TUFT_CONFIG, "
61
+ f"or create a default config at {_DEFAULT_CONFIG_PATH}"
62
+ )
63
+
64
+
65
+ def _build_config(
66
+ config_path: Path | None,
67
+ checkpoint_dir: Path | None,
68
+ ) -> AppConfig:
69
+ resolved_config_path = _resolve_config_path(config_path)
70
+ config = load_yaml_config(resolved_config_path)
71
+ # Apply checkpoint_dir override, or use default if not in config
72
+ if checkpoint_dir is not None:
73
+ config.checkpoint_dir = checkpoint_dir.expanduser()
74
+ elif config.checkpoint_dir is None:
75
+ config.checkpoint_dir = _DEFAULT_CHECKPOINT_DIR
76
+ # Guarantee checkpoint_dir is set after resolution
77
+ assert config.checkpoint_dir is not None, "checkpoint_dir must be set after config resolution"
78
+ config.ensure_directories()
79
+ return config
80
+
81
+
82
+ def _init_telemetry(config: AppConfig, log_level: str) -> None:
83
+ """Initialize OpenTelemetry if enabled."""
84
+ # Configure root logger level to ensure logs flow to OTel
85
+ numeric_level = getattr(logging, log_level.upper(), logging.INFO)
86
+
87
+ if not config.telemetry.enabled:
88
+ logging.basicConfig(level=numeric_level)
89
+ return
90
+
91
+ init_telemetry(config.telemetry)
92
+ # Start resource metrics collection
93
+ ResourceMetricsCollector.start(str(config.checkpoint_dir))
94
+
95
+
96
+ @app.command()
97
+ def launch(
98
+ host: str = _HOST_OPTION,
99
+ port: int = _PORT_OPTION,
100
+ log_level: str = _LOG_LEVEL_OPTION,
101
+ reload: bool = _RELOAD_OPTION,
102
+ config_path: Path | None = _CONFIG_OPTION,
103
+ checkpoint_dir: Path | None = _CHECKPOINT_DIR_OPTION,
104
+ ) -> None:
105
+ """Launch the TuFT server."""
106
+ app_config = _build_config(config_path, checkpoint_dir)
107
+ # Initialize telemetry before starting the server
108
+ _init_telemetry(app_config, log_level)
109
+ logging.getLogger("tuft").info("Server starting on %s:%s", host, port)
110
+ uvicorn.run(
111
+ create_root_app(app_config),
112
+ host=host,
113
+ port=port,
114
+ log_level=log_level,
115
+ reload=reload,
116
+ )
117
+
118
+
119
+ def main() -> None:
120
+ app(prog_name="tuft")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()