dask-cuda 25.2.0__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2021-2025 NVIDIA CORPORATION.
2
+
1
3
  import asyncio
2
4
  import multiprocessing as mp
3
5
  import os
@@ -19,18 +21,16 @@ from distributed.deploy.local import LocalCluster
19
21
 
20
22
  import dask_cuda
21
23
  from dask_cuda.explicit_comms import comms
22
- from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
24
+ from dask_cuda.explicit_comms.dataframe.shuffle import (
25
+ _contains_shuffle_expr,
26
+ shuffle as explicit_comms_shuffle,
27
+ )
23
28
  from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
24
29
 
25
30
  mp = mp.get_context("spawn") # type: ignore
26
31
  ucp = pytest.importorskip("ucp")
27
32
 
28
33
 
29
- # Set default shuffle method to "tasks"
30
- if dask.config.get("dataframe.shuffle.method", None) is None:
31
- dask.config.set({"dataframe.shuffle.method": "tasks"})
32
-
33
-
34
34
  # Notice, all of the following tests is executed in a new process such
35
35
  # that UCX options of the different tests doesn't conflict.
36
36
 
@@ -415,3 +415,133 @@ def test_lock_workers():
415
415
  p.join()
416
416
 
417
417
  assert all(p.exitcode == 0 for p in ps)
418
+
419
+
420
+ def test_create_destroy_create():
421
+ # https://github.com/rapidsai/dask-cuda/issues/1450
422
+ assert len(comms._comms_cache) == 0
423
+ with LocalCluster(n_workers=1) as cluster:
424
+ with Client(cluster) as client:
425
+ context = comms.default_comms()
426
+ scheduler_addresses_old = list(client.scheduler_info()["workers"].keys())
427
+ comms_addresses_old = list(comms.default_comms().worker_addresses)
428
+ assert comms.default_comms() is context
429
+ assert len(comms._comms_cache) == 1
430
+
431
+ # Add a worker, which should have a new comms object
432
+ cluster.scale(2)
433
+ client.wait_for_workers(2, timeout=5)
434
+ context2 = comms.default_comms()
435
+ assert context is not context2
436
+ assert len(comms._comms_cache) == 2
437
+
438
+ del context
439
+ del context2
440
+ assert len(comms._comms_cache) == 0
441
+ assert scheduler_addresses_old == comms_addresses_old
442
+
443
+ # A new cluster should have a new comms object. Previously, this failed
444
+ # because we referenced the old cluster's addresses.
445
+ with LocalCluster(n_workers=1) as cluster:
446
+ with Client(cluster) as client:
447
+ scheduler_addresses_new = list(client.scheduler_info()["workers"].keys())
448
+ comms_addresses_new = list(comms.default_comms().worker_addresses)
449
+
450
+ assert scheduler_addresses_new == comms_addresses_new
451
+
452
+
453
+ def test_scaled_cluster_gets_new_comms_context():
454
+ # Ensure that if we create a CommsContext, scale the cluster,
455
+ # and create a new CommsContext, then the new CommsContext
456
+ # should include the new worker.
457
+ # https://github.com/rapidsai/dask-cuda/issues/1450
458
+
459
+ name = "explicit-comms-shuffle"
460
+ ddf = dd.from_pandas(pd.DataFrame({"key": np.arange(10)}), npartitions=2)
461
+
462
+ with LocalCluster(n_workers=2) as cluster:
463
+ with Client(cluster) as client:
464
+ context_1 = comms.default_comms()
465
+
466
+ def check(dask_worker, session_id: int):
467
+ has_state = hasattr(dask_worker, "_explicit_comm_state")
468
+ has_state_for_session = (
469
+ has_state and session_id in dask_worker._explicit_comm_state
470
+ )
471
+ if has_state_for_session:
472
+ n_workers = dask_worker._explicit_comm_state[session_id]["nworkers"]
473
+ else:
474
+ n_workers = None
475
+ return {
476
+ "has_state": has_state,
477
+ "has_state_for_session": has_state_for_session,
478
+ "n_workers": n_workers,
479
+ }
480
+
481
+ result_1 = client.run(check, session_id=context_1.sessionId)
482
+ expected_values = {
483
+ "has_state": True,
484
+ "has_state_for_session": True,
485
+ "n_workers": 2,
486
+ }
487
+ expected_1 = {
488
+ k: expected_values for k in client.scheduler_info()["workers"]
489
+ }
490
+ assert result_1 == expected_1
491
+
492
+ # Run a shuffle with the initial setup as a sanity test
493
+ with dask.config.set(explicit_comms=True):
494
+ shuffled = ddf.shuffle(on="key", npartitions=4)
495
+ assert any(name in str(key) for key in shuffled.dask)
496
+ result = shuffled.compute()
497
+
498
+ with dask.config.set(explicit_comms=False):
499
+ shuffled = ddf.shuffle(on="key", npartitions=4)
500
+ expected = shuffled.compute()
501
+
502
+ assert_eq(result, expected)
503
+
504
+ # --- Scale the cluster ---
505
+ cluster.scale(3)
506
+ client.wait_for_workers(3, timeout=5)
507
+
508
+ context_2 = comms.default_comms()
509
+ result_2 = client.run(check, session_id=context_2.sessionId)
510
+ expected_values = {
511
+ "has_state": True,
512
+ "has_state_for_session": True,
513
+ "n_workers": 3,
514
+ }
515
+ expected_2 = {
516
+ k: expected_values for k in client.scheduler_info()["workers"]
517
+ }
518
+ assert result_2 == expected_2
519
+
520
+ # Run a shuffle with the new setup
521
+ with dask.config.set(explicit_comms=True):
522
+ shuffled = ddf.shuffle(on="key", npartitions=4)
523
+ assert any(name in str(key) for key in shuffled.dask)
524
+ result = shuffled.compute()
525
+
526
+ with dask.config.set(explicit_comms=False):
527
+ shuffled = ddf.shuffle(on="key", npartitions=4)
528
+ expected = shuffled.compute()
529
+
530
+ assert_eq(result, expected)
531
+
532
+
533
+ def test_contains_shuffle_expr():
534
+ df = dd.from_pandas(pd.DataFrame({"key": np.arange(10)}), npartitions=2)
535
+ assert not _contains_shuffle_expr(df)
536
+
537
+ with dask.config.set(explicit_comms=True):
538
+ shuffled = df.shuffle(on="key")
539
+
540
+ assert _contains_shuffle_expr(shuffled)
541
+ assert not _contains_shuffle_expr(df)
542
+
543
+ # this requires an active client.
544
+ with LocalCluster(n_workers=1) as cluster:
545
+ with Client(cluster):
546
+ explict_shuffled = explicit_comms_shuffle(df, ["key"])
547
+ assert not _contains_shuffle_expr(explict_shuffled)
@@ -1,4 +1,5 @@
1
1
  import multiprocessing as mp
2
+ import sys
2
3
 
3
4
  import numpy
4
5
  import psutil
@@ -214,3 +215,38 @@ def test_initialize_ucx_all(protocol):
214
215
  p.start()
215
216
  p.join()
216
217
  assert not p.exitcode
218
+
219
+
220
+ def _test_dask_cuda_import():
221
+ # Check that importing `dask_cuda` does NOT
222
+ # require `dask.dataframe` or `dask.array`.
223
+
224
+ # Patch sys.modules so that `dask.dataframe`
225
+ # and `dask.array` cannot be found.
226
+ with pytest.MonkeyPatch.context() as monkeypatch:
227
+ for k in list(sys.modules):
228
+ if k.startswith("dask.dataframe") or k.startswith("dask.array"):
229
+ monkeypatch.setitem(sys.modules, k, None)
230
+ monkeypatch.delitem(sys.modules, "dask_cuda")
231
+
232
+ # Check that top-level imports still succeed.
233
+ import dask_cuda # noqa: F401
234
+ from dask_cuda import CUDAWorker # noqa: F401
235
+ from dask_cuda import LocalCUDACluster
236
+
237
+ with LocalCUDACluster(
238
+ dashboard_address=None,
239
+ n_workers=1,
240
+ threads_per_worker=1,
241
+ processes=True,
242
+ worker_class=IncreasedCloseTimeoutNanny,
243
+ ) as cluster:
244
+ with Client(cluster) as client:
245
+ client.run(lambda *args: None)
246
+
247
+
248
+ def test_dask_cuda_import():
249
+ p = mp.Process(target=_test_dask_cuda_import)
250
+ p.start()
251
+ p.join()
252
+ assert not p.exitcode
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import asyncio
2
5
  import os
3
6
  import pkgutil
@@ -16,7 +19,7 @@ from dask_cuda.utils import (
16
19
  get_cluster_configuration,
17
20
  get_device_total_memory,
18
21
  get_gpu_count_mig,
19
- get_gpu_uuid_from_index,
22
+ get_gpu_uuid,
20
23
  print_cluster_config,
21
24
  )
22
25
  from dask_cuda.utils_test import MockWorker
@@ -419,7 +422,7 @@ async def test_available_mig_workers():
419
422
 
420
423
  @gen_test(timeout=20)
421
424
  async def test_gpu_uuid():
422
- gpu_uuid = get_gpu_uuid_from_index(0)
425
+ gpu_uuid = get_gpu_uuid(0)
423
426
 
424
427
  async with LocalCUDACluster(
425
428
  CUDA_VISIBLE_DEVICES=gpu_uuid,
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  from typing import Iterable
2
4
  from unittest.mock import patch
3
5
 
@@ -414,7 +416,7 @@ async def test_compatibility_mode_dataframe_shuffle(compatibility_mode, npartiti
414
416
  ddf = dask.dataframe.from_pandas(
415
417
  cudf.DataFrame({"key": np.arange(10)}), npartitions=npartitions
416
418
  )
417
- res = ddf.shuffle(on="key", shuffle_method="tasks").persist()
419
+ [res] = client.persist([ddf.shuffle(on="key", shuffle_method="tasks")])
418
420
 
419
421
  # With compatibility mode on, we shouldn't encounter any proxy objects
420
422
  if compatibility_mode:
@@ -440,7 +442,7 @@ async def test_worker_force_spill_to_disk():
440
442
  async with Client(cluster, asynchronous=True) as client:
441
443
  # Create a df that are spilled to host memory immediately
442
444
  df = cudf.DataFrame({"key": np.arange(10**8)})
443
- ddf = dask.dataframe.from_pandas(df, npartitions=1).persist()
445
+ [ddf] = client.persist([dask.dataframe.from_pandas(df, npartitions=1)])
444
446
  await ddf
445
447
 
446
448
  async def f(dask_worker):
@@ -498,3 +500,14 @@ def test_on_demand_debug_info():
498
500
  assert f"WARNING - RMM allocation of {size} failed" in log
499
501
  assert f"RMM allocs: {size}" in log
500
502
  assert "traceback:" in log
503
+
504
+
505
+ def test_sizeof_owner_with_cai():
506
+ cudf = pytest.importorskip("cudf")
507
+ s = cudf.Series([1, 2, 3])
508
+
509
+ items = dask_cuda.get_device_memory_objects.dispatch(s)
510
+ assert len(items) == 1
511
+ item = items[0]
512
+ result = dask.sizeof.sizeof(item)
513
+ assert result == 24
@@ -1,14 +1,18 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
+
1
3
  import gc
2
4
  import os
3
5
  from time import sleep
6
+ from typing import TypedDict
4
7
 
5
8
  import pytest
6
9
 
7
10
  import dask
8
11
  from dask import array as da
9
- from distributed import Client, wait
12
+ from distributed import Client, Worker, wait
10
13
  from distributed.metrics import time
11
14
  from distributed.sizeof import sizeof
15
+ from distributed.utils import Deadline
12
16
  from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401
13
17
 
14
18
  import dask_cudf
@@ -72,24 +76,66 @@ def cudf_spill(request):
72
76
 
73
77
 
74
78
  def device_host_file_size_matches(
75
- dhf, total_bytes, device_chunk_overhead=0, serialized_chunk_overhead=1024
79
+ dask_worker: Worker,
80
+ total_bytes,
81
+ device_chunk_overhead=0,
82
+ serialized_chunk_overhead=1024,
76
83
  ):
77
- byte_sum = dhf.device_buffer.fast.total_weight
84
+ worker_data_sizes = collect_device_host_file_size(
85
+ dask_worker,
86
+ device_chunk_overhead=device_chunk_overhead,
87
+ serialized_chunk_overhead=serialized_chunk_overhead,
88
+ )
89
+ byte_sum = (
90
+ worker_data_sizes["device_fast"]
91
+ + worker_data_sizes["host_fast"]
92
+ + worker_data_sizes["host_buffer"]
93
+ + worker_data_sizes["disk"]
94
+ )
95
+ return (
96
+ byte_sum >= total_bytes
97
+ and byte_sum
98
+ <= total_bytes
99
+ + worker_data_sizes["device_overhead"]
100
+ + worker_data_sizes["host_overhead"]
101
+ + worker_data_sizes["disk_overhead"]
102
+ )
103
+
104
+
105
+ class WorkerDataSizes(TypedDict):
106
+ device_fast: int
107
+ host_fast: int
108
+ host_buffer: int
109
+ disk: int
110
+ device_overhead: int
111
+ host_overhead: int
112
+ disk_overhead: int
113
+
114
+
115
+ def collect_device_host_file_size(
116
+ dask_worker: Worker,
117
+ device_chunk_overhead: int,
118
+ serialized_chunk_overhead: int,
119
+ ) -> WorkerDataSizes:
120
+ dhf = dask_worker.data
78
121
 
79
- # `dhf.host_buffer.fast` is only available when Worker's `memory_limit != 0`
122
+ device_fast = dhf.device_buffer.fast.total_weight or 0
80
123
  if hasattr(dhf.host_buffer, "fast"):
81
- byte_sum += dhf.host_buffer.fast.total_weight
124
+ host_fast = dhf.host_buffer.fast.total_weight or 0
125
+ host_buffer = 0
82
126
  else:
83
- byte_sum += sum([sizeof(b) for b in dhf.host_buffer.values()])
127
+ host_buffer = sum([sizeof(b) for b in dhf.host_buffer.values()])
128
+ host_fast = 0
84
129
 
85
- # `dhf.disk` is only available when Worker's `memory_limit != 0`
86
130
  if dhf.disk is not None:
87
131
  file_path = [
88
132
  os.path.join(dhf.disk.directory, fname)
89
133
  for fname in dhf.disk.filenames.values()
90
134
  ]
91
135
  file_size = [os.path.getsize(f) for f in file_path]
92
- byte_sum += sum(file_size)
136
+ disk = sum(file_size)
137
+ else:
138
+ disk = 0
93
139
 
94
140
  # Allow up to chunk_overhead bytes overhead per chunk
95
141
  device_overhead = len(dhf.device) * device_chunk_overhead
@@ -98,17 +144,25 @@ def device_host_file_size_matches(
98
144
  len(dhf.disk) * serialized_chunk_overhead if dhf.disk is not None else 0
99
145
  )
100
146
 
101
- return (
102
- byte_sum >= total_bytes
103
- and byte_sum <= total_bytes + device_overhead + host_overhead + disk_overhead
147
+ return WorkerDataSizes(
148
+ device_fast=device_fast,
149
+ host_fast=host_fast,
150
+ host_buffer=host_buffer,
151
+ disk=disk,
152
+ device_overhead=device_overhead,
153
+ host_overhead=host_overhead,
154
+ disk_overhead=disk_overhead,
104
155
  )
105
156
 
106
157
 
107
158
  def assert_device_host_file_size(
108
- dhf, total_bytes, device_chunk_overhead=0, serialized_chunk_overhead=1024
159
+ dask_worker: Worker,
160
+ total_bytes,
161
+ device_chunk_overhead=0,
162
+ serialized_chunk_overhead=1024,
109
163
  ):
110
164
  assert device_host_file_size_matches(
111
- dhf, total_bytes, device_chunk_overhead, serialized_chunk_overhead
165
+ dask_worker, total_bytes, device_chunk_overhead, serialized_chunk_overhead
112
166
  )
113
167
 
114
168
 
@@ -119,7 +173,7 @@ def worker_assert(
119
173
  dask_worker=None,
120
174
  ):
121
175
  assert_device_host_file_size(
122
- dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
176
+ dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
123
177
  )
124
178
 
125
179
 
@@ -131,12 +185,12 @@ def delayed_worker_assert(
131
185
  ):
132
186
  start = time()
133
187
  while not device_host_file_size_matches(
134
- dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead
188
+ dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead
135
189
  ):
136
190
  sleep(0.01)
137
191
  if time() < start + 3:
138
192
  assert_device_host_file_size(
139
- dask_worker.data,
193
+ dask_worker,
140
194
  total_size,
141
195
  device_chunk_overhead,
142
196
  serialized_chunk_overhead,
@@ -224,8 +278,8 @@ async def test_cupy_cluster_device_spill(params):
224
278
  x = rs.random(int(50e6), chunks=2e6)
225
279
  await wait(x)
226
280
 
227
- xx = x.persist()
228
- await wait(xx)
281
+ [xx] = client.persist([x])
282
+ await xx
229
283
 
230
284
  # Allow up to 1024 bytes overhead per chunk serialized
231
285
  await client.run(
@@ -344,19 +398,38 @@ async def test_cudf_cluster_device_spill(params, cudf_spill):
344
398
  sizes = sizes.to_arrow().to_pylist()
345
399
  nbytes = sum(sizes)
346
400
 
347
- cdf2 = cdf.persist()
348
- await wait(cdf2)
401
+ [cdf2] = client.persist([cdf])
402
+ await cdf2
349
403
 
350
404
  del cdf
351
405
  gc.collect()
352
406
 
353
407
  if enable_cudf_spill:
354
- await client.run(
355
- worker_assert,
356
- 0,
357
- 0,
358
- 0,
408
+ expected_data = WorkerDataSizes(
409
+ device_fast=0,
410
+ host_fast=0,
411
+ host_buffer=0,
412
+ disk=0,
413
+ device_overhead=0,
414
+ host_overhead=0,
415
+ disk_overhead=0,
359
416
  )
417
+
418
+ deadline = Deadline.after(duration=3)
419
+ while not deadline.expired:
420
+ data = await client.run(
421
+ collect_device_host_file_size,
422
+ device_chunk_overhead=0,
423
+ serialized_chunk_overhead=0,
424
+ )
425
+ expected = {k: expected_data for k in data}
426
+ if data == expected:
427
+ break
428
+ sleep(0.01)
429
+
430
+ # final assertion for pytest to reraise with a nice traceback
431
+ assert data == expected
432
+
360
433
  else:
361
434
  await client.run(
362
435
  assert_host_chunks,
@@ -419,8 +492,8 @@ async def test_cudf_spill_cluster(cudf_spill):
419
492
  }
420
493
  )
421
494
 
422
- ddf = dask_cudf.from_cudf(cdf, npartitions=2).sum().persist()
423
- await wait(ddf)
495
+ [ddf] = client.persist([dask_cudf.from_cudf(cdf, npartitions=2).sum()])
496
+ await ddf
424
497
 
425
498
  await client.run(_assert_cudf_spill_stats, enable_cudf_spill)
426
499
  _assert_cudf_spill_stats(enable_cudf_spill)
dask_cuda/utils.py CHANGED
@@ -1,3 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
  import math
2
5
  import operator
3
6
  import os
@@ -86,6 +89,45 @@ def get_gpu_count():
86
89
  return pynvml.nvmlDeviceGetCount()
87
90
 
88
91
 
92
+ def get_gpu_handle(device_id=0):
93
+ """Get GPU handle from device index or UUID.
94
+
95
+ Parameters
96
+ ----------
97
+ device_id: int or str
98
+ The index or UUID of the device from which to obtain the handle.
99
+
100
+ Raises
101
+ ------
102
+ ValueError
103
+ If acquiring the device handle for the device specified failed.
104
+ pynvml.NVMLError
105
+ If any NVML error occurred while initializing.
106
+
107
+ Examples
108
+ --------
109
+ >>> get_gpu_handle(device_id=0)
110
+
111
+ >>> get_gpu_handle(device_id="GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
112
+ """
113
+ pynvml.nvmlInit()
114
+
115
+ try:
116
+ if device_id and not str(device_id).isnumeric():
117
+ # This means device_id is UUID.
118
+ # This works for both MIG and non-MIG device UUIDs.
119
+ handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_id))
120
+ if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
121
+ # Additionally get parent device handle
122
+ # if the device itself is a MIG instance
123
+ handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
124
+ else:
125
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
126
+ return handle
127
+ except pynvml.NVMLError:
128
+ raise ValueError(f"Invalid device index or UUID: {device_id}")
129
+
130
+
89
131
  @toolz.memoize
90
132
  def get_gpu_count_mig(return_uuids=False):
91
133
  """Return the number of MIG instances available
@@ -129,7 +171,7 @@ def get_cpu_affinity(device_index=None):
129
171
  Parameters
130
172
  ----------
131
173
  device_index: int or str
132
- Index or UUID of the GPU device
174
+ The index or UUID of the device from which to obtain the CPU affinity.
133
175
 
134
176
  Examples
135
177
  --------
@@ -148,26 +190,15 @@ def get_cpu_affinity(device_index=None):
148
190
  40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
149
191
  60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
150
192
  """
151
- pynvml.nvmlInit()
152
-
153
193
  try:
154
- if device_index and not str(device_index).isnumeric():
155
- # This means device_index is UUID.
156
- # This works for both MIG and non-MIG device UUIDs.
157
- handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
158
- if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
159
- # Additionally get parent device handle
160
- # if the device itself is a MIG instance
161
- handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
162
- else:
163
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
194
+ handle = get_gpu_handle(device_index)
164
195
  # Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
165
196
  affinity = pynvml.nvmlDeviceGetCpuAffinity(
166
197
  handle,
167
198
  math.ceil(get_cpu_count() / 64),
168
199
  )
169
200
  return unpack_bitmask(affinity)
170
- except pynvml.NVMLError:
201
+ except (pynvml.NVMLError, ValueError):
171
202
  warnings.warn(
172
203
  "Cannot get CPU affinity for device with index %d, setting default affinity"
173
204
  % device_index
@@ -182,18 +213,15 @@ def get_n_gpus():
182
213
  return get_gpu_count()
183
214
 
184
215
 
185
- def get_device_total_memory(index=0):
186
- """
187
- Return total memory of CUDA device with index or with device identifier UUID
188
- """
189
- pynvml.nvmlInit()
216
+ def get_device_total_memory(device_index=0):
217
+ """Return total memory of CUDA device with index or with device identifier UUID.
190
218
 
191
- if index and not str(index).isnumeric():
192
- # This means index is UUID. This works for both MIG and non-MIG device UUIDs.
193
- handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(str(index)))
194
- else:
195
- # This is a device index
196
- handle = pynvml.nvmlDeviceGetHandleByIndex(index)
219
+ Parameters
220
+ ----------
221
+ device_index: int or str
222
+ The index or UUID of the device from which to obtain the CPU affinity.
223
+ """
224
+ handle = get_gpu_handle(device_index)
197
225
  return pynvml.nvmlDeviceGetMemoryInfo(handle).total
198
226
 
199
227
 
@@ -553,26 +581,26 @@ def parse_device_memory_limit(device_memory_limit, device_index=0, alignment_siz
553
581
  return _align(int(device_memory_limit), alignment_size)
554
582
 
555
583
 
556
- def get_gpu_uuid_from_index(device_index=0):
584
+ def get_gpu_uuid(device_index=0):
557
585
  """Get GPU UUID from CUDA device index.
558
586
 
559
587
  Parameters
560
588
  ----------
561
589
  device_index: int or str
562
- The index of the device from which to obtain the UUID. Default: 0.
590
+ The index or UUID of the device from which to obtain the UUID.
563
591
 
564
592
  Examples
565
593
  --------
566
- >>> get_gpu_uuid_from_index()
594
+ >>> get_gpu_uuid()
567
595
  'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'
568
596
 
569
- >>> get_gpu_uuid_from_index(3)
597
+ >>> get_gpu_uuid(3)
570
598
  'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
571
- """
572
- import pynvml
573
599
 
574
- pynvml.nvmlInit()
575
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
600
+ >>> get_gpu_uuid("GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
601
+ 'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
602
+ """
603
+ handle = get_gpu_handle(device_index)
576
604
  try:
577
605
  return pynvml.nvmlDeviceGetUUID(handle).decode("utf-8")
578
606
  except AttributeError:
@@ -1,9 +1,9 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: dask-cuda
3
- Version: 25.2.0
3
+ Version: 25.6.0
4
4
  Summary: Utilities for Dask and CUDA interactions
5
5
  Author: NVIDIA Corporation
6
- License: Apache 2.0
6
+ License: Apache-2.0
7
7
  Project-URL: Homepage, https://github.com/rapidsai/dask-cuda
8
8
  Project-URL: Documentation, https://docs.rapids.ai/api/dask-cuda/stable/
9
9
  Project-URL: Source, https://github.com/rapidsai/dask-cuda
@@ -15,21 +15,23 @@ Classifier: Programming Language :: Python :: 3
15
15
  Classifier: Programming Language :: Python :: 3.10
16
16
  Classifier: Programming Language :: Python :: 3.11
17
17
  Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
18
19
  Requires-Python: >=3.10
19
20
  Description-Content-Type: text/markdown
20
21
  License-File: LICENSE
21
22
  Requires-Dist: click>=8.1
22
- Requires-Dist: numba<0.61.0a0,>=0.59.1
23
+ Requires-Dist: numba<0.62.0a0,>=0.59.1
23
24
  Requires-Dist: numpy<3.0a0,>=1.23
24
25
  Requires-Dist: pandas>=1.3
25
26
  Requires-Dist: pynvml<13.0.0a0,>=12.0.0
26
- Requires-Dist: rapids-dask-dependency==25.2.*
27
+ Requires-Dist: rapids-dask-dependency==25.6.*
27
28
  Requires-Dist: zict>=2.0.0
28
29
  Provides-Extra: docs
29
30
  Requires-Dist: numpydoc>=1.1.0; extra == "docs"
30
31
  Requires-Dist: sphinx; extra == "docs"
31
32
  Requires-Dist: sphinx-click>=2.7.1; extra == "docs"
32
33
  Requires-Dist: sphinx-rtd-theme>=0.5.1; extra == "docs"
34
+ Dynamic: license-file
33
35
 
34
36
  Dask CUDA
35
37
  =========