tuft 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tuft/__init__.py +5 -2
- tuft/auth.py +35 -0
- tuft/backend.py +254 -0
- tuft/backends/__init__.py +10 -0
- tuft/backends/base_backend.py +112 -0
- tuft/backends/hf_training_model.py +404 -0
- tuft/backends/sampling_backend.py +253 -0
- tuft/backends/training_backend.py +327 -0
- tuft/checkpoints.py +193 -0
- tuft/cli.py +91 -0
- tuft/config.py +121 -0
- tuft/exceptions.py +138 -0
- tuft/futures.py +431 -0
- tuft/loss_fn/__init__.py +48 -0
- tuft/loss_fn/cispo.py +40 -0
- tuft/loss_fn/cross_entropy.py +26 -0
- tuft/loss_fn/dro.py +37 -0
- tuft/loss_fn/importance_sampling.py +33 -0
- tuft/loss_fn/ppo.py +43 -0
- tuft/persistence/__init__.py +32 -0
- tuft/persistence/file_redis.py +268 -0
- tuft/persistence/redis_store.py +488 -0
- tuft/sampling_controller.py +366 -0
- tuft/server.py +720 -0
- tuft/state.py +352 -0
- tuft/telemetry/__init__.py +17 -0
- tuft/telemetry/metrics.py +335 -0
- tuft/telemetry/provider.py +198 -0
- tuft/telemetry/tracing.py +43 -0
- tuft/training_controller.py +723 -0
- tuft-0.1.1.dist-info/METADATA +633 -0
- tuft-0.1.1.dist-info/RECORD +35 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/WHEEL +1 -2
- tuft-0.1.1.dist-info/entry_points.txt +2 -0
- {tuft-0.1.0.dist-info → tuft-0.1.1.dist-info}/licenses/LICENSE +2 -2
- tuft-0.1.0.dist-info/METADATA +0 -77
- tuft-0.1.0.dist-info/RECORD +0 -6
- tuft-0.1.0.dist-info/top_level.txt +0 -1
tuft/state.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""In-memory state containers backing the FastAPI endpoints."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Dict, TypeVar
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
from tinker import types
|
|
11
|
+
|
|
12
|
+
from .auth import AuthenticationDB, User
|
|
13
|
+
from .checkpoints import CheckpointRecord
|
|
14
|
+
from .config import AppConfig
|
|
15
|
+
from .exceptions import SessionNotFoundException, UserMismatchException
|
|
16
|
+
from .futures import FutureStore
|
|
17
|
+
from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
|
|
18
|
+
from .sampling_controller import SamplingController
|
|
19
|
+
from .training_controller import TrainingController, TrainingRunRecord
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _now() -> datetime:
|
|
26
|
+
return datetime.now(timezone.utc)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SessionRecord(BaseModel):
|
|
30
|
+
"""Session record with persistence support.
|
|
31
|
+
|
|
32
|
+
Sessions are permanent records (no TTL) as they represent user sessions
|
|
33
|
+
that may need to be accessed at any time.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
session_id: str
|
|
37
|
+
tags: list[str]
|
|
38
|
+
user_metadata: dict[str, str] | None = None
|
|
39
|
+
user_id: str
|
|
40
|
+
sdk_version: str
|
|
41
|
+
created_at: datetime = Field(default_factory=_now)
|
|
42
|
+
last_heartbeat: datetime = Field(default_factory=_now)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SessionManager:
|
|
46
|
+
"""Maintains session metadata and heartbeats so other controllers can enforce ownership."""
|
|
47
|
+
|
|
48
|
+
REDIS_KEY_PREFIX = "session"
|
|
49
|
+
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
self._sessions: Dict[str, SessionRecord] = {}
|
|
52
|
+
self._restore_from_redis()
|
|
53
|
+
|
|
54
|
+
def _build_key(self, session_id: str) -> str:
|
|
55
|
+
return get_redis_store().build_key(self.REDIS_KEY_PREFIX, session_id)
|
|
56
|
+
|
|
57
|
+
def _restore_from_redis(self) -> None:
|
|
58
|
+
if not is_persistence_enabled():
|
|
59
|
+
return
|
|
60
|
+
store = get_redis_store()
|
|
61
|
+
pattern = store.build_key(self.REDIS_KEY_PREFIX, "*")
|
|
62
|
+
for key in store.keys(pattern):
|
|
63
|
+
record = load_record(key, SessionRecord)
|
|
64
|
+
if record is not None:
|
|
65
|
+
self._sessions[record.session_id] = record
|
|
66
|
+
|
|
67
|
+
def _save_session(self, session_id: str) -> None:
|
|
68
|
+
"""Save session to Redis (no TTL - permanent record)."""
|
|
69
|
+
if not is_persistence_enabled():
|
|
70
|
+
return
|
|
71
|
+
record = self._sessions.get(session_id)
|
|
72
|
+
if record is not None:
|
|
73
|
+
save_record(self._build_key(session_id), record)
|
|
74
|
+
|
|
75
|
+
def _delete_session(self, session_id: str) -> None:
|
|
76
|
+
if not is_persistence_enabled():
|
|
77
|
+
return
|
|
78
|
+
get_redis_store().delete(self._build_key(session_id))
|
|
79
|
+
|
|
80
|
+
def create_session(self, request: types.CreateSessionRequest, user: User) -> SessionRecord:
|
|
81
|
+
"""Create a new session for the given user and request."""
|
|
82
|
+
session_id = str(uuid.uuid4())
|
|
83
|
+
record = SessionRecord(
|
|
84
|
+
session_id=session_id,
|
|
85
|
+
tags=request.tags,
|
|
86
|
+
user_id=user.user_id,
|
|
87
|
+
user_metadata=request.user_metadata,
|
|
88
|
+
sdk_version=request.sdk_version,
|
|
89
|
+
)
|
|
90
|
+
self._sessions[session_id] = record
|
|
91
|
+
self._save_session(session_id)
|
|
92
|
+
return record
|
|
93
|
+
|
|
94
|
+
def require(self, session_id: str) -> SessionRecord:
|
|
95
|
+
record = self._sessions.get(session_id)
|
|
96
|
+
if record is None:
|
|
97
|
+
raise SessionNotFoundException(session_id)
|
|
98
|
+
return record
|
|
99
|
+
|
|
100
|
+
def heartbeat(self, session_id: str, user_id: str) -> None:
|
|
101
|
+
record = self.require(session_id)
|
|
102
|
+
if record.user_id != user_id:
|
|
103
|
+
raise UserMismatchException()
|
|
104
|
+
record.last_heartbeat = _now()
|
|
105
|
+
self._save_session(session_id)
|
|
106
|
+
|
|
107
|
+
def list_sessions(self, user_id: str) -> list[str]:
|
|
108
|
+
return [k for k, v in self._sessions.items() if v.user_id == user_id]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ServerState:
|
|
112
|
+
"""Application-wide container that wires controllers together
|
|
113
|
+
and exposes a simple façade to FastAPI.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(self, config: AppConfig | None = None) -> None:
|
|
117
|
+
self.config = config or AppConfig()
|
|
118
|
+
self.config.ensure_directories()
|
|
119
|
+
self.config.check_validity()
|
|
120
|
+
self.sessions = SessionManager()
|
|
121
|
+
self.training = TrainingController(self.config)
|
|
122
|
+
self.sampling = SamplingController(self.config)
|
|
123
|
+
self.auth_db = AuthenticationDB(self.config.authorized_users)
|
|
124
|
+
self.future_store = FutureStore()
|
|
125
|
+
|
|
126
|
+
async def async_init(self) -> None:
|
|
127
|
+
"""Put any async initialization logic here"""
|
|
128
|
+
await self.sampling.async_init()
|
|
129
|
+
await self._restore_from_checkpoints()
|
|
130
|
+
|
|
131
|
+
async def _restore_from_checkpoints(self) -> None:
|
|
132
|
+
"""Restore server state from checkpoints after Redis restore.
|
|
133
|
+
|
|
134
|
+
This method handles checkpoint-based recovery:
|
|
135
|
+
1. For each training run restored from Redis, create adapter and load latest checkpoint
|
|
136
|
+
2. Mark ALL futures created after checkpoint's future_id as failed
|
|
137
|
+
3. For training runs without checkpoints, mark all futures as failed
|
|
138
|
+
"""
|
|
139
|
+
# Restore training runs (adapter + checkpoint)
|
|
140
|
+
for model_id, record in self.training.training_runs.items():
|
|
141
|
+
if record.backend is None or record.corrupted:
|
|
142
|
+
continue
|
|
143
|
+
latest_ckpt = await self.training.restore_from_checkpoint(model_id)
|
|
144
|
+
|
|
145
|
+
if latest_ckpt is None:
|
|
146
|
+
self.future_store.mark_futures_failed_after_checkpoint(
|
|
147
|
+
model_id=model_id,
|
|
148
|
+
checkpoint_future_id=None,
|
|
149
|
+
error_message=f"No checkpoint found for model {model_id}. Please retry.",
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
self.future_store.mark_futures_failed_after_checkpoint(
|
|
153
|
+
model_id=model_id,
|
|
154
|
+
checkpoint_future_id=latest_ckpt.future_id,
|
|
155
|
+
error_message=(
|
|
156
|
+
f"Server restored from checkpoint {latest_ckpt.checkpoint_id}. "
|
|
157
|
+
"Operations after this checkpoint need to be retried."
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def create_session(self, request: types.CreateSessionRequest, user: User) -> SessionRecord:
|
|
162
|
+
return self.sessions.create_session(request, user)
|
|
163
|
+
|
|
164
|
+
def heartbeat(self, session_id: str, user_id: str) -> None:
|
|
165
|
+
self.sessions.heartbeat(session_id, user_id)
|
|
166
|
+
|
|
167
|
+
async def create_model(
|
|
168
|
+
self,
|
|
169
|
+
session_id: str,
|
|
170
|
+
base_model: str,
|
|
171
|
+
lora_config: types.LoraConfig,
|
|
172
|
+
model_owner: str,
|
|
173
|
+
user_metadata: dict[str, str] | None,
|
|
174
|
+
) -> TrainingRunRecord:
|
|
175
|
+
self.sessions.require(session_id)
|
|
176
|
+
return await self.training.create_model(
|
|
177
|
+
session_id=session_id,
|
|
178
|
+
base_model=base_model,
|
|
179
|
+
lora_config=lora_config,
|
|
180
|
+
model_owner=model_owner,
|
|
181
|
+
user_metadata=user_metadata,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def build_supported_models(self) -> list[types.SupportedModel]:
|
|
185
|
+
return self.training.build_supported_models()
|
|
186
|
+
|
|
187
|
+
def get_user(self, api_key: str) -> User | None:
|
|
188
|
+
return self.auth_db.authenticate(api_key)
|
|
189
|
+
|
|
190
|
+
async def run_forward(
|
|
191
|
+
self,
|
|
192
|
+
model_id: str,
|
|
193
|
+
user_id: str,
|
|
194
|
+
data: list[types.Datum],
|
|
195
|
+
loss_fn: types.LossFnType,
|
|
196
|
+
loss_fn_config: dict[str, float] | None,
|
|
197
|
+
seq_id: int | None,
|
|
198
|
+
*,
|
|
199
|
+
backward: bool,
|
|
200
|
+
) -> types.ForwardBackwardOutput:
|
|
201
|
+
return await self.training.run_forward(
|
|
202
|
+
model_id=model_id,
|
|
203
|
+
user_id=user_id,
|
|
204
|
+
data=data,
|
|
205
|
+
loss_fn=loss_fn,
|
|
206
|
+
loss_fn_config=loss_fn_config,
|
|
207
|
+
seq_id=seq_id,
|
|
208
|
+
backward=backward,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
async def run_optim_step(
|
|
212
|
+
self, model_id: str, user_id: str, params: types.AdamParams, seq_id: int | None
|
|
213
|
+
) -> types.OptimStepResponse:
|
|
214
|
+
return await self.training.run_optim_step(
|
|
215
|
+
model_id=model_id, user_id=user_id, params=params, seq_id=seq_id
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
async def create_sampling_session(
|
|
219
|
+
self,
|
|
220
|
+
session_id: str,
|
|
221
|
+
base_model: str | None,
|
|
222
|
+
model_path: str | None,
|
|
223
|
+
user_id: str,
|
|
224
|
+
*,
|
|
225
|
+
session_seq_id: int,
|
|
226
|
+
) -> str:
|
|
227
|
+
self.sessions.require(session_id)
|
|
228
|
+
return await self.sampling.create_sampling_session(
|
|
229
|
+
session_id=session_id,
|
|
230
|
+
user_id=user_id,
|
|
231
|
+
base_model=base_model,
|
|
232
|
+
model_path=model_path,
|
|
233
|
+
session_seq_id=session_seq_id,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
async def run_sample(self, request: types.SampleRequest, user_id: str) -> types.SampleResponse:
|
|
237
|
+
return await self.sampling.run_sample(request, user_id=user_id)
|
|
238
|
+
|
|
239
|
+
async def save_checkpoint(
|
|
240
|
+
self,
|
|
241
|
+
model_id: str,
|
|
242
|
+
user_id: str,
|
|
243
|
+
name: str | None,
|
|
244
|
+
checkpoint_type: types.CheckpointType,
|
|
245
|
+
seq_id: int | None = None,
|
|
246
|
+
) -> CheckpointRecord:
|
|
247
|
+
current_future_id = self.future_store.get_current_future_id()
|
|
248
|
+
return await self.training.save_checkpoint(
|
|
249
|
+
model_id=model_id,
|
|
250
|
+
user_id=user_id,
|
|
251
|
+
name=name,
|
|
252
|
+
checkpoint_type=checkpoint_type,
|
|
253
|
+
future_id=current_future_id,
|
|
254
|
+
seq_id=seq_id,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
async def load_checkpoint(
|
|
258
|
+
self, model_id: str, user_id: str, path: str, optimizer: bool, seq_id: int | None = None
|
|
259
|
+
) -> None:
|
|
260
|
+
return await self.training.load_checkpoint(
|
|
261
|
+
model_id=model_id,
|
|
262
|
+
user_id=user_id,
|
|
263
|
+
path=path,
|
|
264
|
+
optimizer=optimizer,
|
|
265
|
+
seq_id=seq_id,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def delete_checkpoint(self, model_id: str, user_id: str, checkpoint_id: str) -> None:
|
|
269
|
+
self.training.delete_checkpoint(model_id, user_id, checkpoint_id)
|
|
270
|
+
|
|
271
|
+
def list_checkpoints(self, model_id: str, user_id: str) -> list[types.Checkpoint]:
|
|
272
|
+
return self.training.list_checkpoints(model_id, user_id)
|
|
273
|
+
|
|
274
|
+
def list_user_checkpoints(self, user_id: str) -> list[types.Checkpoint]:
|
|
275
|
+
return self.training.list_user_checkpoints(user_id)
|
|
276
|
+
|
|
277
|
+
def set_checkpoint_visibility(
|
|
278
|
+
self,
|
|
279
|
+
model_id: str,
|
|
280
|
+
user_id: str,
|
|
281
|
+
checkpoint_id: str,
|
|
282
|
+
*,
|
|
283
|
+
public: bool,
|
|
284
|
+
) -> None:
|
|
285
|
+
self.training.set_visibility(
|
|
286
|
+
model_id=model_id,
|
|
287
|
+
user_id=user_id,
|
|
288
|
+
checkpoint_id=checkpoint_id,
|
|
289
|
+
public=public,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def get_weights_info(self, tinker_path: str, user_id: str) -> types.WeightsInfoResponse:
|
|
293
|
+
parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
294
|
+
return self.training.get_weights_info(parsed.training_run_id, user_id)
|
|
295
|
+
|
|
296
|
+
def build_archive_url(
|
|
297
|
+
self,
|
|
298
|
+
model_id: str,
|
|
299
|
+
user_id: str,
|
|
300
|
+
checkpoint_id: str,
|
|
301
|
+
) -> types.CheckpointArchiveUrlResponse:
|
|
302
|
+
return self.training.build_archive_url(model_id, user_id, checkpoint_id)
|
|
303
|
+
|
|
304
|
+
def list_training_runs(
|
|
305
|
+
self, *, user_id: str, limit: int | None = None, offset: int = 0
|
|
306
|
+
) -> types.TrainingRunsResponse:
|
|
307
|
+
return self.training.list_training_runs(user_id=user_id, limit=limit, offset=offset)
|
|
308
|
+
|
|
309
|
+
def get_training_run_view(self, model_id: str, user_id: str) -> types.TrainingRun:
|
|
310
|
+
return self.training.get_training_run_view(model_id, user_id)
|
|
311
|
+
|
|
312
|
+
def get_model_info(self, model_id: str, user_id: str) -> types.GetInfoResponse:
|
|
313
|
+
return self.training.get_model_info(model_id, user_id=user_id)
|
|
314
|
+
|
|
315
|
+
async def unload_model(self, model_id: str, user_id: str) -> None:
|
|
316
|
+
await self.training.unload_model(model_id, user_id=user_id)
|
|
317
|
+
await self.sampling.evict_model(model_id, user_id=user_id)
|
|
318
|
+
|
|
319
|
+
def get_session_overview(self, session_id: str, user_id: str) -> types.GetSessionResponse:
|
|
320
|
+
record = self.sessions.require(session_id)
|
|
321
|
+
if record.user_id != user_id:
|
|
322
|
+
raise UserMismatchException()
|
|
323
|
+
training_run_ids = [
|
|
324
|
+
run_id
|
|
325
|
+
for run_id, run in self.training.training_runs.items()
|
|
326
|
+
if run.session_id == session_id
|
|
327
|
+
]
|
|
328
|
+
sampler_ids = [
|
|
329
|
+
sid
|
|
330
|
+
for sid, record in self.sampling.sampling_sessions.items()
|
|
331
|
+
if record.session_id == session_id
|
|
332
|
+
]
|
|
333
|
+
return types.GetSessionResponse(training_run_ids=training_run_ids, sampler_ids=sampler_ids)
|
|
334
|
+
|
|
335
|
+
def list_sessions(
|
|
336
|
+
self, user_id: str, *, limit: int | None = None, offset: int = 0
|
|
337
|
+
) -> types.ListSessionsResponse:
|
|
338
|
+
sessions = self.sessions.list_sessions(user_id=user_id)
|
|
339
|
+
total = len(sessions)
|
|
340
|
+
start = min(offset, total)
|
|
341
|
+
if limit is None:
|
|
342
|
+
subset = sessions[start:]
|
|
343
|
+
else:
|
|
344
|
+
subset = sessions[start : min(start + limit, total)]
|
|
345
|
+
return types.ListSessionsResponse(sessions=subset)
|
|
346
|
+
|
|
347
|
+
def get_sampler_info(self, sampler_id: str, user_id: str) -> types.GetSamplerResponse:
|
|
348
|
+
return self.sampling.get_sampler_info(
|
|
349
|
+
sampler_id=sampler_id,
|
|
350
|
+
user_id=user_id,
|
|
351
|
+
default_base_model=self.config.supported_models[0].model_name,
|
|
352
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""OpenTelemetry integration for TuFT.
|
|
2
|
+
|
|
3
|
+
This module provides observability through Traces, Metrics, and Logs.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from tuft.config import TelemetryConfig
|
|
9
|
+
|
|
10
|
+
from .provider import init_telemetry, shutdown_telemetry
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"TelemetryConfig",
|
|
15
|
+
"init_telemetry",
|
|
16
|
+
"shutdown_telemetry",
|
|
17
|
+
]
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""Metrics utilities for TuFT.
|
|
2
|
+
|
|
3
|
+
Provides meter access and predefined metric instruments.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import threading
|
|
10
|
+
from collections.abc import Iterable
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import psutil
|
|
14
|
+
import pynvml
|
|
15
|
+
from opentelemetry import metrics
|
|
16
|
+
from opentelemetry.metrics import CallbackOptions, Observation
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Module-level meter cache
|
|
22
|
+
_meters: dict[str, Any] = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_meter(name: str = "tuft"):
|
|
26
|
+
"""Get a meter instance by name.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: Name for the meter (typically module name).
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
A Meter instance. When no MeterProvider is configured,
|
|
33
|
+
OpenTelemetry automatically returns a NoOpMeter.
|
|
34
|
+
"""
|
|
35
|
+
if name in _meters:
|
|
36
|
+
return _meters[name]
|
|
37
|
+
|
|
38
|
+
meter = metrics.get_meter(name)
|
|
39
|
+
_meters[name] = meter
|
|
40
|
+
return meter
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def clear_meters() -> None:
|
|
44
|
+
"""Clear the meter cache. Used during shutdown."""
|
|
45
|
+
_meters.clear()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class TuftMetrics:
|
|
49
|
+
"""Centralized metrics registry for TuFT.
|
|
50
|
+
|
|
51
|
+
Provides access to all predefined metrics with proper typing.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
_instance: "TuftMetrics | None" = None
|
|
55
|
+
_lock = threading.Lock()
|
|
56
|
+
|
|
57
|
+
def __init__(self):
|
|
58
|
+
meter = get_meter("tuft")
|
|
59
|
+
|
|
60
|
+
# Training metrics
|
|
61
|
+
self.training_models_active = meter.create_up_down_counter(
|
|
62
|
+
"tuft.training.models.active",
|
|
63
|
+
description="Number of active training models",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.training_tokens_per_second = meter.create_histogram(
|
|
67
|
+
"tuft.training.tokens_per_second",
|
|
68
|
+
description="Training tokens per second",
|
|
69
|
+
unit="tokens/s",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self.training_checkpoints_saved = meter.create_counter(
|
|
73
|
+
"tuft.training.checkpoints.saved",
|
|
74
|
+
description="Number of checkpoints saved",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.training_checkpoint_size = meter.create_histogram(
|
|
78
|
+
"tuft.training.checkpoint.size_bytes",
|
|
79
|
+
description="Checkpoint size in bytes",
|
|
80
|
+
unit="bytes",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Sampling metrics
|
|
84
|
+
self.sampling_sessions_active = meter.create_up_down_counter(
|
|
85
|
+
"tuft.sampling.sessions.active",
|
|
86
|
+
description="Number of active sampling sessions",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
self.sampling_requests = meter.create_counter(
|
|
90
|
+
"tuft.sampling.requests",
|
|
91
|
+
description="Number of sampling requests",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.sampling_duration = meter.create_histogram(
|
|
95
|
+
"tuft.sampling.duration",
|
|
96
|
+
description="Sampling request duration",
|
|
97
|
+
unit="s",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.sampling_tokens_per_second = meter.create_histogram(
|
|
101
|
+
"tuft.sampling.tokens_per_second",
|
|
102
|
+
description="Sampling tokens per second",
|
|
103
|
+
unit="tokens/s",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.sampling_output_tokens = meter.create_histogram(
|
|
107
|
+
"tuft.sampling.output_tokens",
|
|
108
|
+
description="Number of output tokens per sample",
|
|
109
|
+
unit="tokens",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Future queue metrics
|
|
113
|
+
self.futures_queue_length = meter.create_up_down_counter(
|
|
114
|
+
"tuft.futures.queue_length",
|
|
115
|
+
description="Number of futures in queue",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.futures_created = meter.create_counter(
|
|
119
|
+
"tuft.futures.created",
|
|
120
|
+
description="Number of futures created",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
self.futures_completed = meter.create_counter(
|
|
124
|
+
"tuft.futures.completed",
|
|
125
|
+
description="Number of futures completed",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.futures_wait_time = meter.create_histogram(
|
|
129
|
+
"tuft.futures.wait_time",
|
|
130
|
+
description="Time waiting for future completion",
|
|
131
|
+
unit="s",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self.futures_execution_time = meter.create_histogram(
|
|
135
|
+
"tuft.futures.execution_time",
|
|
136
|
+
description="Future execution time",
|
|
137
|
+
unit="s",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Redis metrics
|
|
141
|
+
self.redis_operation_duration = meter.create_histogram(
|
|
142
|
+
"tuft.redis.operation.duration",
|
|
143
|
+
description="Redis operation duration",
|
|
144
|
+
unit="s",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def get_instance(cls) -> "TuftMetrics":
|
|
149
|
+
"""Get the singleton metrics instance."""
|
|
150
|
+
if cls._instance is None:
|
|
151
|
+
with cls._lock:
|
|
152
|
+
if cls._instance is None:
|
|
153
|
+
cls._instance = cls()
|
|
154
|
+
return cls._instance
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_metrics() -> TuftMetrics:
|
|
158
|
+
"""Get the TuFT metrics instance."""
|
|
159
|
+
return TuftMetrics.get_instance()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ResourceMetricsCollector:
|
|
163
|
+
"""Collects system resource metrics periodically.
|
|
164
|
+
|
|
165
|
+
Collects CPU, memory, disk, GPU, and network metrics.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
_instance: "ResourceMetricsCollector | None" = None
|
|
169
|
+
_lock = threading.Lock()
|
|
170
|
+
|
|
171
|
+
def __init__(self, checkpoint_dir: str | None = None):
|
|
172
|
+
self._checkpoint_dir = checkpoint_dir
|
|
173
|
+
self._running = False
|
|
174
|
+
self._thread: threading.Thread | None = None
|
|
175
|
+
self._nvml_initialized = False
|
|
176
|
+
self._gpu_available = self._check_and_init_gpu()
|
|
177
|
+
self._setup_metrics()
|
|
178
|
+
|
|
179
|
+
def _check_and_init_gpu(self) -> bool:
|
|
180
|
+
"""Check if GPU monitoring is available and initialize NVML once."""
|
|
181
|
+
try:
|
|
182
|
+
pynvml.nvmlInit()
|
|
183
|
+
self._nvml_initialized = True
|
|
184
|
+
return True
|
|
185
|
+
except pynvml.NVMLError as e:
|
|
186
|
+
logger.debug("GPU monitoring not available: %s", e)
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
def _shutdown_gpu(self) -> None:
|
|
190
|
+
"""Shutdown NVML if it was initialized."""
|
|
191
|
+
if self._nvml_initialized:
|
|
192
|
+
try:
|
|
193
|
+
pynvml.nvmlShutdown()
|
|
194
|
+
self._nvml_initialized = False
|
|
195
|
+
except pynvml.NVMLError as e:
|
|
196
|
+
logger.warning("Failed to shutdown NVML: %s", e)
|
|
197
|
+
|
|
198
|
+
def _setup_metrics(self) -> None:
|
|
199
|
+
"""Set up observable gauges for resource metrics."""
|
|
200
|
+
meter = metrics.get_meter("tuft.resources")
|
|
201
|
+
|
|
202
|
+
# CPU metrics
|
|
203
|
+
meter.create_observable_gauge(
|
|
204
|
+
"tuft.resource.cpu.utilization_percent",
|
|
205
|
+
callbacks=[self._cpu_utilization_callback],
|
|
206
|
+
description="CPU utilization percentage",
|
|
207
|
+
unit="%",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Memory metrics
|
|
211
|
+
meter.create_observable_gauge(
|
|
212
|
+
"tuft.resource.memory.used_bytes",
|
|
213
|
+
callbacks=[self._memory_used_callback],
|
|
214
|
+
description="Memory used in bytes",
|
|
215
|
+
unit="bytes",
|
|
216
|
+
)
|
|
217
|
+
meter.create_observable_gauge(
|
|
218
|
+
"tuft.resource.memory.total_bytes",
|
|
219
|
+
callbacks=[self._memory_total_callback],
|
|
220
|
+
description="Total memory in bytes",
|
|
221
|
+
unit="bytes",
|
|
222
|
+
)
|
|
223
|
+
meter.create_observable_gauge(
|
|
224
|
+
"tuft.resource.memory.utilization_percent",
|
|
225
|
+
callbacks=[self._memory_utilization_callback],
|
|
226
|
+
description="Memory utilization percentage",
|
|
227
|
+
unit="%",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# GPU metrics (if available)
|
|
231
|
+
if self._gpu_available:
|
|
232
|
+
meter.create_observable_gauge(
|
|
233
|
+
"tuft.resource.gpu.utilization_percent",
|
|
234
|
+
callbacks=[self._gpu_utilization_callback],
|
|
235
|
+
description="GPU utilization percentage",
|
|
236
|
+
unit="%",
|
|
237
|
+
)
|
|
238
|
+
meter.create_observable_gauge(
|
|
239
|
+
"tuft.resource.gpu.memory_used_bytes",
|
|
240
|
+
callbacks=[self._gpu_memory_used_callback],
|
|
241
|
+
description="GPU memory used in bytes",
|
|
242
|
+
unit="bytes",
|
|
243
|
+
)
|
|
244
|
+
meter.create_observable_gauge(
|
|
245
|
+
"tuft.resource.gpu.memory_total_bytes",
|
|
246
|
+
callbacks=[self._gpu_memory_total_callback],
|
|
247
|
+
description="GPU total memory in bytes",
|
|
248
|
+
unit="bytes",
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Process metrics
|
|
252
|
+
meter.create_observable_gauge(
|
|
253
|
+
"tuft.resource.process.memory_used_bytes",
|
|
254
|
+
callbacks=[self._process_memory_callback],
|
|
255
|
+
description="Process memory usage in bytes",
|
|
256
|
+
unit="bytes",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def _cpu_utilization_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
260
|
+
"""Callback for CPU utilization metric."""
|
|
261
|
+
yield Observation(psutil.cpu_percent(interval=None))
|
|
262
|
+
|
|
263
|
+
def _memory_used_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
264
|
+
"""Callback for memory used metric."""
|
|
265
|
+
yield Observation(psutil.virtual_memory().used)
|
|
266
|
+
|
|
267
|
+
def _memory_total_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
268
|
+
"""Callback for total memory metric."""
|
|
269
|
+
yield Observation(psutil.virtual_memory().total)
|
|
270
|
+
|
|
271
|
+
def _memory_utilization_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
272
|
+
"""Callback for memory utilization metric."""
|
|
273
|
+
yield Observation(psutil.virtual_memory().percent)
|
|
274
|
+
|
|
275
|
+
def _gpu_utilization_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
276
|
+
"""Callback for GPU utilization metric."""
|
|
277
|
+
if not self._nvml_initialized:
|
|
278
|
+
return
|
|
279
|
+
try:
|
|
280
|
+
device_count = pynvml.nvmlDeviceGetCount()
|
|
281
|
+
for i in range(device_count):
|
|
282
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
283
|
+
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
|
284
|
+
yield Observation(int(util.gpu), {"gpu_id": str(i)})
|
|
285
|
+
except pynvml.NVMLError as e:
|
|
286
|
+
logger.debug("Failed to get GPU utilization: %s", e)
|
|
287
|
+
|
|
288
|
+
def _gpu_memory_used_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
289
|
+
"""Callback for GPU memory used metric."""
|
|
290
|
+
if not self._nvml_initialized:
|
|
291
|
+
return
|
|
292
|
+
try:
|
|
293
|
+
device_count = pynvml.nvmlDeviceGetCount()
|
|
294
|
+
for i in range(device_count):
|
|
295
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
296
|
+
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
297
|
+
yield Observation(int(mem_info.used), {"gpu_id": str(i)})
|
|
298
|
+
except pynvml.NVMLError as e:
|
|
299
|
+
logger.debug("Failed to get GPU memory used: %s", e)
|
|
300
|
+
|
|
301
|
+
def _gpu_memory_total_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
302
|
+
"""Callback for GPU total memory metric."""
|
|
303
|
+
if not self._nvml_initialized:
|
|
304
|
+
return
|
|
305
|
+
try:
|
|
306
|
+
device_count = pynvml.nvmlDeviceGetCount()
|
|
307
|
+
for i in range(device_count):
|
|
308
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
309
|
+
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
310
|
+
yield Observation(int(mem_info.total), {"gpu_id": str(i)})
|
|
311
|
+
except pynvml.NVMLError as e:
|
|
312
|
+
logger.debug("Failed to get GPU memory total: %s", e)
|
|
313
|
+
|
|
314
|
+
def _process_memory_callback(self, options: CallbackOptions) -> Iterable[Observation]:
|
|
315
|
+
"""Callback for process memory metric."""
|
|
316
|
+
process = psutil.Process()
|
|
317
|
+
yield Observation(process.memory_info().rss)
|
|
318
|
+
|
|
319
|
+
@classmethod
|
|
320
|
+
def start(cls, checkpoint_dir: str | None = None) -> "ResourceMetricsCollector":
|
|
321
|
+
"""Start the resource metrics collector."""
|
|
322
|
+
if cls._instance is None:
|
|
323
|
+
with cls._lock:
|
|
324
|
+
if cls._instance is None:
|
|
325
|
+
cls._instance = cls(checkpoint_dir)
|
|
326
|
+
return cls._instance
|
|
327
|
+
|
|
328
|
+
@classmethod
|
|
329
|
+
def shutdown(cls) -> None:
|
|
330
|
+
"""Shutdown the resource metrics collector."""
|
|
331
|
+
if cls._instance is not None:
|
|
332
|
+
with cls._lock:
|
|
333
|
+
if cls._instance is not None:
|
|
334
|
+
cls._instance._shutdown_gpu()
|
|
335
|
+
cls._instance = None
|