tuft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tuft/exceptions.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Some custom exceptions."""
2
2
 
3
+ from typing import Any
4
+
3
5
 
4
6
  class TuFTException(Exception):
5
7
  """Base exception for TuFT errors."""
@@ -79,6 +81,15 @@ class SequenceConflictException(FutureException):
79
81
  self.got = got
80
82
 
81
83
 
84
+ class SequenceTimeoutException(FutureException):
85
+ """Timeout waiting for the expected sequence ID."""
86
+
87
+ def __init__(self, expected_sequence_id: int):
88
+ detail = f"Timeout when waiting for sequence ID {expected_sequence_id}."
89
+ super().__init__(detail)
90
+ self.sequence_id = expected_sequence_id
91
+
92
+
82
93
  class MissingSequenceIDException(FutureException):
83
94
  """Missing sequence ID in the request."""
84
95
 
@@ -136,3 +147,58 @@ class LossFunctionInputShapeMismatchException(LossFunctionException):
136
147
  detail = f"Input tensors must have the same shape. Got shapes: {shapes}"
137
148
  super().__init__(detail)
138
149
  self.shapes = shapes
150
+
151
+
152
+ class LossFunctionUnknownMetricReductionException(LossFunctionException):
153
+ def __init__(self, reduction_type: str):
154
+ detail = f"Unknown metric reduction type: {reduction_type}"
155
+ super().__init__(detail)
156
+ self.reduction_type = reduction_type
157
+
158
+
159
+ class PersistenceException(TuFTException):
160
+ """Base exception for Persistence related errors."""
161
+
162
+
163
+ class ConfigMismatchError(PersistenceException):
164
+ """Raised when current config doesn't match the stored config in Redis.
165
+
166
+ This error occurs during server startup when persistence is enabled and
167
+ the configuration has changed since the last run. This can cause data
168
+ corruption when restoring persisted state.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ diff: dict[str, dict[str, Any]],
174
+ ):
175
+ self.diff = diff
176
+
177
+ # Build detailed diff message
178
+ diff_parts = []
179
+ for field_name, field_diff in diff.items():
180
+ # Handle scalar fields (current/stored)
181
+ current = field_diff.get("current")
182
+ stored = field_diff.get("stored")
183
+
184
+ parts = []
185
+ if current is not None or stored is not None:
186
+ parts.append(f"current: {current}, stored: {stored}")
187
+
188
+ if parts:
189
+ diff_parts.append(f"{field_name} ({', '.join(parts)})")
190
+
191
+ diff_str = "; ".join(diff_parts) if diff_parts else "unknown difference"
192
+
193
+ message = (
194
+ f"Configuration mismatch detected: {diff_str}.\n"
195
+ "The current configuration does not match the stored configuration in Redis.\n"
196
+ "This can cause data corruption when restoring persisted state.\n\n"
197
+ "Options:\n"
198
+ " 1. Use a different Redis database (change redis_url in config)\n"
199
+ " 2. Run `tuft clear persistence -c <config_path>` to clear existing data\n"
200
+ " Use `--force` or `-f` to skip confirmation prompt.\n"
201
+ " (WARNING: This will delete all persisted sessions, training runs, etc.)\n"
202
+ " 3. Restore the original configuration that matches the stored data"
203
+ )
204
+ super().__init__(message)
tuft/futures.py CHANGED
@@ -189,6 +189,24 @@ class FutureStore:
189
189
  count += 1
190
190
  return count
191
191
 
192
+ def mark_pending_sample_futures_failed(
193
+ self,
194
+ error_message: str = "Server restarted while sample request was pending. Please retry.",
195
+ ) -> int:
196
+ """Mark all pending sample futures as failed."""
197
+ count = 0
198
+ for record in self._records.values():
199
+ if record.status == "pending" and record.operation_type == "sample":
200
+ record.status = "failed"
201
+ record.error = types.RequestFailedResponse(
202
+ error=error_message,
203
+ category=types.RequestErrorCategory.Server,
204
+ )
205
+ record.event.set()
206
+ self._save_future(record.request_id)
207
+ count += 1
208
+ return count
209
+
192
210
  def _store_record(self, record: FutureRecord) -> None:
193
211
  self._records[record.request_id] = record
194
212
  self._save_future(record.request_id)
@@ -327,8 +345,9 @@ class FutureStore:
327
345
  record.payload = payload
328
346
  record.status = "ready"
329
347
  record.error = None
348
+ loop = asyncio.get_event_loop()
349
+ await loop.run_in_executor(None, self._save_future, request_id)
330
350
  record.event.set()
331
- self._save_future(request_id)
332
351
 
333
352
  # Update metrics
334
353
  get_metrics().futures_completed.add(
@@ -352,8 +371,9 @@ class FutureStore:
352
371
  return
353
372
  record.status = "failed"
354
373
  record.error = failure
374
+ loop = asyncio.get_event_loop()
375
+ await loop.run_in_executor(None, self._save_future, request_id)
355
376
  record.event.set()
356
- self._save_future(request_id)
357
377
 
358
378
  # Update metrics
359
379
  get_metrics().futures_completed.add(
tuft/loss_fn/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from typing import Callable, Dict, Tuple
2
2
 
3
+ from tinker.lib.chunked_fwdbwd_helpers import REDUCE_MAP
3
4
  from torch import Tensor
4
5
  from typing_extensions import TypeAlias
5
6
 
@@ -7,6 +8,7 @@ from ..exceptions import (
7
8
  LossFunctionInputShapeMismatchException,
8
9
  LossFunctionMissingInputException,
9
10
  LossFunctionNotFoundException,
11
+ LossFunctionUnknownMetricReductionException,
10
12
  )
11
13
 
12
14
 
@@ -46,3 +48,34 @@ def _check_loss_fn_inputs(
46
48
  shapes = [loss_fn_inputs[key].shape for key in required_keys]
47
49
  if not all(shape == shapes[0] for shape in shapes):
48
50
  raise LossFunctionInputShapeMismatchException(shapes)
51
+
52
+
53
+ def metrics_reduction(
54
+ metric_list: list[dict[str, float]],
55
+ weights: list[float],
56
+ ) -> dict[str, float]:
57
+ """Aggregate metrics from multiple batches.
58
+
59
+ Modified from tinker.lib.chunked_fwdbwd_helpers._metrics_reduction
60
+ """
61
+ if not metric_list:
62
+ return {}
63
+ keys = metric_list[0].keys()
64
+ result = {}
65
+ for key in keys:
66
+ _, reduction = key.split(":")
67
+ if reduction not in REDUCE_MAP:
68
+ raise LossFunctionUnknownMetricReductionException(reduction)
69
+ if not all(key in m for m in metric_list):
70
+ continue
71
+ reduce_fn = REDUCE_MAP[reduction]
72
+ values = [m[key] for m in metric_list]
73
+
74
+ if reduction in ["mean", "slack"]:
75
+ result[key] = reduce_fn(values, weights)
76
+ elif reduction in ["unique"]:
77
+ result[key] = values[0]
78
+ result.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])})
79
+ else:
80
+ result[key] = reduce_fn(values)
81
+ return result
@@ -3,30 +3,38 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from .redis_store import (
6
- DEFAULT_FUTURE_TTL_SECONDS,
6
+ ConfigCheckField,
7
7
  PersistenceConfig,
8
8
  PersistenceMode,
9
9
  RedisPipeline,
10
10
  RedisStore,
11
11
  delete_record,
12
+ flush_all_data,
13
+ get_current_namespace,
12
14
  get_redis_store,
13
15
  is_persistence_enabled,
14
16
  load_record,
17
+ save_config_signature,
15
18
  save_record,
16
19
  save_records_atomic,
20
+ validate_config_signature,
17
21
  )
18
22
 
19
23
 
20
24
  __all__ = [
21
- "DEFAULT_FUTURE_TTL_SECONDS",
25
+ "ConfigCheckField",
22
26
  "PersistenceConfig",
23
27
  "PersistenceMode",
24
28
  "RedisPipeline",
25
29
  "RedisStore",
26
30
  "delete_record",
31
+ "flush_all_data",
32
+ "get_current_namespace",
27
33
  "get_redis_store",
28
34
  "is_persistence_enabled",
29
35
  "load_record",
36
+ "save_config_signature",
30
37
  "save_record",
31
38
  "save_records_atomic",
39
+ "validate_config_signature",
32
40
  ]