tuft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,368 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import logging
6
+ import time
7
+ import uuid
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Dict, List, Tuple
11
+
12
+ from opentelemetry.trace import StatusCode
13
+ from pydantic import BaseModel, Field
14
+ from tinker import types
15
+
16
+ from .backends import BaseSamplingBackend
17
+ from .checkpoints import CheckpointRecord
18
+ from .config import AppConfig, ModelConfig
19
+ from .exceptions import (
20
+ CheckpointAccessDeniedException,
21
+ CheckpointNotFoundException,
22
+ MissingSequenceIDException,
23
+ SequenceConflictException,
24
+ SessionNotFoundException,
25
+ UnknownModelException,
26
+ UserMismatchException,
27
+ )
28
+ from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
29
+ from .telemetry.metrics import get_metrics
30
+ from .telemetry.tracing import get_tracer
31
+
32
+
33
+ _get_tracer = lambda: get_tracer("tuft.sampling_controller") # noqa: E731
34
+
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def _now() -> datetime:
40
+ return datetime.now(timezone.utc)
41
+
42
+
43
+ class SamplingHistoryEntry(BaseModel):
44
+ """Entry in the sampling history."""
45
+
46
+ seq_id: int
47
+ prompt_token_count: int
48
+ prompt_hash: str
49
+ created_at: datetime = Field(default_factory=_now)
50
+
51
+
52
+ class SamplingSessionRecord(BaseModel):
53
+ """Sampling session record with persistence support.
54
+
55
+ Sessions are permanent records (no TTL) as they represent active
56
+ sampling sessions that users may access at any time.
57
+ """
58
+
59
+ sampling_session_id: str
60
+ session_id: str
61
+ model_id: str
62
+ base_model: str
63
+ user_id: str
64
+ model_path: str | None = None
65
+ session_seq_id: int
66
+ last_seq_id: int = -1
67
+ history: list[SamplingHistoryEntry] = Field(default_factory=list)
68
+
69
+
70
+ class SamplingController:
71
+ """Manages sampling sessions and connects them to the correct training or base-model backend."""
72
+
73
+ REDIS_KEY_PREFIX = "sampling_session"
74
+
75
+ def __init__(self, config: AppConfig) -> None:
76
+ self.config = config
77
+ self.sampling_sessions: Dict[str, SamplingSessionRecord] = {}
78
+ self._base_backends: Dict[str, BaseSamplingBackend] = self._create_backends(
79
+ config.supported_models
80
+ )
81
+ self._restore_from_redis()
82
+
83
+ def _build_key(self, session_id: str) -> str:
84
+ return get_redis_store().build_key(self.REDIS_KEY_PREFIX, session_id)
85
+
86
+ def _restore_from_redis(self) -> None:
87
+ """Restore sampling sessions from Redis on startup."""
88
+ if not is_persistence_enabled():
89
+ return
90
+ store = get_redis_store()
91
+ pattern = store.build_key(self.REDIS_KEY_PREFIX, "*")
92
+ invalid_sessions = []
93
+ for key in store.keys(pattern):
94
+ # Match only top-level sessions (3 parts)
95
+ if len(key.split("::")) != 3:
96
+ continue
97
+ record = load_record(key, SamplingSessionRecord)
98
+ if record is None:
99
+ continue
100
+ if record.base_model and record.base_model not in self._base_backends:
101
+ invalid_sessions.append(record.sampling_session_id)
102
+ continue
103
+ self.sampling_sessions[record.sampling_session_id] = record
104
+ for session_id in invalid_sessions:
105
+ store.delete(self._build_key(session_id))
106
+
107
+ def _save_session(self, session_id: str) -> None:
108
+ """Save session to Redis (no TTL - permanent record)."""
109
+ if not is_persistence_enabled():
110
+ return
111
+ record = self.sampling_sessions.get(session_id)
112
+ if record is not None:
113
+ save_record(self._build_key(session_id), record)
114
+
115
+ def _delete_session(self, session_id: str) -> None:
116
+ if not is_persistence_enabled():
117
+ return
118
+ get_redis_store().delete(self._build_key(session_id))
119
+
120
+ async def async_init(self) -> None:
121
+ """Perform any async initialization here, including adapter reloading."""
122
+ init_tasks = [backend.async_init() for backend in self._base_backends.values()]
123
+ await asyncio.gather(*init_tasks)
124
+
125
+ # Re-add adapters for restored sessions
126
+ await self._rebuild_sampling_backends()
127
+
128
+ async def _rebuild_sampling_backends(self) -> None:
129
+ """Rebuild sampling backends for restored sessions."""
130
+ invalid_sessions = []
131
+ for session_id, record in list(self.sampling_sessions.items()):
132
+ if record.model_path and record.base_model:
133
+ adapter_path = Path(record.model_path)
134
+ if not adapter_path.exists():
135
+ invalid_sessions.append(session_id)
136
+ continue
137
+ if record.base_model not in self._base_backends:
138
+ invalid_sessions.append(session_id)
139
+ continue
140
+ try:
141
+ backend = self._base_backends[record.base_model]
142
+ await backend.add_adapter(
143
+ lora_id=record.sampling_session_id, adapter_path=adapter_path
144
+ )
145
+ except Exception:
146
+ logger.exception(
147
+ "Failed to rebuild adapter for sampling session %s", session_id
148
+ )
149
+ invalid_sessions.append(session_id)
150
+ for session_id in invalid_sessions:
151
+ del self.sampling_sessions[session_id]
152
+ self._delete_session(session_id)
153
+
154
+ def _create_backends(self, model_configs: List[ModelConfig]) -> Dict[str, BaseSamplingBackend]:
155
+ backends: Dict[str, BaseSamplingBackend] = {}
156
+ for config in model_configs:
157
+ backends[config.model_name] = BaseSamplingBackend.create_backend(config)
158
+ return backends
159
+
160
+ async def create_sampling_session(
161
+ self,
162
+ *,
163
+ session_id: str,
164
+ user_id: str,
165
+ base_model: str | None,
166
+ model_path: str | None,
167
+ session_seq_id: int,
168
+ ) -> str:
169
+ base_model_ref: str | None = None
170
+ adapter_path: Path | None = None
171
+ sampling_session_id = str(uuid.uuid4())
172
+
173
+ with _get_tracer().start_as_current_span(
174
+ "sampling_controller.create_sampling_session"
175
+ ) as span:
176
+ span.set_attribute("tuft.session_id", session_id)
177
+ span.set_attribute("tuft.sampling_session_id", sampling_session_id)
178
+ if base_model:
179
+ span.set_attribute("tuft.base_model", base_model)
180
+ try:
181
+ if model_path:
182
+ # model_path should have higher priority than base_model
183
+ try:
184
+ assert self.config.checkpoint_dir is not None
185
+ parsed_checkpoint = CheckpointRecord.from_tinker_path(
186
+ model_path,
187
+ self.config.checkpoint_dir,
188
+ )
189
+ except FileNotFoundError as exc:
190
+ raise CheckpointNotFoundException(checkpoint_id=model_path) from exc
191
+ if not parsed_checkpoint.path.exists():
192
+ raise CheckpointNotFoundException(
193
+ checkpoint_id=parsed_checkpoint.checkpoint_id,
194
+ )
195
+ metadata = parsed_checkpoint.metadata
196
+ base_model_ref = metadata.base_model
197
+ is_public = parsed_checkpoint.public
198
+ model_owner = parsed_checkpoint.owner_name
199
+ if not is_public and model_owner != user_id:
200
+ raise CheckpointAccessDeniedException(
201
+ checkpoint_id=parsed_checkpoint.checkpoint_id,
202
+ )
203
+ if base_model_ref not in self._base_backends:
204
+ raise UnknownModelException(model_name=base_model_ref)
205
+ adapter_path = parsed_checkpoint.adapter_path
206
+ sampling_backend = self._base_backends[base_model_ref]
207
+ await sampling_backend.add_adapter(
208
+ lora_id=sampling_session_id, adapter_path=adapter_path
209
+ )
210
+ # TODO: remove adapter when session is deleted
211
+ elif base_model:
212
+ base_model_ref = base_model
213
+ if base_model_ref not in self._base_backends:
214
+ raise UnknownModelException(model_name=base_model_ref)
215
+ else:
216
+ raise UnknownModelException(model_name="None")
217
+ self.sampling_sessions[sampling_session_id] = SamplingSessionRecord(
218
+ sampling_session_id=sampling_session_id,
219
+ session_id=session_id,
220
+ user_id=user_id,
221
+ model_id=sampling_session_id,
222
+ base_model=base_model_ref,
223
+ model_path=str(adapter_path) if adapter_path else None,
224
+ session_seq_id=session_seq_id,
225
+ )
226
+ self._save_session(sampling_session_id)
227
+
228
+ # Update metrics
229
+ get_metrics().sampling_sessions_active.add(1, {"base_model": base_model_ref or ""})
230
+ logger.info("Sampling session created: %s", sampling_session_id)
231
+ return sampling_session_id
232
+ except Exception as e:
233
+ span.record_exception(e)
234
+ span.set_status(StatusCode.ERROR)
235
+ raise
236
+
237
+ def _hash_prompt(self, prompt: types.ModelInput) -> str:
238
+ tokens = ",".join(str(token) for token in prompt.to_ints())
239
+ return hashlib.sha1(tokens.encode("utf-8")).hexdigest()[:16]
240
+
241
+ def _record_sequence(
242
+ self, record: SamplingSessionRecord, seq_id: int, prompt: types.ModelInput
243
+ ) -> None:
244
+ if seq_id <= record.last_seq_id:
245
+ raise SequenceConflictException(expected=record.last_seq_id + 1, got=seq_id)
246
+ record.last_seq_id = seq_id
247
+ entry = SamplingHistoryEntry(
248
+ seq_id=seq_id,
249
+ prompt_token_count=len(prompt.to_ints()),
250
+ prompt_hash=self._hash_prompt(prompt),
251
+ )
252
+ record.history.append(entry)
253
+ self._save_session(record.sampling_session_id)
254
+
255
+ def _resolve_backend(
256
+ self, request: types.SampleRequest, user_id: str
257
+ ) -> Tuple[BaseSamplingBackend, str | None]:
258
+ """Resolve the appropriate backend for the sampling request.
259
+
260
+ Args:
261
+ request: The sampling request.
262
+
263
+ Returns:
264
+ A tuple of the resolved backend and the LoRA ID if applicable.
265
+ """
266
+ if request.sampling_session_id:
267
+ record = self.sampling_sessions.get(request.sampling_session_id)
268
+ if record is None:
269
+ raise SessionNotFoundException(session_id=request.sampling_session_id)
270
+ if record.user_id != user_id:
271
+ raise UserMismatchException()
272
+ if request.seq_id is None:
273
+ raise MissingSequenceIDException()
274
+ self._record_sequence(record, request.seq_id, request.prompt)
275
+ if record.base_model not in self._base_backends:
276
+ raise UnknownModelException(model_name=record.base_model)
277
+ if record.model_path is None:
278
+ lora_id = None
279
+ else:
280
+ lora_id = record.sampling_session_id
281
+ return self._base_backends[record.base_model], lora_id
282
+ raise SessionNotFoundException(session_id="None")
283
+
284
+ async def run_sample(
285
+ self,
286
+ request: types.SampleRequest,
287
+ user_id: str,
288
+ ) -> types.SampleResponse:
289
+ with _get_tracer().start_as_current_span("sampling_controller.run_sample") as span:
290
+ sampling_session_id = request.sampling_session_id or ""
291
+ span.set_attribute("tuft.sampling_session_id", sampling_session_id)
292
+ # Get session_id from sampling session record if available
293
+ if request.sampling_session_id:
294
+ record = self.sampling_sessions.get(request.sampling_session_id)
295
+ if record:
296
+ span.set_attribute("tuft.session_id", record.session_id)
297
+ span.set_attribute("tuft.num_samples", request.num_samples)
298
+
299
+ logger.info("Sampling begin for %s", sampling_session_id)
300
+ start_time = time.perf_counter()
301
+
302
+ backend, lora_id = self._resolve_backend(request, user_id=user_id)
303
+ prompt = request.prompt
304
+ sampling_params = request.sampling_params
305
+ num_samples = request.num_samples
306
+ include_prompt_logprobs = bool(request.prompt_logprobs)
307
+ topk_prompt_logprobs = request.topk_prompt_logprobs or 0
308
+
309
+ response = await backend.sample(
310
+ prompt=prompt,
311
+ num_samples=num_samples,
312
+ sampling_params=sampling_params,
313
+ include_prompt_logprobs=include_prompt_logprobs,
314
+ topk_prompt_logprobs=topk_prompt_logprobs,
315
+ lora_id=lora_id,
316
+ )
317
+
318
+ duration = time.perf_counter() - start_time
319
+ logger.info("Sampling completed for %s", sampling_session_id)
320
+
321
+ # Get base_model for metrics
322
+ record = self.sampling_sessions.get(request.sampling_session_id or "")
323
+ base_model = record.base_model if record else ""
324
+
325
+ # Update metrics
326
+ metrics = get_metrics()
327
+ metrics.sampling_requests.add(1, {"base_model": base_model})
328
+ metrics.sampling_duration.record(duration, {"base_model": base_model})
329
+
330
+ # Record output tokens for each sequence
331
+ total_output_tokens = 0
332
+ for seq in response.sequences:
333
+ if seq.tokens:
334
+ metrics.sampling_output_tokens.record(len(seq.tokens))
335
+ total_output_tokens += len(seq.tokens)
336
+
337
+ # Record tokens per second if we have output tokens and positive duration
338
+ if total_output_tokens > 0 and duration > 0:
339
+ tokens_per_second = total_output_tokens / duration
340
+ metrics.sampling_tokens_per_second.record(
341
+ tokens_per_second, {"base_model": base_model}
342
+ )
343
+
344
+ return response
345
+
346
+ async def evict_model(self, model_id: str, user_id: str) -> None:
347
+ for sampling_id, record in list(self.sampling_sessions.items()):
348
+ if record.model_id == model_id and record.user_id == user_id:
349
+ base_model = record.base_model
350
+ del self.sampling_sessions[sampling_id]
351
+ self._delete_session(sampling_id)
352
+ # Update metrics
353
+ get_metrics().sampling_sessions_active.add(-1, {"base_model": base_model or ""})
354
+
355
+ def get_sampler_info(
356
+ self, sampler_id: str, user_id: str, default_base_model: str
357
+ ) -> types.GetSamplerResponse:
358
+ record = self.sampling_sessions.get(sampler_id)
359
+ if record is None:
360
+ raise SessionNotFoundException(session_id=sampler_id)
361
+ if record.user_id != user_id:
362
+ raise UserMismatchException()
363
+ base = record.base_model
364
+ return types.GetSamplerResponse(
365
+ sampler_id=sampler_id,
366
+ base_model=base or default_base_model,
367
+ model_path=record.model_path,
368
+ )