tuft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tuft/state.py ADDED
@@ -0,0 +1,352 @@
1
+ """In-memory state containers backing the FastAPI endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from datetime import datetime, timezone
7
+ from typing import Dict, TypeVar
8
+
9
+ from pydantic import BaseModel, Field
10
+ from tinker import types
11
+
12
+ from .auth import AuthenticationDB, User
13
+ from .checkpoints import CheckpointRecord
14
+ from .config import AppConfig
15
+ from .exceptions import SessionNotFoundException, UserMismatchException
16
+ from .futures import FutureStore
17
+ from .persistence import get_redis_store, is_persistence_enabled, load_record, save_record
18
+ from .sampling_controller import SamplingController
19
+ from .training_controller import TrainingController, TrainingRunRecord
20
+
21
+
22
+ T = TypeVar("T")
23
+
24
+
25
+ def _now() -> datetime:
26
+ return datetime.now(timezone.utc)
27
+
28
+
29
+ class SessionRecord(BaseModel):
30
+ """Session record with persistence support.
31
+
32
+ Sessions are permanent records (no TTL) as they represent user sessions
33
+ that may need to be accessed at any time.
34
+ """
35
+
36
+ session_id: str
37
+ tags: list[str]
38
+ user_metadata: dict[str, str] | None = None
39
+ user_id: str
40
+ sdk_version: str
41
+ created_at: datetime = Field(default_factory=_now)
42
+ last_heartbeat: datetime = Field(default_factory=_now)
43
+
44
+
45
+ class SessionManager:
46
+ """Maintains session metadata and heartbeats so other controllers can enforce ownership."""
47
+
48
+ REDIS_KEY_PREFIX = "session"
49
+
50
+ def __init__(self) -> None:
51
+ self._sessions: Dict[str, SessionRecord] = {}
52
+ self._restore_from_redis()
53
+
54
+ def _build_key(self, session_id: str) -> str:
55
+ return get_redis_store().build_key(self.REDIS_KEY_PREFIX, session_id)
56
+
57
+ def _restore_from_redis(self) -> None:
58
+ if not is_persistence_enabled():
59
+ return
60
+ store = get_redis_store()
61
+ pattern = store.build_key(self.REDIS_KEY_PREFIX, "*")
62
+ for key in store.keys(pattern):
63
+ record = load_record(key, SessionRecord)
64
+ if record is not None:
65
+ self._sessions[record.session_id] = record
66
+
67
+ def _save_session(self, session_id: str) -> None:
68
+ """Save session to Redis (no TTL - permanent record)."""
69
+ if not is_persistence_enabled():
70
+ return
71
+ record = self._sessions.get(session_id)
72
+ if record is not None:
73
+ save_record(self._build_key(session_id), record)
74
+
75
+ def _delete_session(self, session_id: str) -> None:
76
+ if not is_persistence_enabled():
77
+ return
78
+ get_redis_store().delete(self._build_key(session_id))
79
+
80
+ def create_session(self, request: types.CreateSessionRequest, user: User) -> SessionRecord:
81
+ """Create a new session for the given user and request."""
82
+ session_id = str(uuid.uuid4())
83
+ record = SessionRecord(
84
+ session_id=session_id,
85
+ tags=request.tags,
86
+ user_id=user.user_id,
87
+ user_metadata=request.user_metadata,
88
+ sdk_version=request.sdk_version,
89
+ )
90
+ self._sessions[session_id] = record
91
+ self._save_session(session_id)
92
+ return record
93
+
94
+ def require(self, session_id: str) -> SessionRecord:
95
+ record = self._sessions.get(session_id)
96
+ if record is None:
97
+ raise SessionNotFoundException(session_id)
98
+ return record
99
+
100
+ def heartbeat(self, session_id: str, user_id: str) -> None:
101
+ record = self.require(session_id)
102
+ if record.user_id != user_id:
103
+ raise UserMismatchException()
104
+ record.last_heartbeat = _now()
105
+ self._save_session(session_id)
106
+
107
+ def list_sessions(self, user_id: str) -> list[str]:
108
+ return [k for k, v in self._sessions.items() if v.user_id == user_id]
109
+
110
+
111
+ class ServerState:
112
+ """Application-wide container that wires controllers together
113
+ and exposes a simple façade to FastAPI.
114
+ """
115
+
116
+ def __init__(self, config: AppConfig | None = None) -> None:
117
+ self.config = config or AppConfig()
118
+ self.config.ensure_directories()
119
+ self.config.check_validity()
120
+ self.sessions = SessionManager()
121
+ self.training = TrainingController(self.config)
122
+ self.sampling = SamplingController(self.config)
123
+ self.auth_db = AuthenticationDB(self.config.authorized_users)
124
+ self.future_store = FutureStore()
125
+
126
+ async def async_init(self) -> None:
127
+ """Put any async initialization logic here"""
128
+ await self.sampling.async_init()
129
+ await self._restore_from_checkpoints()
130
+
131
+ async def _restore_from_checkpoints(self) -> None:
132
+ """Restore server state from checkpoints after Redis restore.
133
+
134
+ This method handles checkpoint-based recovery:
135
+ 1. For each training run restored from Redis, create adapter and load latest checkpoint
136
+ 2. Mark ALL futures created after checkpoint's future_id as failed
137
+ 3. For training runs without checkpoints, mark all futures as failed
138
+ """
139
+ # Restore training runs (adapter + checkpoint)
140
+ for model_id, record in self.training.training_runs.items():
141
+ if record.backend is None or record.corrupted:
142
+ continue
143
+ latest_ckpt = await self.training.restore_from_checkpoint(model_id)
144
+
145
+ if latest_ckpt is None:
146
+ self.future_store.mark_futures_failed_after_checkpoint(
147
+ model_id=model_id,
148
+ checkpoint_future_id=None,
149
+ error_message=f"No checkpoint found for model {model_id}. Please retry.",
150
+ )
151
+ else:
152
+ self.future_store.mark_futures_failed_after_checkpoint(
153
+ model_id=model_id,
154
+ checkpoint_future_id=latest_ckpt.future_id,
155
+ error_message=(
156
+ f"Server restored from checkpoint {latest_ckpt.checkpoint_id}. "
157
+ "Operations after this checkpoint need to be retried."
158
+ ),
159
+ )
160
+
161
+ def create_session(self, request: types.CreateSessionRequest, user: User) -> SessionRecord:
162
+ return self.sessions.create_session(request, user)
163
+
164
+ def heartbeat(self, session_id: str, user_id: str) -> None:
165
+ self.sessions.heartbeat(session_id, user_id)
166
+
167
+ async def create_model(
168
+ self,
169
+ session_id: str,
170
+ base_model: str,
171
+ lora_config: types.LoraConfig,
172
+ model_owner: str,
173
+ user_metadata: dict[str, str] | None,
174
+ ) -> TrainingRunRecord:
175
+ self.sessions.require(session_id)
176
+ return await self.training.create_model(
177
+ session_id=session_id,
178
+ base_model=base_model,
179
+ lora_config=lora_config,
180
+ model_owner=model_owner,
181
+ user_metadata=user_metadata,
182
+ )
183
+
184
+ def build_supported_models(self) -> list[types.SupportedModel]:
185
+ return self.training.build_supported_models()
186
+
187
+ def get_user(self, api_key: str) -> User | None:
188
+ return self.auth_db.authenticate(api_key)
189
+
190
+ async def run_forward(
191
+ self,
192
+ model_id: str,
193
+ user_id: str,
194
+ data: list[types.Datum],
195
+ loss_fn: types.LossFnType,
196
+ loss_fn_config: dict[str, float] | None,
197
+ seq_id: int | None,
198
+ *,
199
+ backward: bool,
200
+ ) -> types.ForwardBackwardOutput:
201
+ return await self.training.run_forward(
202
+ model_id=model_id,
203
+ user_id=user_id,
204
+ data=data,
205
+ loss_fn=loss_fn,
206
+ loss_fn_config=loss_fn_config,
207
+ seq_id=seq_id,
208
+ backward=backward,
209
+ )
210
+
211
+ async def run_optim_step(
212
+ self, model_id: str, user_id: str, params: types.AdamParams, seq_id: int | None
213
+ ) -> types.OptimStepResponse:
214
+ return await self.training.run_optim_step(
215
+ model_id=model_id, user_id=user_id, params=params, seq_id=seq_id
216
+ )
217
+
218
+ async def create_sampling_session(
219
+ self,
220
+ session_id: str,
221
+ base_model: str | None,
222
+ model_path: str | None,
223
+ user_id: str,
224
+ *,
225
+ session_seq_id: int,
226
+ ) -> str:
227
+ self.sessions.require(session_id)
228
+ return await self.sampling.create_sampling_session(
229
+ session_id=session_id,
230
+ user_id=user_id,
231
+ base_model=base_model,
232
+ model_path=model_path,
233
+ session_seq_id=session_seq_id,
234
+ )
235
+
236
+ async def run_sample(self, request: types.SampleRequest, user_id: str) -> types.SampleResponse:
237
+ return await self.sampling.run_sample(request, user_id=user_id)
238
+
239
+ async def save_checkpoint(
240
+ self,
241
+ model_id: str,
242
+ user_id: str,
243
+ name: str | None,
244
+ checkpoint_type: types.CheckpointType,
245
+ seq_id: int | None = None,
246
+ ) -> CheckpointRecord:
247
+ current_future_id = self.future_store.get_current_future_id()
248
+ return await self.training.save_checkpoint(
249
+ model_id=model_id,
250
+ user_id=user_id,
251
+ name=name,
252
+ checkpoint_type=checkpoint_type,
253
+ future_id=current_future_id,
254
+ seq_id=seq_id,
255
+ )
256
+
257
+ async def load_checkpoint(
258
+ self, model_id: str, user_id: str, path: str, optimizer: bool, seq_id: int | None = None
259
+ ) -> None:
260
+ return await self.training.load_checkpoint(
261
+ model_id=model_id,
262
+ user_id=user_id,
263
+ path=path,
264
+ optimizer=optimizer,
265
+ seq_id=seq_id,
266
+ )
267
+
268
+ def delete_checkpoint(self, model_id: str, user_id: str, checkpoint_id: str) -> None:
269
+ self.training.delete_checkpoint(model_id, user_id, checkpoint_id)
270
+
271
+ def list_checkpoints(self, model_id: str, user_id: str) -> list[types.Checkpoint]:
272
+ return self.training.list_checkpoints(model_id, user_id)
273
+
274
+ def list_user_checkpoints(self, user_id: str) -> list[types.Checkpoint]:
275
+ return self.training.list_user_checkpoints(user_id)
276
+
277
+ def set_checkpoint_visibility(
278
+ self,
279
+ model_id: str,
280
+ user_id: str,
281
+ checkpoint_id: str,
282
+ *,
283
+ public: bool,
284
+ ) -> None:
285
+ self.training.set_visibility(
286
+ model_id=model_id,
287
+ user_id=user_id,
288
+ checkpoint_id=checkpoint_id,
289
+ public=public,
290
+ )
291
+
292
+ def get_weights_info(self, tinker_path: str, user_id: str) -> types.WeightsInfoResponse:
293
+ parsed = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
294
+ return self.training.get_weights_info(parsed.training_run_id, user_id)
295
+
296
+ def build_archive_url(
297
+ self,
298
+ model_id: str,
299
+ user_id: str,
300
+ checkpoint_id: str,
301
+ ) -> types.CheckpointArchiveUrlResponse:
302
+ return self.training.build_archive_url(model_id, user_id, checkpoint_id)
303
+
304
+ def list_training_runs(
305
+ self, *, user_id: str, limit: int | None = None, offset: int = 0
306
+ ) -> types.TrainingRunsResponse:
307
+ return self.training.list_training_runs(user_id=user_id, limit=limit, offset=offset)
308
+
309
+ def get_training_run_view(self, model_id: str, user_id: str) -> types.TrainingRun:
310
+ return self.training.get_training_run_view(model_id, user_id)
311
+
312
+ def get_model_info(self, model_id: str, user_id: str) -> types.GetInfoResponse:
313
+ return self.training.get_model_info(model_id, user_id=user_id)
314
+
315
+ async def unload_model(self, model_id: str, user_id: str) -> None:
316
+ await self.training.unload_model(model_id, user_id=user_id)
317
+ await self.sampling.evict_model(model_id, user_id=user_id)
318
+
319
+ def get_session_overview(self, session_id: str, user_id: str) -> types.GetSessionResponse:
320
+ record = self.sessions.require(session_id)
321
+ if record.user_id != user_id:
322
+ raise UserMismatchException()
323
+ training_run_ids = [
324
+ run_id
325
+ for run_id, run in self.training.training_runs.items()
326
+ if run.session_id == session_id
327
+ ]
328
+ sampler_ids = [
329
+ sid
330
+ for sid, record in self.sampling.sampling_sessions.items()
331
+ if record.session_id == session_id
332
+ ]
333
+ return types.GetSessionResponse(training_run_ids=training_run_ids, sampler_ids=sampler_ids)
334
+
335
+ def list_sessions(
336
+ self, user_id: str, *, limit: int | None = None, offset: int = 0
337
+ ) -> types.ListSessionsResponse:
338
+ sessions = self.sessions.list_sessions(user_id=user_id)
339
+ total = len(sessions)
340
+ start = min(offset, total)
341
+ if limit is None:
342
+ subset = sessions[start:]
343
+ else:
344
+ subset = sessions[start : min(start + limit, total)]
345
+ return types.ListSessionsResponse(sessions=subset)
346
+
347
+ def get_sampler_info(self, sampler_id: str, user_id: str) -> types.GetSamplerResponse:
348
+ return self.sampling.get_sampler_info(
349
+ sampler_id=sampler_id,
350
+ user_id=user_id,
351
+ default_base_model=self.config.supported_models[0].model_name,
352
+ )
@@ -0,0 +1,17 @@
1
+ """OpenTelemetry integration for TuFT.
2
+
3
+ This module provides observability through Traces, Metrics, and Logs.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from tuft.config import TelemetryConfig
9
+
10
+ from .provider import init_telemetry, shutdown_telemetry
11
+
12
+
13
+ __all__ = [
14
+ "TelemetryConfig",
15
+ "init_telemetry",
16
+ "shutdown_telemetry",
17
+ ]
@@ -0,0 +1,335 @@
1
+ """Metrics utilities for TuFT.
2
+
3
+ Provides meter access and predefined metric instruments.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import threading
10
+ from collections.abc import Iterable
11
+ from typing import Any
12
+
13
+ import psutil
14
+ import pynvml
15
+ from opentelemetry import metrics
16
+ from opentelemetry.metrics import CallbackOptions, Observation
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Module-level meter cache
22
+ _meters: dict[str, Any] = {}
23
+
24
+
25
+ def get_meter(name: str = "tuft"):
26
+ """Get a meter instance by name.
27
+
28
+ Args:
29
+ name: Name for the meter (typically module name).
30
+
31
+ Returns:
32
+ A Meter instance. When no MeterProvider is configured,
33
+ OpenTelemetry automatically returns a NoOpMeter.
34
+ """
35
+ if name in _meters:
36
+ return _meters[name]
37
+
38
+ meter = metrics.get_meter(name)
39
+ _meters[name] = meter
40
+ return meter
41
+
42
+
43
+ def clear_meters() -> None:
44
+ """Clear the meter cache. Used during shutdown."""
45
+ _meters.clear()
46
+
47
+
48
+ class TuftMetrics:
49
+ """Centralized metrics registry for TuFT.
50
+
51
+ Provides access to all predefined metrics with proper typing.
52
+ """
53
+
54
+ _instance: "TuftMetrics | None" = None
55
+ _lock = threading.Lock()
56
+
57
+ def __init__(self):
58
+ meter = get_meter("tuft")
59
+
60
+ # Training metrics
61
+ self.training_models_active = meter.create_up_down_counter(
62
+ "tuft.training.models.active",
63
+ description="Number of active training models",
64
+ )
65
+
66
+ self.training_tokens_per_second = meter.create_histogram(
67
+ "tuft.training.tokens_per_second",
68
+ description="Training tokens per second",
69
+ unit="tokens/s",
70
+ )
71
+
72
+ self.training_checkpoints_saved = meter.create_counter(
73
+ "tuft.training.checkpoints.saved",
74
+ description="Number of checkpoints saved",
75
+ )
76
+
77
+ self.training_checkpoint_size = meter.create_histogram(
78
+ "tuft.training.checkpoint.size_bytes",
79
+ description="Checkpoint size in bytes",
80
+ unit="bytes",
81
+ )
82
+
83
+ # Sampling metrics
84
+ self.sampling_sessions_active = meter.create_up_down_counter(
85
+ "tuft.sampling.sessions.active",
86
+ description="Number of active sampling sessions",
87
+ )
88
+
89
+ self.sampling_requests = meter.create_counter(
90
+ "tuft.sampling.requests",
91
+ description="Number of sampling requests",
92
+ )
93
+
94
+ self.sampling_duration = meter.create_histogram(
95
+ "tuft.sampling.duration",
96
+ description="Sampling request duration",
97
+ unit="s",
98
+ )
99
+
100
+ self.sampling_tokens_per_second = meter.create_histogram(
101
+ "tuft.sampling.tokens_per_second",
102
+ description="Sampling tokens per second",
103
+ unit="tokens/s",
104
+ )
105
+
106
+ self.sampling_output_tokens = meter.create_histogram(
107
+ "tuft.sampling.output_tokens",
108
+ description="Number of output tokens per sample",
109
+ unit="tokens",
110
+ )
111
+
112
+ # Future queue metrics
113
+ self.futures_queue_length = meter.create_up_down_counter(
114
+ "tuft.futures.queue_length",
115
+ description="Number of futures in queue",
116
+ )
117
+
118
+ self.futures_created = meter.create_counter(
119
+ "tuft.futures.created",
120
+ description="Number of futures created",
121
+ )
122
+
123
+ self.futures_completed = meter.create_counter(
124
+ "tuft.futures.completed",
125
+ description="Number of futures completed",
126
+ )
127
+
128
+ self.futures_wait_time = meter.create_histogram(
129
+ "tuft.futures.wait_time",
130
+ description="Time waiting for future completion",
131
+ unit="s",
132
+ )
133
+
134
+ self.futures_execution_time = meter.create_histogram(
135
+ "tuft.futures.execution_time",
136
+ description="Future execution time",
137
+ unit="s",
138
+ )
139
+
140
+ # Redis metrics
141
+ self.redis_operation_duration = meter.create_histogram(
142
+ "tuft.redis.operation.duration",
143
+ description="Redis operation duration",
144
+ unit="s",
145
+ )
146
+
147
+ @classmethod
148
+ def get_instance(cls) -> "TuftMetrics":
149
+ """Get the singleton metrics instance."""
150
+ if cls._instance is None:
151
+ with cls._lock:
152
+ if cls._instance is None:
153
+ cls._instance = cls()
154
+ return cls._instance
155
+
156
+
157
+ def get_metrics() -> TuftMetrics:
158
+ """Get the TuFT metrics instance."""
159
+ return TuftMetrics.get_instance()
160
+
161
+
162
+ class ResourceMetricsCollector:
163
+ """Collects system resource metrics periodically.
164
+
165
+ Collects CPU, memory, disk, GPU, and network metrics.
166
+ """
167
+
168
+ _instance: "ResourceMetricsCollector | None" = None
169
+ _lock = threading.Lock()
170
+
171
+ def __init__(self, checkpoint_dir: str | None = None):
172
+ self._checkpoint_dir = checkpoint_dir
173
+ self._running = False
174
+ self._thread: threading.Thread | None = None
175
+ self._nvml_initialized = False
176
+ self._gpu_available = self._check_and_init_gpu()
177
+ self._setup_metrics()
178
+
179
+ def _check_and_init_gpu(self) -> bool:
180
+ """Check if GPU monitoring is available and initialize NVML once."""
181
+ try:
182
+ pynvml.nvmlInit()
183
+ self._nvml_initialized = True
184
+ return True
185
+ except pynvml.NVMLError as e:
186
+ logger.debug("GPU monitoring not available: %s", e)
187
+ return False
188
+
189
+ def _shutdown_gpu(self) -> None:
190
+ """Shutdown NVML if it was initialized."""
191
+ if self._nvml_initialized:
192
+ try:
193
+ pynvml.nvmlShutdown()
194
+ self._nvml_initialized = False
195
+ except pynvml.NVMLError as e:
196
+ logger.warning("Failed to shutdown NVML: %s", e)
197
+
198
+ def _setup_metrics(self) -> None:
199
+ """Set up observable gauges for resource metrics."""
200
+ meter = metrics.get_meter("tuft.resources")
201
+
202
+ # CPU metrics
203
+ meter.create_observable_gauge(
204
+ "tuft.resource.cpu.utilization_percent",
205
+ callbacks=[self._cpu_utilization_callback],
206
+ description="CPU utilization percentage",
207
+ unit="%",
208
+ )
209
+
210
+ # Memory metrics
211
+ meter.create_observable_gauge(
212
+ "tuft.resource.memory.used_bytes",
213
+ callbacks=[self._memory_used_callback],
214
+ description="Memory used in bytes",
215
+ unit="bytes",
216
+ )
217
+ meter.create_observable_gauge(
218
+ "tuft.resource.memory.total_bytes",
219
+ callbacks=[self._memory_total_callback],
220
+ description="Total memory in bytes",
221
+ unit="bytes",
222
+ )
223
+ meter.create_observable_gauge(
224
+ "tuft.resource.memory.utilization_percent",
225
+ callbacks=[self._memory_utilization_callback],
226
+ description="Memory utilization percentage",
227
+ unit="%",
228
+ )
229
+
230
+ # GPU metrics (if available)
231
+ if self._gpu_available:
232
+ meter.create_observable_gauge(
233
+ "tuft.resource.gpu.utilization_percent",
234
+ callbacks=[self._gpu_utilization_callback],
235
+ description="GPU utilization percentage",
236
+ unit="%",
237
+ )
238
+ meter.create_observable_gauge(
239
+ "tuft.resource.gpu.memory_used_bytes",
240
+ callbacks=[self._gpu_memory_used_callback],
241
+ description="GPU memory used in bytes",
242
+ unit="bytes",
243
+ )
244
+ meter.create_observable_gauge(
245
+ "tuft.resource.gpu.memory_total_bytes",
246
+ callbacks=[self._gpu_memory_total_callback],
247
+ description="GPU total memory in bytes",
248
+ unit="bytes",
249
+ )
250
+
251
+ # Process metrics
252
+ meter.create_observable_gauge(
253
+ "tuft.resource.process.memory_used_bytes",
254
+ callbacks=[self._process_memory_callback],
255
+ description="Process memory usage in bytes",
256
+ unit="bytes",
257
+ )
258
+
259
+ def _cpu_utilization_callback(self, options: CallbackOptions) -> Iterable[Observation]:
260
+ """Callback for CPU utilization metric."""
261
+ yield Observation(psutil.cpu_percent(interval=None))
262
+
263
+ def _memory_used_callback(self, options: CallbackOptions) -> Iterable[Observation]:
264
+ """Callback for memory used metric."""
265
+ yield Observation(psutil.virtual_memory().used)
266
+
267
+ def _memory_total_callback(self, options: CallbackOptions) -> Iterable[Observation]:
268
+ """Callback for total memory metric."""
269
+ yield Observation(psutil.virtual_memory().total)
270
+
271
+ def _memory_utilization_callback(self, options: CallbackOptions) -> Iterable[Observation]:
272
+ """Callback for memory utilization metric."""
273
+ yield Observation(psutil.virtual_memory().percent)
274
+
275
+ def _gpu_utilization_callback(self, options: CallbackOptions) -> Iterable[Observation]:
276
+ """Callback for GPU utilization metric."""
277
+ if not self._nvml_initialized:
278
+ return
279
+ try:
280
+ device_count = pynvml.nvmlDeviceGetCount()
281
+ for i in range(device_count):
282
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
283
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
284
+ yield Observation(int(util.gpu), {"gpu_id": str(i)})
285
+ except pynvml.NVMLError as e:
286
+ logger.debug("Failed to get GPU utilization: %s", e)
287
+
288
+ def _gpu_memory_used_callback(self, options: CallbackOptions) -> Iterable[Observation]:
289
+ """Callback for GPU memory used metric."""
290
+ if not self._nvml_initialized:
291
+ return
292
+ try:
293
+ device_count = pynvml.nvmlDeviceGetCount()
294
+ for i in range(device_count):
295
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
296
+ mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
297
+ yield Observation(int(mem_info.used), {"gpu_id": str(i)})
298
+ except pynvml.NVMLError as e:
299
+ logger.debug("Failed to get GPU memory used: %s", e)
300
+
301
+ def _gpu_memory_total_callback(self, options: CallbackOptions) -> Iterable[Observation]:
302
+ """Callback for GPU total memory metric."""
303
+ if not self._nvml_initialized:
304
+ return
305
+ try:
306
+ device_count = pynvml.nvmlDeviceGetCount()
307
+ for i in range(device_count):
308
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
309
+ mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
310
+ yield Observation(int(mem_info.total), {"gpu_id": str(i)})
311
+ except pynvml.NVMLError as e:
312
+ logger.debug("Failed to get GPU memory total: %s", e)
313
+
314
+ def _process_memory_callback(self, options: CallbackOptions) -> Iterable[Observation]:
315
+ """Callback for process memory metric."""
316
+ process = psutil.Process()
317
+ yield Observation(process.memory_info().rss)
318
+
319
+ @classmethod
320
+ def start(cls, checkpoint_dir: str | None = None) -> "ResourceMetricsCollector":
321
+ """Start the resource metrics collector."""
322
+ if cls._instance is None:
323
+ with cls._lock:
324
+ if cls._instance is None:
325
+ cls._instance = cls(checkpoint_dir)
326
+ return cls._instance
327
+
328
+ @classmethod
329
+ def shutdown(cls) -> None:
330
+ """Shutdown the resource metrics collector."""
331
+ if cls._instance is not None:
332
+ with cls._lock:
333
+ if cls._instance is not None:
334
+ cls._instance._shutdown_gpu()
335
+ cls._instance = None