wandb 0.20.1__py3-none-win32.whl → 0.20.2rc20250616__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. wandb/__init__.py +3 -6
  2. wandb/__init__.pyi +1 -1
  3. wandb/analytics/sentry.py +2 -2
  4. wandb/apis/importers/internals/internal.py +0 -3
  5. wandb/apis/public/api.py +2 -2
  6. wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
  7. wandb/apis/public/registries/registries_search.py +2 -2
  8. wandb/apis/public/registries/registry.py +19 -18
  9. wandb/bin/gpu_stats.exe +0 -0
  10. wandb/bin/wandb-core +0 -0
  11. wandb/cli/beta.py +1 -7
  12. wandb/cli/cli.py +0 -30
  13. wandb/env.py +0 -6
  14. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  15. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  16. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  17. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  18. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  19. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  20. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  21. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  22. wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
  23. wandb/sdk/backend/backend.py +1 -1
  24. wandb/sdk/internal/handler.py +1 -69
  25. wandb/sdk/lib/printer.py +6 -7
  26. wandb/sdk/lib/progress.py +1 -3
  27. wandb/sdk/lib/service/ipc_support.py +13 -0
  28. wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
  29. wandb/sdk/lib/service/service_port_file.py +105 -0
  30. wandb/sdk/lib/service/service_process.py +111 -0
  31. wandb/sdk/lib/service/service_token.py +164 -0
  32. wandb/sdk/lib/sock_client.py +8 -12
  33. wandb/sdk/wandb_init.py +0 -3
  34. wandb/sdk/wandb_require.py +9 -20
  35. wandb/sdk/wandb_run.py +0 -24
  36. wandb/sdk/wandb_settings.py +0 -9
  37. wandb/sdk/wandb_setup.py +2 -13
  38. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/METADATA +1 -3
  39. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/RECORD +42 -68
  40. wandb/sdk/internal/flow_control.py +0 -263
  41. wandb/sdk/internal/internal.py +0 -401
  42. wandb/sdk/internal/internal_util.py +0 -97
  43. wandb/sdk/internal/system/__init__.py +0 -0
  44. wandb/sdk/internal/system/assets/__init__.py +0 -25
  45. wandb/sdk/internal/system/assets/aggregators.py +0 -31
  46. wandb/sdk/internal/system/assets/asset_registry.py +0 -20
  47. wandb/sdk/internal/system/assets/cpu.py +0 -163
  48. wandb/sdk/internal/system/assets/disk.py +0 -210
  49. wandb/sdk/internal/system/assets/gpu.py +0 -416
  50. wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
  51. wandb/sdk/internal/system/assets/interfaces.py +0 -205
  52. wandb/sdk/internal/system/assets/ipu.py +0 -177
  53. wandb/sdk/internal/system/assets/memory.py +0 -166
  54. wandb/sdk/internal/system/assets/network.py +0 -125
  55. wandb/sdk/internal/system/assets/open_metrics.py +0 -293
  56. wandb/sdk/internal/system/assets/tpu.py +0 -154
  57. wandb/sdk/internal/system/assets/trainium.py +0 -393
  58. wandb/sdk/internal/system/env_probe_helpers.py +0 -13
  59. wandb/sdk/internal/system/system_info.py +0 -248
  60. wandb/sdk/internal/system/system_monitor.py +0 -224
  61. wandb/sdk/internal/writer.py +0 -204
  62. wandb/sdk/lib/service_token.py +0 -93
  63. wandb/sdk/service/__init__.py +0 -0
  64. wandb/sdk/service/_startup_debug.py +0 -22
  65. wandb/sdk/service/port_file.py +0 -53
  66. wandb/sdk/service/server.py +0 -107
  67. wandb/sdk/service/server_sock.py +0 -286
  68. wandb/sdk/service/service.py +0 -252
  69. wandb/sdk/service/streams.py +0 -425
  70. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/WHEEL +0 -0
  71. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/entry_points.txt +0 -0
  72. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/licenses/LICENSE +0 -0
@@ -1,416 +0,0 @@
1
- import logging
2
- import threading
3
- from collections import deque
4
- from typing import TYPE_CHECKING, List
5
-
6
- try:
7
- import psutil
8
- except ImportError:
9
- psutil = None
10
-
11
- from wandb.vendor.pynvml import pynvml
12
-
13
- from .aggregators import aggregate_mean
14
- from .asset_registry import asset_registry
15
- from .interfaces import Interface, Metric, MetricsMonitor
16
-
17
- if TYPE_CHECKING:
18
- from typing import Deque
19
-
20
- from wandb.sdk.internal.settings_static import SettingsStatic
21
-
22
- GPUHandle = object
23
-
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- def gpu_in_use_by_this_process(gpu_handle: "GPUHandle", pid: int) -> bool:
29
- if psutil is None:
30
- return False
31
-
32
- try:
33
- base_process = psutil.Process(pid=pid)
34
- except psutil.NoSuchProcess:
35
- # do not report any gpu metrics if the base process can't be found
36
- return False
37
-
38
- our_processes = base_process.children(recursive=True)
39
- our_processes.append(base_process)
40
-
41
- our_pids = {process.pid for process in our_processes}
42
-
43
- compute_pids = {
44
- process.pid
45
- for process in pynvml.nvmlDeviceGetComputeRunningProcesses(gpu_handle) # type: ignore
46
- }
47
- graphics_pids = {
48
- process.pid
49
- for process in pynvml.nvmlDeviceGetGraphicsRunningProcesses(gpu_handle) # type: ignore
50
- }
51
-
52
- pids_using_device = compute_pids | graphics_pids
53
-
54
- return len(pids_using_device & our_pids) > 0
55
-
56
-
57
- class GPUMemoryUtilization:
58
- """GPU memory utilization in percent for each GPU."""
59
-
60
- # name = "memory_utilization"
61
- name = "gpu.{}.memory"
62
- # samples: Deque[Tuple[datetime.datetime, float]]
63
- samples: "Deque[List[float]]"
64
-
65
- def __init__(self, pid: int) -> None:
66
- self.pid = pid
67
- self.samples = deque([])
68
-
69
- def sample(self) -> None:
70
- memory_utilization_rate = []
71
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
72
- for i in range(device_count):
73
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
74
- memory_utilization_rate.append(
75
- pynvml.nvmlDeviceGetUtilizationRates(handle).memory # type: ignore
76
- )
77
- self.samples.append(memory_utilization_rate)
78
-
79
- def clear(self) -> None:
80
- self.samples.clear()
81
-
82
- def aggregate(self) -> dict:
83
- if not self.samples:
84
- return {}
85
- stats = {}
86
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
87
- for i in range(device_count):
88
- samples = [sample[i] for sample in self.samples]
89
- aggregate = aggregate_mean(samples)
90
- stats[self.name.format(i)] = aggregate
91
-
92
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
93
- if gpu_in_use_by_this_process(handle, self.pid):
94
- stats[self.name.format(f"process.{i}")] = aggregate
95
-
96
- return stats
97
-
98
-
99
- class GPUMemoryAllocated:
100
- """GPU memory allocated in percent for each GPU."""
101
-
102
- # name = "memory_allocated"
103
- name = "gpu.{}.memoryAllocated"
104
- # samples: Deque[Tuple[datetime.datetime, float]]
105
- samples: "Deque[List[float]]"
106
-
107
- def __init__(self, pid: int) -> None:
108
- self.pid = pid
109
- self.samples = deque([])
110
-
111
- def sample(self) -> None:
112
- memory_allocated = []
113
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
114
- for i in range(device_count):
115
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
116
- memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) # type: ignore
117
- memory_allocated.append(memory_info.used / memory_info.total * 100)
118
- self.samples.append(memory_allocated)
119
-
120
- def clear(self) -> None:
121
- self.samples.clear()
122
-
123
- def aggregate(self) -> dict:
124
- if not self.samples:
125
- return {}
126
- stats = {}
127
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
128
- for i in range(device_count):
129
- samples = [sample[i] for sample in self.samples]
130
- aggregate = aggregate_mean(samples)
131
- stats[self.name.format(i)] = aggregate
132
-
133
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
134
- if gpu_in_use_by_this_process(handle, self.pid):
135
- stats[self.name.format(f"process.{i}")] = aggregate
136
-
137
- return stats
138
-
139
-
140
- class GPUMemoryAllocatedBytes:
141
- """GPU memory allocated in bytes for each GPU."""
142
-
143
- # name = "memory_allocated"
144
- name = "gpu.{}.memoryAllocatedBytes"
145
- # samples: Deque[Tuple[datetime.datetime, float]]
146
- samples: "Deque[List[float]]"
147
-
148
- def __init__(self, pid: int) -> None:
149
- self.pid = pid
150
- self.samples = deque([])
151
-
152
- def sample(self) -> None:
153
- memory_allocated = []
154
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
155
- for i in range(device_count):
156
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
157
- memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) # type: ignore
158
- memory_allocated.append(memory_info.used)
159
- self.samples.append(memory_allocated)
160
-
161
- def clear(self) -> None:
162
- self.samples.clear()
163
-
164
- def aggregate(self) -> dict:
165
- if not self.samples:
166
- return {}
167
- stats = {}
168
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
169
- for i in range(device_count):
170
- samples = [sample[i] for sample in self.samples]
171
- aggregate = aggregate_mean(samples)
172
- stats[self.name.format(i)] = aggregate
173
-
174
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
175
- if gpu_in_use_by_this_process(handle, self.pid):
176
- stats[self.name.format(f"process.{i}")] = aggregate
177
-
178
- return stats
179
-
180
-
181
- class GPUUtilization:
182
- """GPU utilization in percent for each GPU."""
183
-
184
- # name = "gpu_utilization"
185
- name = "gpu.{}.gpu"
186
- # samples: Deque[Tuple[datetime.datetime, float]]
187
- samples: "Deque[List[float]]"
188
-
189
- def __init__(self, pid: int) -> None:
190
- self.pid = pid
191
- self.samples = deque([])
192
-
193
- def sample(self) -> None:
194
- gpu_utilization_rate = []
195
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
196
- for i in range(device_count):
197
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
198
- gpu_utilization_rate.append(
199
- pynvml.nvmlDeviceGetUtilizationRates(handle).gpu # type: ignore
200
- )
201
- self.samples.append(gpu_utilization_rate)
202
-
203
- def clear(self) -> None:
204
- self.samples.clear()
205
-
206
- def aggregate(self) -> dict:
207
- if not self.samples:
208
- return {}
209
- stats = {}
210
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
211
- for i in range(device_count):
212
- samples = [sample[i] for sample in self.samples]
213
- aggregate = aggregate_mean(samples)
214
- stats[self.name.format(i)] = aggregate
215
-
216
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
217
- if gpu_in_use_by_this_process(handle, self.pid):
218
- stats[self.name.format(f"process.{i}")] = aggregate
219
-
220
- return stats
221
-
222
-
223
- class GPUTemperature:
224
- """GPU temperature in Celsius for each GPU."""
225
-
226
- # name = "gpu_temperature"
227
- name = "gpu.{}.temp"
228
- # samples: Deque[Tuple[datetime.datetime, float]]
229
- samples: "Deque[List[float]]"
230
-
231
- def __init__(self, pid: int) -> None:
232
- self.pid = pid
233
- self.samples = deque([])
234
-
235
- def sample(self) -> None:
236
- temperature = []
237
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
238
- for i in range(device_count):
239
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
240
- temperature.append(
241
- pynvml.nvmlDeviceGetTemperature( # type: ignore
242
- handle,
243
- pynvml.NVML_TEMPERATURE_GPU,
244
- )
245
- )
246
- self.samples.append(temperature)
247
-
248
- def clear(self) -> None:
249
- self.samples.clear()
250
-
251
- def aggregate(self) -> dict:
252
- if not self.samples:
253
- return {}
254
- stats = {}
255
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
256
- for i in range(device_count):
257
- samples = [sample[i] for sample in self.samples]
258
- aggregate = aggregate_mean(samples)
259
- stats[self.name.format(i)] = aggregate
260
-
261
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
262
- if gpu_in_use_by_this_process(handle, self.pid):
263
- stats[self.name.format(f"process.{i}")] = aggregate
264
-
265
- return stats
266
-
267
-
268
- class GPUPowerUsageWatts:
269
- """GPU power usage in Watts for each GPU."""
270
-
271
- name = "gpu.{}.powerWatts"
272
- # samples: Deque[Tuple[datetime.datetime, float]]
273
- samples: "Deque[List[float]]"
274
-
275
- def __init__(self, pid: int) -> None:
276
- self.pid = pid
277
- self.samples = deque([])
278
-
279
- def sample(self) -> None:
280
- power_usage = []
281
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
282
- for i in range(device_count):
283
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
284
- power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000 # type: ignore
285
- power_usage.append(power_watts)
286
- self.samples.append(power_usage)
287
-
288
- def clear(self) -> None:
289
- self.samples.clear()
290
-
291
- def aggregate(self) -> dict:
292
- stats = {}
293
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
294
- for i in range(device_count):
295
- samples = [sample[i] for sample in self.samples]
296
- aggregate = aggregate_mean(samples)
297
- stats[self.name.format(i)] = aggregate
298
-
299
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
300
- if gpu_in_use_by_this_process(handle, self.pid):
301
- stats[self.name.format(f"process.{i}")] = aggregate
302
-
303
- return stats
304
-
305
-
306
- class GPUPowerUsagePercent:
307
- """GPU power usage in percent for each GPU."""
308
-
309
- name = "gpu.{}.powerPercent"
310
- # samples: Deque[Tuple[datetime.datetime, float]]
311
- samples: "Deque[List[float]]"
312
-
313
- def __init__(self, pid: int) -> None:
314
- self.pid = pid
315
- self.samples = deque([])
316
-
317
- def sample(self) -> None:
318
- power_usage = []
319
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
320
- for i in range(device_count):
321
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
322
- power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) # type: ignore
323
- power_capacity_watts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) # type: ignore
324
- power_usage.append((power_watts / power_capacity_watts) * 100)
325
- self.samples.append(power_usage)
326
-
327
- def clear(self) -> None:
328
- self.samples.clear()
329
-
330
- def aggregate(self) -> dict:
331
- if not self.samples:
332
- return {}
333
- stats = {}
334
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
335
- for i in range(device_count):
336
- samples = [sample[i] for sample in self.samples]
337
- aggregate = aggregate_mean(samples)
338
- stats[self.name.format(i)] = aggregate
339
-
340
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
341
- if gpu_in_use_by_this_process(handle, self.pid):
342
- stats[self.name.format(f"process.{i}")] = aggregate
343
-
344
- return stats
345
-
346
-
347
- @asset_registry.register
348
- class GPU:
349
- def __init__(
350
- self,
351
- interface: "Interface",
352
- settings: "SettingsStatic",
353
- shutdown_event: threading.Event,
354
- ) -> None:
355
- self.name = self.__class__.__name__.lower()
356
- self.metrics: List[Metric] = [
357
- GPUMemoryAllocated(settings.x_stats_pid),
358
- GPUMemoryAllocatedBytes(settings.x_stats_pid),
359
- GPUMemoryUtilization(settings.x_stats_pid),
360
- GPUUtilization(settings.x_stats_pid),
361
- GPUTemperature(settings.x_stats_pid),
362
- GPUPowerUsageWatts(settings.x_stats_pid),
363
- GPUPowerUsagePercent(settings.x_stats_pid),
364
- ]
365
- self.metrics_monitor = MetricsMonitor(
366
- self.name,
367
- self.metrics,
368
- interface,
369
- settings,
370
- shutdown_event,
371
- )
372
-
373
- @classmethod
374
- def is_available(cls) -> bool:
375
- try:
376
- pynvml.nvmlInit() # type: ignore
377
- return True
378
- except pynvml.NVMLError_LibraryNotFound: # type: ignore
379
- return False
380
- except Exception:
381
- logger.exception("Error initializing NVML.")
382
- return False
383
-
384
- def start(self) -> None:
385
- self.metrics_monitor.start()
386
-
387
- def finish(self) -> None:
388
- self.metrics_monitor.finish()
389
-
390
- def probe(self) -> dict:
391
- info = {}
392
- try:
393
- pynvml.nvmlInit() # type: ignore
394
- # todo: this is an adapter for the legacy stats system:
395
- info["gpu"] = pynvml.nvmlDeviceGetName(pynvml.nvmlDeviceGetHandleByIndex(0)) # type: ignore
396
- info["gpu_count"] = pynvml.nvmlDeviceGetCount() # type: ignore
397
-
398
- device_count = pynvml.nvmlDeviceGetCount() # type: ignore
399
- devices = []
400
- for i in range(device_count):
401
- handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
402
- gpu_info = pynvml.nvmlDeviceGetMemoryInfo(handle) # type: ignore
403
- devices.append(
404
- {
405
- "name": pynvml.nvmlDeviceGetName(handle),
406
- "memory_total": gpu_info.total,
407
- }
408
- )
409
- info["gpu_devices"] = devices
410
-
411
- except pynvml.NVMLError:
412
- pass
413
- except Exception:
414
- logger.exception("Error Probing GPU.")
415
-
416
- return info
@@ -1,233 +0,0 @@
1
- import json
2
- import logging
3
- import shutil
4
- import subprocess
5
- import threading
6
- from collections import deque
7
- from typing import TYPE_CHECKING, Any, Dict, Final, List, Literal, Union
8
-
9
- from wandb.sdk.lib import telemetry
10
-
11
- from .aggregators import aggregate_mean
12
- from .asset_registry import asset_registry
13
- from .interfaces import Interface, Metric, MetricsMonitor
14
-
15
- if TYPE_CHECKING:
16
- from typing import Deque
17
-
18
- from wandb.sdk.internal.settings_static import SettingsStatic
19
-
20
-
21
- logger = logging.getLogger(__name__)
22
- ROCM_SMI_CMD: Final[str] = shutil.which("rocm-smi") or "/usr/bin/rocm-smi"
23
-
24
-
25
- _StatsKeys = Literal[
26
- "gpu",
27
- "memoryAllocated",
28
- "temp",
29
- "powerWatts",
30
- "powerPercent",
31
- ]
32
- _Stats = Dict[_StatsKeys, float]
33
-
34
-
35
- _InfoDict = Dict[str, Union[int, List[Dict[str, Any]]]]
36
-
37
-
38
- def get_rocm_smi_stats() -> Dict[str, Any]:
39
- command = [str(ROCM_SMI_CMD), "-a", "--json"]
40
- output = subprocess.check_output(command, universal_newlines=True).strip()
41
- if "No AMD GPUs specified" in output:
42
- return {}
43
- return json.loads(output.split("\n")[0]) # type: ignore
44
-
45
-
46
- def parse_stats(stats: Dict[str, str]) -> _Stats:
47
- """Parse stats from rocm-smi output."""
48
- parsed_stats: _Stats = {}
49
-
50
- try:
51
- parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
52
- except (TypeError, ValueError):
53
- logger.warning("Could not parse GPU usage as float")
54
- try:
55
- parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
56
- except (TypeError, ValueError):
57
- logger.warning("Could not parse GPU memory allocation as float")
58
- try:
59
- parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
60
- except (TypeError, ValueError):
61
- logger.warning("Could not parse GPU temperature as float")
62
- try:
63
- parsed_stats["powerWatts"] = float(
64
- stats.get("Average Graphics Package Power (W)") # type: ignore
65
- )
66
- except (TypeError, ValueError):
67
- logger.warning("Could not parse GPU power as float")
68
- try:
69
- parsed_stats["powerPercent"] = (
70
- float(stats.get("Average Graphics Package Power (W)")) # type: ignore
71
- / float(stats.get("Max Graphics Package Power (W)")) # type: ignore
72
- * 100
73
- )
74
- except (TypeError, ValueError):
75
- logger.warning("Could not parse GPU average/max power as float")
76
-
77
- return parsed_stats
78
-
79
-
80
- class GPUAMDStats:
81
- """Stats for AMD GPU devices."""
82
-
83
- name = "gpu.{gpu_id}.{key}"
84
- samples: "Deque[List[_Stats]]"
85
-
86
- def __init__(self) -> None:
87
- self.samples = deque()
88
-
89
- def sample(self) -> None:
90
- try:
91
- raw_stats = get_rocm_smi_stats()
92
- cards = []
93
-
94
- card_keys = [
95
- key for key in sorted(raw_stats.keys()) if key.startswith("card")
96
- ]
97
-
98
- for card_key in card_keys:
99
- card_stats = raw_stats[card_key]
100
- stats = parse_stats(card_stats)
101
- if stats:
102
- cards.append(stats)
103
-
104
- if cards:
105
- self.samples.append(cards)
106
-
107
- except (OSError, ValueError, TypeError, subprocess.CalledProcessError):
108
- logger.exception("GPU stats error")
109
-
110
- def clear(self) -> None:
111
- self.samples.clear()
112
-
113
- def aggregate(self) -> dict:
114
- if not self.samples:
115
- return {}
116
- stats = {}
117
- device_count = len(self.samples[0])
118
-
119
- for i in range(device_count):
120
- samples = [sample[i] for sample in self.samples]
121
-
122
- for key in samples[0].keys():
123
- samples_key = [s[key] for s in samples]
124
- aggregate = aggregate_mean(samples_key)
125
- stats[self.name.format(gpu_id=i, key=key)] = aggregate
126
-
127
- return stats
128
-
129
-
130
- @asset_registry.register
131
- class GPUAMD:
132
- """GPUAMD is a class for monitoring AMD GPU devices.
133
-
134
- Uses AMD's rocm_smi tool to get GPU stats.
135
- For the list of supported environments and devices, see
136
- https://github.com/RadeonOpenCompute/ROCm/blob/develop/docs/deploy/
137
- """
138
-
139
- def __init__(
140
- self,
141
- interface: "Interface",
142
- settings: "SettingsStatic",
143
- shutdown_event: threading.Event,
144
- ) -> None:
145
- self.name = self.__class__.__name__.lower()
146
- self.metrics: List[Metric] = [
147
- GPUAMDStats(),
148
- ]
149
- self.metrics_monitor = MetricsMonitor(
150
- self.name,
151
- self.metrics,
152
- interface,
153
- settings,
154
- shutdown_event,
155
- )
156
- telemetry_record = telemetry.TelemetryRecord()
157
- telemetry_record.env.amd_gpu = True
158
- interface._publish_telemetry(telemetry_record)
159
-
160
- @classmethod
161
- def is_available(cls) -> bool:
162
- rocm_smi_available = shutil.which(ROCM_SMI_CMD) is not None
163
- if not rocm_smi_available:
164
- # If rocm-smi is not available, we can't monitor AMD GPUs
165
- return False
166
-
167
- is_driver_initialized = False
168
-
169
- try:
170
- # inspired by https://github.com/ROCm/rocm_smi_lib/blob/5d2cd0c2715ae45b8f9cfe1e777c6c2cd52fb601/python_smi_tools/rocm_smi.py#L71C1-L81C17
171
- with open("/sys/module/amdgpu/initstate") as file:
172
- file_content = file.read()
173
- if "live" in file_content:
174
- is_driver_initialized = True
175
- except FileNotFoundError:
176
- pass
177
-
178
- can_read_rocm_smi = False
179
- try:
180
- # try to read stats from rocm-smi and parse them
181
- raw_stats = get_rocm_smi_stats()
182
- card_keys = [
183
- key for key in sorted(raw_stats.keys()) if key.startswith("card")
184
- ]
185
-
186
- for card_key in card_keys:
187
- card_stats = raw_stats[card_key]
188
- parse_stats(card_stats)
189
-
190
- can_read_rocm_smi = True
191
- except Exception:
192
- pass
193
-
194
- return is_driver_initialized and can_read_rocm_smi
195
-
196
- def start(self) -> None:
197
- self.metrics_monitor.start()
198
-
199
- def finish(self) -> None:
200
- self.metrics_monitor.finish()
201
-
202
- def probe(self) -> dict:
203
- info: _InfoDict = {}
204
- try:
205
- stats = get_rocm_smi_stats()
206
-
207
- info["gpu_count"] = len(
208
- [key for key in stats.keys() if key.startswith("card")]
209
- )
210
- key_mapping = {
211
- "id": "GPU ID",
212
- "unique_id": "Unique ID",
213
- "vbios_version": "VBIOS version",
214
- "performance_level": "Performance Level",
215
- "gpu_overdrive": "GPU OverDrive value (%)",
216
- "gpu_memory_overdrive": "GPU Memory OverDrive value (%)",
217
- "max_power": "Max Graphics Package Power (W)",
218
- "series": "Card series",
219
- "model": "Card model",
220
- "vendor": "Card vendor",
221
- "sku": "Card SKU",
222
- "sclk_range": "Valid sclk range",
223
- "mclk_range": "Valid mclk range",
224
- }
225
-
226
- info["gpu_devices"] = [
227
- {k: stats[key][v] for k, v in key_mapping.items() if stats[key].get(v)}
228
- for key in stats.keys()
229
- if key.startswith("card")
230
- ]
231
- except Exception:
232
- logger.exception("GPUAMD probe error")
233
- return info