tuft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/__main__.py +7 -0
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +124 -0
- tuft/config.py +123 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +368 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +728 -0
- tuft-0.1.2.dist-info/METADATA +633 -0
- tuft-0.1.2.dist-info/RECORD +36 -0
- {tuft-0.1.0.dist-info → tuft-0.1.2.dist-info}/WHEEL +1 -2
- tuft-0.1.2.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.2.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import hashlib
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, List, Tuple
|
|
11
|
+
|
|
12
|
+
from opentelemetry.trace import StatusCode
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
from tinker import types
|
|
15
|
+
|
|
16
|
+
from .backends import BaseSamplingBackend
|
|
17
|
+
from .checkpoints import CheckpointRecord
|
|
18
|
+
from .config import AppConfig, ModelConfig
|
|
19
|
+
from .exceptions import (
|
|
20
|
+
CheckpointAccessDeniedException,
|
|
21
|
+
CheckpointNotFoundException,
|
|
22
|
+
MissingSequenceIDException,
|
|
23
|
+
SequenceConflictException,
|
|
24
|
+
SessionNotFoundException,
|
|
25
|
+
UnknownModelException,
|
|
26
|
+
UserMismatchException,
|
|
27
|
+
)
|
|
28
|
+
from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
|
|
29
|
+
from .telemetry.metrics import get_metrics
|
|
30
|
+
from .telemetry.tracing import get_tracer
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_get_tracer = lambda: get_tracer("tuft.sampling_controller") # noqa: E731
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _now() -> datetime:
|
|
40
|
+
return datetime.now(timezone.utc)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SamplingHistoryEntry(BaseModel):
|
|
44
|
+
"""Entry in the sampling history."""
|
|
45
|
+
|
|
46
|
+
seq_id: int
|
|
47
|
+
prompt_token_count: int
|
|
48
|
+
prompt_hash: str
|
|
49
|
+
created_at: datetime = Field(default_factory=_now)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SamplingSessionRecord(BaseModel):
|
|
53
|
+
"""Sampling session record with persistence support.
|
|
54
|
+
|
|
55
|
+
Sessions are permanent records (no TTL) as they represent active
|
|
56
|
+
sampling sessions that users may access at any time.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
sampling_session_id: str
|
|
60
|
+
session_id: str
|
|
61
|
+
model_id: str
|
|
62
|
+
base_model: str
|
|
63
|
+
user_id: str
|
|
64
|
+
model_path: str | None = None
|
|
65
|
+
session_seq_id: int
|
|
66
|
+
last_seq_id: int = -1
|
|
67
|
+
history: list[SamplingHistoryEntry] = Field(default_factory=list)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SamplingController:
|
|
71
|
+
"""Manages sampling sessions and connects them to the correct training or base-model backend."""
|
|
72
|
+
|
|
73
|
+
REDIS_KEY_PREFIX = "sampling_session"
|
|
74
|
+
|
|
75
|
+
def __init__(self, config: AppConfig) -> None:
|
|
76
|
+
self.config = config
|
|
77
|
+
self.sampling_sessions: Dict[str, SamplingSessionRecord] = {}
|
|
78
|
+
self._base_backends: Dict[str, BaseSamplingBackend] = self._create_backends(
|
|
79
|
+
config.supported_models
|
|
80
|
+
)
|
|
81
|
+
self._restore_from_redis()
|
|
82
|
+
|
|
83
|
+
def _build_key(self, session_id: str) -> str:
|
|
84
|
+
return get_redis_store().build_key(self.REDIS_KEY_PREFIX, session_id)
|
|
85
|
+
|
|
86
|
+
def _restore_from_redis(self) -> None:
|
|
87
|
+
"""Restore sampling sessions from Redis on startup."""
|
|
88
|
+
if not is_persistence_enabled():
|
|
89
|
+
return
|
|
90
|
+
store = get_redis_store()
|
|
91
|
+
pattern = store.build_key(self.REDIS_KEY_PREFIX, "*")
|
|
92
|
+
invalid_sessions = []
|
|
93
|
+
for key in store.keys(pattern):
|
|
94
|
+
# Match only top-level sessions (3 parts)
|
|
95
|
+
if len(key.split("::")) != 3:
|
|
96
|
+
continue
|
|
97
|
+
record = load_record(key, SamplingSessionRecord)
|
|
98
|
+
if record is None:
|
|
99
|
+
continue
|
|
100
|
+
if record.base_model and record.base_model not in self._base_backends:
|
|
101
|
+
invalid_sessions.append(record.sampling_session_id)
|
|
102
|
+
continue
|
|
103
|
+
self.sampling_sessions[record.sampling_session_id] = record
|
|
104
|
+
for session_id in invalid_sessions:
|
|
105
|
+
store.delete(self._build_key(session_id))
|
|
106
|
+
|
|
107
|
+
def _save_session(self, session_id: str) -> None:
|
|
108
|
+
"""Save session to Redis (no TTL - permanent record)."""
|
|
109
|
+
if not is_persistence_enabled():
|
|
110
|
+
return
|
|
111
|
+
record = self.sampling_sessions.get(session_id)
|
|
112
|
+
if record is not None:
|
|
113
|
+
save_record(self._build_key(session_id), record)
|
|
114
|
+
|
|
115
|
+
def _delete_session(self, session_id: str) -> None:
|
|
116
|
+
if not is_persistence_enabled():
|
|
117
|
+
return
|
|
118
|
+
get_redis_store().delete(self._build_key(session_id))
|
|
119
|
+
|
|
120
|
+
async def async_init(self) -> None:
|
|
121
|
+
"""Perform any async initialization here, including adapter reloading."""
|
|
122
|
+
init_tasks = [backend.async_init() for backend in self._base_backends.values()]
|
|
123
|
+
await asyncio.gather(*init_tasks)
|
|
124
|
+
|
|
125
|
+
# Re-add adapters for restored sessions
|
|
126
|
+
await self._rebuild_sampling_backends()
|
|
127
|
+
|
|
128
|
+
async def _rebuild_sampling_backends(self) -> None:
|
|
129
|
+
"""Rebuild sampling backends for restored sessions."""
|
|
130
|
+
invalid_sessions = []
|
|
131
|
+
for session_id, record in list(self.sampling_sessions.items()):
|
|
132
|
+
if record.model_path and record.base_model:
|
|
133
|
+
adapter_path = Path(record.model_path)
|
|
134
|
+
if not adapter_path.exists():
|
|
135
|
+
invalid_sessions.append(session_id)
|
|
136
|
+
continue
|
|
137
|
+
if record.base_model not in self._base_backends:
|
|
138
|
+
invalid_sessions.append(session_id)
|
|
139
|
+
continue
|
|
140
|
+
try:
|
|
141
|
+
backend = self._base_backends[record.base_model]
|
|
142
|
+
await backend.add_adapter(
|
|
143
|
+
lora_id=record.sampling_session_id, adapter_path=adapter_path
|
|
144
|
+
)
|
|
145
|
+
except Exception:
|
|
146
|
+
logger.exception(
|
|
147
|
+
"Failed to rebuild adapter for sampling session %s", session_id
|
|
148
|
+
)
|
|
149
|
+
invalid_sessions.append(session_id)
|
|
150
|
+
for session_id in invalid_sessions:
|
|
151
|
+
del self.sampling_sessions[session_id]
|
|
152
|
+
self._delete_session(session_id)
|
|
153
|
+
|
|
154
|
+
def _create_backends(self, model_configs: List[ModelConfig]) -> Dict[str, BaseSamplingBackend]:
|
|
155
|
+
backends: Dict[str, BaseSamplingBackend] = {}
|
|
156
|
+
for config in model_configs:
|
|
157
|
+
backends[config.model_name] = BaseSamplingBackend.create_backend(config)
|
|
158
|
+
return backends
|
|
159
|
+
|
|
160
|
+
async def create_sampling_session(
|
|
161
|
+
self,
|
|
162
|
+
*,
|
|
163
|
+
session_id: str,
|
|
164
|
+
user_id: str,
|
|
165
|
+
base_model: str | None,
|
|
166
|
+
model_path: str | None,
|
|
167
|
+
session_seq_id: int,
|
|
168
|
+
) -> str:
|
|
169
|
+
base_model_ref: str | None = None
|
|
170
|
+
adapter_path: Path | None = None
|
|
171
|
+
sampling_session_id = str(uuid.uuid4())
|
|
172
|
+
|
|
173
|
+
with _get_tracer().start_as_current_span(
|
|
174
|
+
"sampling_controller.create_sampling_session"
|
|
175
|
+
) as span:
|
|
176
|
+
span.set_attribute("tuft.session_id", session_id)
|
|
177
|
+
span.set_attribute("tuft.sampling_session_id", sampling_session_id)
|
|
178
|
+
if base_model:
|
|
179
|
+
span.set_attribute("tuft.base_model", base_model)
|
|
180
|
+
try:
|
|
181
|
+
if model_path:
|
|
182
|
+
# model_path should have higher priority than base_model
|
|
183
|
+
try:
|
|
184
|
+
assert self.config.checkpoint_dir is not None
|
|
185
|
+
parsed_checkpoint = CheckpointRecord.from_tinker_path(
|
|
186
|
+
model_path,
|
|
187
|
+
self.config.checkpoint_dir,
|
|
188
|
+
)
|
|
189
|
+
except FileNotFoundError as exc:
|
|
190
|
+
raise CheckpointNotFoundException(checkpoint_id=model_path) from exc
|
|
191
|
+
if not parsed_checkpoint.path.exists():
|
|
192
|
+
raise CheckpointNotFoundException(
|
|
193
|
+
checkpoint_id=parsed_checkpoint.checkpoint_id,
|
|
194
|
+
)
|
|
195
|
+
metadata = parsed_checkpoint.metadata
|
|
196
|
+
base_model_ref = metadata.base_model
|
|
197
|
+
is_public = parsed_checkpoint.public
|
|
198
|
+
model_owner = parsed_checkpoint.owner_name
|
|
199
|
+
if not is_public and model_owner != user_id:
|
|
200
|
+
raise CheckpointAccessDeniedException(
|
|
201
|
+
checkpoint_id=parsed_checkpoint.checkpoint_id,
|
|
202
|
+
)
|
|
203
|
+
if base_model_ref not in self._base_backends:
|
|
204
|
+
raise UnknownModelException(model_name=base_model_ref)
|
|
205
|
+
adapter_path = parsed_checkpoint.adapter_path
|
|
206
|
+
sampling_backend = self._base_backends[base_model_ref]
|
|
207
|
+
await sampling_backend.add_adapter(
|
|
208
|
+
lora_id=sampling_session_id, adapter_path=adapter_path
|
|
209
|
+
)
|
|
210
|
+
# TODO: remove adapter when session is deleted
|
|
211
|
+
elif base_model:
|
|
212
|
+
base_model_ref = base_model
|
|
213
|
+
if base_model_ref not in self._base_backends:
|
|
214
|
+
raise UnknownModelException(model_name=base_model_ref)
|
|
215
|
+
else:
|
|
216
|
+
raise UnknownModelException(model_name="None")
|
|
217
|
+
self.sampling_sessions[sampling_session_id] = SamplingSessionRecord(
|
|
218
|
+
sampling_session_id=sampling_session_id,
|
|
219
|
+
session_id=session_id,
|
|
220
|
+
user_id=user_id,
|
|
221
|
+
model_id=sampling_session_id,
|
|
222
|
+
base_model=base_model_ref,
|
|
223
|
+
model_path=str(adapter_path) if adapter_path else None,
|
|
224
|
+
session_seq_id=session_seq_id,
|
|
225
|
+
)
|
|
226
|
+
self._save_session(sampling_session_id)
|
|
227
|
+
|
|
228
|
+
# Update metrics
|
|
229
|
+
get_metrics().sampling_sessions_active.add(1, {"base_model": base_model_ref or ""})
|
|
230
|
+
logger.info("Sampling session created: %s", sampling_session_id)
|
|
231
|
+
return sampling_session_id
|
|
232
|
+
except Exception as e:
|
|
233
|
+
span.record_exception(e)
|
|
234
|
+
span.set_status(StatusCode.ERROR)
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
def _hash_prompt(self, prompt: types.ModelInput) -> str:
|
|
238
|
+
tokens = ",".join(str(token) for token in prompt.to_ints())
|
|
239
|
+
return hashlib.sha1(tokens.encode("utf-8")).hexdigest()[:16]
|
|
240
|
+
|
|
241
|
+
def _record_sequence(
|
|
242
|
+
self, record: SamplingSessionRecord, seq_id: int, prompt: types.ModelInput
|
|
243
|
+
) -> None:
|
|
244
|
+
if seq_id <= record.last_seq_id:
|
|
245
|
+
raise SequenceConflictException(expected=record.last_seq_id + 1, got=seq_id)
|
|
246
|
+
record.last_seq_id = seq_id
|
|
247
|
+
entry = SamplingHistoryEntry(
|
|
248
|
+
seq_id=seq_id,
|
|
249
|
+
prompt_token_count=len(prompt.to_ints()),
|
|
250
|
+
prompt_hash=self._hash_prompt(prompt),
|
|
251
|
+
)
|
|
252
|
+
record.history.append(entry)
|
|
253
|
+
self._save_session(record.sampling_session_id)
|
|
254
|
+
|
|
255
|
+
def _resolve_backend(
|
|
256
|
+
self, request: types.SampleRequest, user_id: str
|
|
257
|
+
) -> Tuple[BaseSamplingBackend, str | None]:
|
|
258
|
+
"""Resolve the appropriate backend for the sampling request.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
request: The sampling request.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
A tuple of the resolved backend and the LoRA ID if applicable.
|
|
265
|
+
"""
|
|
266
|
+
if request.sampling_session_id:
|
|
267
|
+
record = self.sampling_sessions.get(request.sampling_session_id)
|
|
268
|
+
if record is None:
|
|
269
|
+
raise SessionNotFoundException(session_id=request.sampling_session_id)
|
|
270
|
+
if record.user_id != user_id:
|
|
271
|
+
raise UserMismatchException()
|
|
272
|
+
if request.seq_id is None:
|
|
273
|
+
raise MissingSequenceIDException()
|
|
274
|
+
self._record_sequence(record, request.seq_id, request.prompt)
|
|
275
|
+
if record.base_model not in self._base_backends:
|
|
276
|
+
raise UnknownModelException(model_name=record.base_model)
|
|
277
|
+
if record.model_path is None:
|
|
278
|
+
lora_id = None
|
|
279
|
+
else:
|
|
280
|
+
lora_id = record.sampling_session_id
|
|
281
|
+
return self._base_backends[record.base_model], lora_id
|
|
282
|
+
raise SessionNotFoundException(session_id="None")
|
|
283
|
+
|
|
284
|
+
async def run_sample(
|
|
285
|
+
self,
|
|
286
|
+
request: types.SampleRequest,
|
|
287
|
+
user_id: str,
|
|
288
|
+
) -> types.SampleResponse:
|
|
289
|
+
with _get_tracer().start_as_current_span("sampling_controller.run_sample") as span:
|
|
290
|
+
sampling_session_id = request.sampling_session_id or ""
|
|
291
|
+
span.set_attribute("tuft.sampling_session_id", sampling_session_id)
|
|
292
|
+
# Get session_id from sampling session record if available
|
|
293
|
+
if request.sampling_session_id:
|
|
294
|
+
record = self.sampling_sessions.get(request.sampling_session_id)
|
|
295
|
+
if record:
|
|
296
|
+
span.set_attribute("tuft.session_id", record.session_id)
|
|
297
|
+
span.set_attribute("tuft.num_samples", request.num_samples)
|
|
298
|
+
|
|
299
|
+
logger.info("Sampling begin for %s", sampling_session_id)
|
|
300
|
+
start_time = time.perf_counter()
|
|
301
|
+
|
|
302
|
+
backend, lora_id = self._resolve_backend(request, user_id=user_id)
|
|
303
|
+
prompt = request.prompt
|
|
304
|
+
sampling_params = request.sampling_params
|
|
305
|
+
num_samples = request.num_samples
|
|
306
|
+
include_prompt_logprobs = bool(request.prompt_logprobs)
|
|
307
|
+
topk_prompt_logprobs = request.topk_prompt_logprobs or 0
|
|
308
|
+
|
|
309
|
+
response = await backend.sample(
|
|
310
|
+
prompt=prompt,
|
|
311
|
+
num_samples=num_samples,
|
|
312
|
+
sampling_params=sampling_params,
|
|
313
|
+
include_prompt_logprobs=include_prompt_logprobs,
|
|
314
|
+
topk_prompt_logprobs=topk_prompt_logprobs,
|
|
315
|
+
lora_id=lora_id,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
duration = time.perf_counter() - start_time
|
|
319
|
+
logger.info("Sampling completed for %s", sampling_session_id)
|
|
320
|
+
|
|
321
|
+
# Get base_model for metrics
|
|
322
|
+
record = self.sampling_sessions.get(request.sampling_session_id or "")
|
|
323
|
+
base_model = record.base_model if record else ""
|
|
324
|
+
|
|
325
|
+
# Update metrics
|
|
326
|
+
metrics = get_metrics()
|
|
327
|
+
metrics.sampling_requests.add(1, {"base_model": base_model})
|
|
328
|
+
metrics.sampling_duration.record(duration, {"base_model": base_model})
|
|
329
|
+
|
|
330
|
+
# Record output tokens for each sequence
|
|
331
|
+
total_output_tokens = 0
|
|
332
|
+
for seq in response.sequences:
|
|
333
|
+
if seq.tokens:
|
|
334
|
+
metrics.sampling_output_tokens.record(len(seq.tokens))
|
|
335
|
+
total_output_tokens += len(seq.tokens)
|
|
336
|
+
|
|
337
|
+
# Record tokens per second if we have output tokens and positive duration
|
|
338
|
+
if total_output_tokens > 0 and duration > 0:
|
|
339
|
+
tokens_per_second = total_output_tokens / duration
|
|
340
|
+
metrics.sampling_tokens_per_second.record(
|
|
341
|
+
tokens_per_second, {"base_model": base_model}
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
return response
|
|
345
|
+
|
|
346
|
+
async def evict_model(self, model_id: str, user_id: str) -> None:
|
|
347
|
+
for sampling_id, record in list(self.sampling_sessions.items()):
|
|
348
|
+
if record.model_id == model_id and record.user_id == user_id:
|
|
349
|
+
base_model = record.base_model
|
|
350
|
+
del self.sampling_sessions[sampling_id]
|
|
351
|
+
self._delete_session(sampling_id)
|
|
352
|
+
# Update metrics
|
|
353
|
+
get_metrics().sampling_sessions_active.add(-1, {"base_model": base_model or ""})
|
|
354
|
+
|
|
355
|
+
def get_sampler_info(
|
|
356
|
+
self, sampler_id: str, user_id: str, default_base_model: str
|
|
357
|
+
) -> types.GetSamplerResponse:
|
|
358
|
+
record = self.sampling_sessions.get(sampler_id)
|
|
359
|
+
if record is None:
|
|
360
|
+
raise SessionNotFoundException(session_id=sampler_id)
|
|
361
|
+
if record.user_id != user_id:
|
|
362
|
+
raise UserMismatchException()
|
|
363
|
+
base = record.base_model
|
|
364
|
+
return types.GetSamplerResponse(
|
|
365
|
+
sampler_id=sampler_id,
|
|
366
|
+
base_model=base or default_base_model,
|
|
367
|
+
model_path=record.model_path,
|
|
368
|
+
)
|