tuft 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tuft/config.py CHANGED
@@ -2,23 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from dataclasses import dataclass, field
6
5
  from pathlib import Path
7
- from typing import Dict, Iterable, List
8
-
9
- from .persistence import PersistenceConfig
6
+ from typing import Any, Iterable
10
7
 
8
+ from pydantic import BaseModel, Field, model_validator
11
9
 
12
- def _default_checkpoint_dir() -> Path:
13
- return Path.home() / ".cache" / "tuft" / "checkpoints"
10
+ from .persistence import PersistenceConfig
14
11
 
15
12
 
16
- def _default_persistence_config() -> PersistenceConfig:
17
- return PersistenceConfig()
13
+ def _default_checkpoint_dir() -> Path | None:
14
+ """Return None to let CLI set the default based on TUFT_HOME."""
15
+ return None
18
16
 
19
17
 
20
- @dataclass
21
- class TelemetryConfig:
18
+ class TelemetryConfig(BaseModel):
22
19
  """Configuration for OpenTelemetry integration.
23
20
 
24
21
  Attributes:
@@ -31,29 +28,65 @@ class TelemetryConfig:
31
28
  enabled: bool = False
32
29
  service_name: str = "tuft"
33
30
  otlp_endpoint: str | None = None
34
- resource_attributes: Dict[str, str] = field(default_factory=dict)
31
+ resource_attributes: dict[str, str] = Field(default_factory=dict)
32
+
33
+
34
+ class ModelConfig(BaseModel):
35
+ model_config = {"arbitrary_types_allowed": True}
36
+
37
+ model_name: str # name used in APIs
38
+ model_path: Path # path to model checkpoint
39
+ max_model_len: int # maximum context length supported by the model
40
+ tensor_parallel_size: int = 1 # tensor parallel size
41
+
42
+ # default sampling parameters for this model
43
+ temperature: float = 1.0
44
+ top_p: float = 1.0
45
+ top_k: int = -1
46
+ logprobs: int = 0
47
+ seed: int = 42
48
+ min_response_tokens: int = 0
49
+
50
+ # default lora setting
51
+ max_lora_rank: int = 16 # maximum rank for LoRA adapters
52
+ max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
35
53
 
54
+ # default training setting
55
+ micro_batch_size: int = 1 # micro-batch size for training
36
56
 
37
- def _default_telemetry_config() -> TelemetryConfig:
38
- return TelemetryConfig()
57
+ # whether to colocate sampling and training on the same device
58
+ # only for local testing purposes
59
+ colocate: bool = False
60
+ sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
61
+
62
+ @model_validator(mode="after")
63
+ def validate_colocate(self) -> "ModelConfig":
64
+ if self.colocate and self.tensor_parallel_size != 1:
65
+ raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
66
+ return self
39
67
 
40
68
 
41
- @dataclass
42
- class AppConfig:
43
- """Runtime configuration for the TuFT server."""
69
+ class AppConfig(BaseModel):
70
+ """Runtime configuration for the TuFT server.
44
71
 
45
- checkpoint_dir: Path = field(default_factory=_default_checkpoint_dir)
46
- supported_models: List[ModelConfig] = field(default_factory=list)
72
+ This is a Pydantic model that can be serialized/deserialized for persistence.
73
+ """
74
+
75
+ model_config = {"arbitrary_types_allowed": True}
76
+
77
+ checkpoint_dir: Path | None = Field(default_factory=_default_checkpoint_dir)
78
+ supported_models: list[ModelConfig] = Field(default_factory=list)
47
79
  model_owner: str = "local-user"
48
80
  toy_backend_seed: int = 0
49
81
  # TODO: Temporary implementation for user authorization,
50
82
  # replace with proper auth system later
51
- authorized_users: Dict[str, str] = field(default_factory=dict)
52
- persistence: PersistenceConfig = field(default_factory=_default_persistence_config)
53
- telemetry: TelemetryConfig = field(default_factory=_default_telemetry_config)
83
+ authorized_users: dict[str, str] = Field(default_factory=dict)
84
+ persistence: PersistenceConfig = Field(default_factory=PersistenceConfig)
85
+ telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig)
54
86
 
55
87
  def ensure_directories(self) -> None:
56
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
88
+ if self.checkpoint_dir is not None:
89
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
57
90
 
58
91
  def check_validity(self) -> None:
59
92
  if not self.supported_models:
@@ -72,50 +105,21 @@ class AppConfig:
72
105
  self.supported_models = updated
73
106
  return self
74
107
 
75
-
76
- @dataclass
77
- class ModelConfig:
78
- """Configuration for a specific model."""
79
-
80
- model_name: str # name used in APIs
81
- model_path: Path # path to model checkpoint
82
- max_model_len: int # maximum context length supported by the model
83
- tensor_parallel_size: int = 1 # tensor parallel size
84
-
85
- # default sampling parameters for this model
86
- temperature: float = 1.0
87
- top_p: float = 1.0
88
- top_k: int = -1
89
- logprobs: int = 0
90
- seed: int = 42
91
- min_response_tokens: int = 0
92
-
93
- # default lora setting
94
- max_lora_rank: int = 16 # maximum rank for LoRA adapters
95
- max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously
96
-
97
- # whether to colocate sampling and training on the same device
98
- # only for local testing purposes
99
- colocate: bool = False
100
- sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling
101
-
102
- def __post_init__(self) -> None:
103
- if self.colocate and self.tensor_parallel_size != 1:
104
- raise ValueError("Colocate option is only supported for tensor_parallel_size=1.")
108
+ def get_config_for_persistence(self) -> dict[str, Any]:
109
+ """Get config fields for persistence signature (excludes persistence config itself)."""
110
+ return self.model_dump(mode="json", exclude={"persistence"})
105
111
 
106
112
 
107
113
  def load_yaml_config(config_path: Path) -> AppConfig:
108
114
  """Loads an AppConfig from a YAML file."""
109
115
  from omegaconf import OmegaConf
110
116
 
111
- schema = OmegaConf.structured(AppConfig)
112
117
  loaded = OmegaConf.load(config_path)
113
118
  try:
114
- config = OmegaConf.merge(schema, loaded)
115
- app_config = OmegaConf.to_object(config)
116
- assert isinstance(app_config, AppConfig), (
117
- "Loaded config is not of type AppConfig, which should not happen."
118
- )
119
- return app_config
119
+ # Convert OmegaConf to plain dict for Pydantic
120
+ config_dict = OmegaConf.to_container(loaded, resolve=True)
121
+ if not isinstance(config_dict, dict):
122
+ raise ValueError("Config file must contain a dictionary at root level")
123
+ return AppConfig.model_validate(config_dict)
120
124
  except Exception as e:
121
125
  raise ValueError(f"Failed to load config from {config_path}: {e}") from e
tuft/exceptions.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Some custom exceptions."""
2
2
 
3
+ from typing import Any
4
+
3
5
 
4
6
  class TuFTException(Exception):
5
7
  """Base exception for TuFT errors."""
@@ -79,6 +81,15 @@ class SequenceConflictException(FutureException):
79
81
  self.got = got
80
82
 
81
83
 
84
+ class SequenceTimeoutException(FutureException):
85
+ """Timeout waiting for the expected sequence ID."""
86
+
87
+ def __init__(self, expected_sequence_id: int):
88
+ detail = f"Timeout when waiting for sequence ID {expected_sequence_id}."
89
+ super().__init__(detail)
90
+ self.sequence_id = expected_sequence_id
91
+
92
+
82
93
  class MissingSequenceIDException(FutureException):
83
94
  """Missing sequence ID in the request."""
84
95
 
@@ -136,3 +147,58 @@ class LossFunctionInputShapeMismatchException(LossFunctionException):
136
147
  detail = f"Input tensors must have the same shape. Got shapes: {shapes}"
137
148
  super().__init__(detail)
138
149
  self.shapes = shapes
150
+
151
+
152
+ class LossFunctionUnknownMetricReductionException(LossFunctionException):
153
+ def __init__(self, reduction_type: str):
154
+ detail = f"Unknown metric reduction type: {reduction_type}"
155
+ super().__init__(detail)
156
+ self.reduction_type = reduction_type
157
+
158
+
159
+ class PersistenceException(TuFTException):
160
+ """Base exception for Persistence related errors."""
161
+
162
+
163
+ class ConfigMismatchError(PersistenceException):
164
+ """Raised when current config doesn't match the stored config in Redis.
165
+
166
+ This error occurs during server startup when persistence is enabled and
167
+ the configuration has changed since the last run. This can cause data
168
+ corruption when restoring persisted state.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ diff: dict[str, dict[str, Any]],
174
+ ):
175
+ self.diff = diff
176
+
177
+ # Build detailed diff message
178
+ diff_parts = []
179
+ for field_name, field_diff in diff.items():
180
+ # Handle scalar fields (current/stored)
181
+ current = field_diff.get("current")
182
+ stored = field_diff.get("stored")
183
+
184
+ parts = []
185
+ if current is not None or stored is not None:
186
+ parts.append(f"current: {current}, stored: {stored}")
187
+
188
+ if parts:
189
+ diff_parts.append(f"{field_name} ({', '.join(parts)})")
190
+
191
+ diff_str = "; ".join(diff_parts) if diff_parts else "unknown difference"
192
+
193
+ message = (
194
+ f"Configuration mismatch detected: {diff_str}.\n"
195
+ "The current configuration does not match the stored configuration in Redis.\n"
196
+ "This can cause data corruption when restoring persisted state.\n\n"
197
+ "Options:\n"
198
+ " 1. Use a different Redis database (change redis_url in config)\n"
199
+ " 2. Run `tuft clear persistence -c <config_path>` to clear existing data\n"
200
+ " Use `--force` or `-f` to skip confirmation prompt.\n"
201
+ " (WARNING: This will delete all persisted sessions, training runs, etc.)\n"
202
+ " 3. Restore the original configuration that matches the stored data"
203
+ )
204
+ super().__init__(message)
tuft/futures.py CHANGED
@@ -189,6 +189,24 @@ class FutureStore:
189
189
  count += 1
190
190
  return count
191
191
 
192
+ def mark_pending_sample_futures_failed(
193
+ self,
194
+ error_message: str = "Server restarted while sample request was pending. Please retry.",
195
+ ) -> int:
196
+ """Mark all pending sample futures as failed."""
197
+ count = 0
198
+ for record in self._records.values():
199
+ if record.status == "pending" and record.operation_type == "sample":
200
+ record.status = "failed"
201
+ record.error = types.RequestFailedResponse(
202
+ error=error_message,
203
+ category=types.RequestErrorCategory.Server,
204
+ )
205
+ record.event.set()
206
+ self._save_future(record.request_id)
207
+ count += 1
208
+ return count
209
+
192
210
  def _store_record(self, record: FutureRecord) -> None:
193
211
  self._records[record.request_id] = record
194
212
  self._save_future(record.request_id)
@@ -327,8 +345,9 @@ class FutureStore:
327
345
  record.payload = payload
328
346
  record.status = "ready"
329
347
  record.error = None
348
+ loop = asyncio.get_event_loop()
349
+ await loop.run_in_executor(None, self._save_future, request_id)
330
350
  record.event.set()
331
- self._save_future(request_id)
332
351
 
333
352
  # Update metrics
334
353
  get_metrics().futures_completed.add(
@@ -352,8 +371,9 @@ class FutureStore:
352
371
  return
353
372
  record.status = "failed"
354
373
  record.error = failure
374
+ loop = asyncio.get_event_loop()
375
+ await loop.run_in_executor(None, self._save_future, request_id)
355
376
  record.event.set()
356
- self._save_future(request_id)
357
377
 
358
378
  # Update metrics
359
379
  get_metrics().futures_completed.add(
tuft/loss_fn/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from typing import Callable, Dict, Tuple
2
2
 
3
+ from tinker.lib.chunked_fwdbwd_helpers import REDUCE_MAP
3
4
  from torch import Tensor
4
5
  from typing_extensions import TypeAlias
5
6
 
@@ -7,6 +8,7 @@ from ..exceptions import (
7
8
  LossFunctionInputShapeMismatchException,
8
9
  LossFunctionMissingInputException,
9
10
  LossFunctionNotFoundException,
11
+ LossFunctionUnknownMetricReductionException,
10
12
  )
11
13
 
12
14
 
@@ -46,3 +48,34 @@ def _check_loss_fn_inputs(
46
48
  shapes = [loss_fn_inputs[key].shape for key in required_keys]
47
49
  if not all(shape == shapes[0] for shape in shapes):
48
50
  raise LossFunctionInputShapeMismatchException(shapes)
51
+
52
+
53
+ def metrics_reduction(
54
+ metric_list: list[dict[str, float]],
55
+ weights: list[float],
56
+ ) -> dict[str, float]:
57
+ """Aggregate metrics from multiple batches.
58
+
59
+ Modified from tinker.lib.chunked_fwdbwd_helpers._metrics_reduction
60
+ """
61
+ if not metric_list:
62
+ return {}
63
+ keys = metric_list[0].keys()
64
+ result = {}
65
+ for key in keys:
66
+ _, reduction = key.split(":")
67
+ if reduction not in REDUCE_MAP:
68
+ raise LossFunctionUnknownMetricReductionException(reduction)
69
+ if not all(key in m for m in metric_list):
70
+ continue
71
+ reduce_fn = REDUCE_MAP[reduction]
72
+ values = [m[key] for m in metric_list]
73
+
74
+ if reduction in ["mean", "slack"]:
75
+ result[key] = reduce_fn(values, weights)
76
+ elif reduction in ["unique"]:
77
+ result[key] = values[0]
78
+ result.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])})
79
+ else:
80
+ result[key] = reduce_fn(values)
81
+ return result
@@ -3,30 +3,38 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from .redis_store import (
6
- DEFAULT_FUTURE_TTL_SECONDS,
6
+ ConfigCheckField,
7
7
  PersistenceConfig,
8
8
  PersistenceMode,
9
9
  RedisPipeline,
10
10
  RedisStore,
11
11
  delete_record,
12
+ flush_all_data,
13
+ get_current_namespace,
12
14
  get_redis_store,
13
15
  is_persistence_enabled,
14
16
  load_record,
17
+ save_config_signature,
15
18
  save_record,
16
19
  save_records_atomic,
20
+ validate_config_signature,
17
21
  )
18
22
 
19
23
 
20
24
  __all__ = [
21
- "DEFAULT_FUTURE_TTL_SECONDS",
25
+ "ConfigCheckField",
22
26
  "PersistenceConfig",
23
27
  "PersistenceMode",
24
28
  "RedisPipeline",
25
29
  "RedisStore",
26
30
  "delete_record",
31
+ "flush_all_data",
32
+ "get_current_namespace",
27
33
  "get_redis_store",
28
34
  "is_persistence_enabled",
29
35
  "load_record",
36
+ "save_config_signature",
30
37
  "save_record",
31
38
  "save_records_atomic",
39
+ "validate_config_signature",
32
40
  ]