dask-cuda 24.6.0__py3-none-any.whl → 24.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. dask_cuda/VERSION +1 -1
  2. dask_cuda/__init__.py +19 -0
  3. dask_cuda/_version.py +12 -2
  4. dask_cuda/benchmarks/common.py +12 -10
  5. dask_cuda/benchmarks/local_cudf_groupby.py +1 -8
  6. dask_cuda/benchmarks/local_cudf_merge.py +1 -8
  7. dask_cuda/benchmarks/local_cudf_shuffle.py +0 -7
  8. dask_cuda/benchmarks/local_cupy.py +1 -8
  9. dask_cuda/benchmarks/local_cupy_map_overlap.py +1 -8
  10. dask_cuda/benchmarks/read_parquet.py +268 -0
  11. dask_cuda/benchmarks/utils.py +109 -31
  12. dask_cuda/cli.py +27 -4
  13. dask_cuda/cuda_worker.py +18 -1
  14. dask_cuda/explicit_comms/dataframe/shuffle.py +24 -20
  15. dask_cuda/local_cuda_cluster.py +31 -1
  16. dask_cuda/plugins.py +15 -0
  17. dask_cuda/tests/test_cudf_builtin_spilling.py +1 -1
  18. dask_cuda/tests/test_dask_cuda_worker.py +85 -0
  19. dask_cuda/tests/test_explicit_comms.py +38 -8
  20. dask_cuda/tests/test_gds.py +1 -1
  21. dask_cuda/tests/test_local_cuda_cluster.py +48 -0
  22. dask_cuda/tests/test_proxify_host_file.py +1 -1
  23. dask_cuda/tests/test_proxy.py +5 -5
  24. dask_cuda/tests/test_spill.py +116 -16
  25. dask_cuda/tests/test_version.py +12 -0
  26. {dask_cuda-24.6.0.dist-info → dask_cuda-24.10.0.dist-info}/METADATA +20 -20
  27. dask_cuda-24.10.0.dist-info/RECORD +55 -0
  28. {dask_cuda-24.6.0.dist-info → dask_cuda-24.10.0.dist-info}/WHEEL +1 -1
  29. dask_cuda-24.6.0.dist-info/RECORD +0 -53
  30. {dask_cuda-24.6.0.dist-info → dask_cuda-24.10.0.dist-info}/LICENSE +0 -0
  31. {dask_cuda-24.6.0.dist-info → dask_cuda-24.10.0.dist-info}/entry_points.txt +0 -0
  32. {dask_cuda-24.6.0.dist-info → dask_cuda-24.10.0.dist-info}/top_level.txt +0 -0
dask_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 24.06.00
1
+ 24.10.00
dask_cuda/__init__.py CHANGED
@@ -9,6 +9,8 @@ import dask.dataframe.core
9
9
  import dask.dataframe.shuffle
10
10
  import dask.dataframe.multi
11
11
  import dask.bag.core
12
+ from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
13
+ from distributed.protocol.serialize import dask_deserialize, dask_serialize
12
14
 
13
15
  from ._version import __git_commit__, __version__
14
16
  from .cuda_worker import CUDAWorker
@@ -48,3 +50,20 @@ dask.dataframe.shuffle.shuffle_group = proxify_decorator(
48
50
  dask.dataframe.shuffle.shuffle_group
49
51
  )
50
52
  dask.dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)
53
+
54
+
55
+ def _register_cudf_spill_aware():
56
+ import cudf
57
+
58
+ # Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
59
+ # https://github.com/rapidsai/dask-cuda/issues/1363
60
+ if not cudf.get_option("spill"):
61
+ # This reproduces the implementation of `_register_cudf`, see
62
+ # https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
63
+ from cudf.comm import serialize
64
+
65
+
66
+ for registry in [cuda_serialize, cuda_deserialize, dask_serialize, dask_deserialize]:
67
+ for lib in ["cudf", "dask_cudf"]:
68
+ if lib in registry._lazy:
69
+ registry._lazy[lib] = _register_cudf_spill_aware
dask_cuda/_version.py CHANGED
@@ -15,6 +15,16 @@
15
15
  import importlib.resources
16
16
 
17
17
  __version__ = (
18
- importlib.resources.files("dask_cuda").joinpath("VERSION").read_text().strip()
18
+ importlib.resources.files(__package__).joinpath("VERSION").read_text().strip()
19
19
  )
20
- __git_commit__ = "2fc151b061e90fae0cf95b45dbd62507aa8dd7e6"
20
+ try:
21
+ __git_commit__ = (
22
+ importlib.resources.files(__package__)
23
+ .joinpath("GIT_COMMIT")
24
+ .read_text()
25
+ .strip()
26
+ )
27
+ except FileNotFoundError:
28
+ __git_commit__ = ""
29
+
30
+ __all__ = ["__git_commit__", "__version__"]
@@ -117,16 +117,18 @@ def run(client: Client, args: Namespace, config: Config):
117
117
  wait_for_cluster(client, shutdown_on_failure=True)
118
118
  assert len(client.scheduler_info()["workers"]) > 0
119
119
  setup_memory_pools(
120
- client,
121
- args.type == "gpu",
122
- args.rmm_pool_size,
123
- args.disable_rmm_pool,
124
- args.enable_rmm_async,
125
- args.enable_rmm_managed,
126
- args.rmm_release_threshold,
127
- args.rmm_log_directory,
128
- args.enable_rmm_statistics,
129
- args.enable_rmm_track_allocations,
120
+ client=client,
121
+ is_gpu=args.type == "gpu",
122
+ disable_rmm=args.disable_rmm,
123
+ disable_rmm_pool=args.disable_rmm_pool,
124
+ pool_size=args.rmm_pool_size,
125
+ maximum_pool_size=args.rmm_maximum_pool_size,
126
+ rmm_async=args.enable_rmm_async,
127
+ rmm_managed=args.enable_rmm_managed,
128
+ release_threshold=args.rmm_release_threshold,
129
+ log_directory=args.rmm_log_directory,
130
+ statistics=args.enable_rmm_statistics,
131
+ rmm_track_allocations=args.enable_rmm_track_allocations,
130
132
  )
131
133
  address_to_index, results, message_data = gather_bench_results(client, args, config)
132
134
  p2p_bw = peer_to_peer_bandwidths(message_data, address_to_index)
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  import dask
8
8
  import dask.dataframe as dd
9
9
  from dask.distributed import performance_report, wait
10
- from dask.utils import format_bytes, parse_bytes
10
+ from dask.utils import format_bytes
11
11
 
12
12
  from dask_cuda.benchmarks.common import Config, execute_benchmark
13
13
  from dask_cuda.benchmarks.utils import (
@@ -260,13 +260,6 @@ def parse_args():
260
260
  "type": str,
261
261
  "help": "Do shuffle with GPU or CPU dataframes (default 'gpu')",
262
262
  },
263
- {
264
- "name": "--ignore-size",
265
- "default": "1 MiB",
266
- "metavar": "nbytes",
267
- "type": parse_bytes,
268
- "help": "Ignore messages smaller than this (default '1 MB')",
269
- },
270
263
  {
271
264
  "name": "--runs",
272
265
  "default": 3,
@@ -9,7 +9,7 @@ import pandas as pd
9
9
  import dask
10
10
  import dask.dataframe as dd
11
11
  from dask.distributed import performance_report, wait
12
- from dask.utils import format_bytes, parse_bytes
12
+ from dask.utils import format_bytes
13
13
 
14
14
  from dask_cuda.benchmarks.common import Config, execute_benchmark
15
15
  from dask_cuda.benchmarks.utils import (
@@ -335,13 +335,6 @@ def parse_args():
335
335
  "action": "store_true",
336
336
  "help": "Use shuffle join (takes precedence over '--broadcast-join').",
337
337
  },
338
- {
339
- "name": "--ignore-size",
340
- "default": "1 MiB",
341
- "metavar": "nbytes",
342
- "type": parse_bytes,
343
- "help": "Ignore messages smaller than this (default '1 MB')",
344
- },
345
338
  {
346
339
  "name": "--frac-match",
347
340
  "default": 0.3,
@@ -228,13 +228,6 @@ def parse_args():
228
228
  "type": str,
229
229
  "help": "Do shuffle with GPU or CPU dataframes (default 'gpu')",
230
230
  },
231
- {
232
- "name": "--ignore-size",
233
- "default": "1 MiB",
234
- "metavar": "nbytes",
235
- "type": parse_bytes,
236
- "help": "Ignore messages smaller than this (default '1 MB')",
237
- },
238
231
  {
239
232
  "name": "--runs",
240
233
  "default": 3,
@@ -8,7 +8,7 @@ from nvtx import end_range, start_range
8
8
 
9
9
  from dask import array as da
10
10
  from dask.distributed import performance_report, wait
11
- from dask.utils import format_bytes, parse_bytes
11
+ from dask.utils import format_bytes
12
12
 
13
13
  from dask_cuda.benchmarks.common import Config, execute_benchmark
14
14
  from dask_cuda.benchmarks.utils import (
@@ -297,13 +297,6 @@ def parse_args():
297
297
  "type": int,
298
298
  "help": "Chunk size (default 2500).",
299
299
  },
300
- {
301
- "name": "--ignore-size",
302
- "default": "1 MiB",
303
- "metavar": "nbytes",
304
- "type": parse_bytes,
305
- "help": "Ignore messages smaller than this (default '1 MB').",
306
- },
307
300
  {
308
301
  "name": "--runs",
309
302
  "default": 3,
@@ -10,7 +10,7 @@ from scipy.ndimage import convolve as sp_convolve
10
10
 
11
11
  from dask import array as da
12
12
  from dask.distributed import performance_report, wait
13
- from dask.utils import format_bytes, parse_bytes
13
+ from dask.utils import format_bytes
14
14
 
15
15
  from dask_cuda.benchmarks.common import Config, execute_benchmark
16
16
  from dask_cuda.benchmarks.utils import (
@@ -168,13 +168,6 @@ def parse_args():
168
168
  "type": int,
169
169
  "help": "Kernel size, 2*k+1, in each dimension (default 1)",
170
170
  },
171
- {
172
- "name": "--ignore-size",
173
- "default": "1 MiB",
174
- "metavar": "nbytes",
175
- "type": parse_bytes,
176
- "help": "Ignore messages smaller than this (default '1 MB')",
177
- },
178
171
  {
179
172
  "name": "--runs",
180
173
  "default": 3,
@@ -0,0 +1,268 @@
1
+ import contextlib
2
+ from collections import ChainMap
3
+ from time import perf_counter as clock
4
+
5
+ import fsspec
6
+ import pandas as pd
7
+
8
+ import dask
9
+ import dask.dataframe as dd
10
+ from dask.base import tokenize
11
+ from dask.distributed import performance_report
12
+ from dask.utils import format_bytes, parse_bytes
13
+
14
+ from dask_cuda.benchmarks.common import Config, execute_benchmark
15
+ from dask_cuda.benchmarks.utils import (
16
+ parse_benchmark_args,
17
+ print_key_value,
18
+ print_separator,
19
+ print_throughput_bandwidth,
20
+ )
21
+
22
+ DISK_SIZE_CACHE = {}
23
+ OPTIONS_CACHE = {}
24
+
25
+
26
+ def _noop(df):
27
+ return df
28
+
29
+
30
+ def read_data(paths, columns, backend, **kwargs):
31
+ with dask.config.set({"dataframe.backend": backend}):
32
+ return dd.read_parquet(
33
+ paths,
34
+ columns=columns,
35
+ **kwargs,
36
+ )
37
+
38
+
39
+ def get_fs_paths_kwargs(args):
40
+ kwargs = {}
41
+
42
+ storage_options = {}
43
+ if args.key:
44
+ storage_options["key"] = args.key
45
+ if args.secret:
46
+ storage_options["secret"] = args.secret
47
+
48
+ if args.filesystem == "arrow":
49
+ import pyarrow.fs as pa_fs
50
+ from fsspec.implementations.arrow import ArrowFSWrapper
51
+
52
+ _mapping = {
53
+ "key": "access_key",
54
+ "secret": "secret_key",
55
+ } # See: pyarrow.fs.S3FileSystem docs
56
+ s3_args = {}
57
+ for k, v in storage_options.items():
58
+ s3_args[_mapping[k]] = v
59
+
60
+ fs = pa_fs.FileSystem.from_uri(args.path)[0]
61
+ try:
62
+ region = {"region": fs.region}
63
+ except AttributeError:
64
+ region = {}
65
+ kwargs["filesystem"] = type(fs)(**region, **s3_args)
66
+ fsspec_fs = ArrowFSWrapper(kwargs["filesystem"])
67
+
68
+ if args.type == "gpu":
69
+ kwargs["blocksize"] = args.blocksize
70
+ else:
71
+ fsspec_fs = fsspec.core.get_fs_token_paths(
72
+ args.path, mode="rb", storage_options=storage_options
73
+ )[0]
74
+ kwargs["filesystem"] = fsspec_fs
75
+ kwargs["blocksize"] = args.blocksize
76
+ kwargs["aggregate_files"] = args.aggregate_files
77
+
78
+ # Collect list of paths
79
+ stripped_url_path = fsspec_fs._strip_protocol(args.path)
80
+ if stripped_url_path.endswith("/"):
81
+ stripped_url_path = stripped_url_path[:-1]
82
+ paths = fsspec_fs.glob(f"{stripped_url_path}/*.parquet")
83
+ if args.file_count:
84
+ paths = paths[: args.file_count]
85
+
86
+ return fsspec_fs, paths, kwargs
87
+
88
+
89
+ def bench_once(client, args, write_profile=None):
90
+ global OPTIONS_CACHE
91
+ global DISK_SIZE_CACHE
92
+
93
+ # Construct kwargs
94
+ token = tokenize(args)
95
+ try:
96
+ fsspec_fs, paths, kwargs = OPTIONS_CACHE[token]
97
+ except KeyError:
98
+ fsspec_fs, paths, kwargs = get_fs_paths_kwargs(args)
99
+ OPTIONS_CACHE[token] = (fsspec_fs, paths, kwargs)
100
+
101
+ if write_profile is None:
102
+ ctx = contextlib.nullcontext()
103
+ else:
104
+ ctx = performance_report(filename=args.profile)
105
+
106
+ with ctx:
107
+ t1 = clock()
108
+ df = read_data(
109
+ paths,
110
+ columns=args.columns,
111
+ backend="cudf" if args.type == "gpu" else "pandas",
112
+ **kwargs,
113
+ )
114
+ num_rows = len(
115
+ # Use opaque `map_partitions` call to "block"
116
+ # dask-expr from using pq metadata to get length
117
+ df.map_partitions(
118
+ _noop,
119
+ meta=df._meta,
120
+ enforce_metadata=False,
121
+ )
122
+ )
123
+ t2 = clock()
124
+
125
+ # Extract total size of files on disk
126
+ token = tokenize(paths)
127
+ try:
128
+ disk_size = DISK_SIZE_CACHE[token]
129
+ except KeyError:
130
+ disk_size = sum(fsspec_fs.sizes(paths))
131
+ DISK_SIZE_CACHE[token] = disk_size
132
+
133
+ return (disk_size, num_rows, t2 - t1)
134
+
135
+
136
+ def pretty_print_results(args, address_to_index, p2p_bw, results):
137
+ if args.markdown:
138
+ print("```")
139
+ print("Parquet read benchmark")
140
+ data_processed, row_count, durations = zip(*results)
141
+ print_separator(separator="-")
142
+ backend = "cudf" if args.type == "gpu" else "pandas"
143
+ print_key_value(key="Path", value=args.path)
144
+ print_key_value(key="Columns", value=f"{args.columns}")
145
+ print_key_value(key="Backend", value=f"{backend}")
146
+ print_key_value(key="Filesystem", value=f"{args.filesystem}")
147
+ print_key_value(key="Blocksize", value=f"{format_bytes(args.blocksize)}")
148
+ print_key_value(key="Aggregate files", value=f"{args.aggregate_files}")
149
+ print_key_value(key="Row count", value=f"{row_count[0]}")
150
+ print_key_value(key="Size on disk", value=f"{format_bytes(data_processed[0])}")
151
+ if args.markdown:
152
+ print("\n```")
153
+ args.no_show_p2p_bandwidth = True
154
+ print_throughput_bandwidth(
155
+ args, durations, data_processed, p2p_bw, address_to_index
156
+ )
157
+ print_separator(separator="=")
158
+
159
+
160
+ def create_tidy_results(args, p2p_bw, results):
161
+ configuration = {
162
+ "path": args.path,
163
+ "columns": args.columns,
164
+ "backend": "cudf" if args.type == "gpu" else "pandas",
165
+ "filesystem": args.filesystem,
166
+ "blocksize": args.blocksize,
167
+ "aggregate_files": args.aggregate_files,
168
+ }
169
+ timing_data = pd.DataFrame(
170
+ [
171
+ pd.Series(
172
+ data=ChainMap(
173
+ configuration,
174
+ {
175
+ "wallclock": duration,
176
+ "data_processed": data_processed,
177
+ "num_rows": num_rows,
178
+ },
179
+ )
180
+ )
181
+ for data_processed, num_rows, duration in results
182
+ ]
183
+ )
184
+ return timing_data, p2p_bw
185
+
186
+
187
+ def parse_args():
188
+ special_args = [
189
+ {
190
+ "name": "path",
191
+ "type": str,
192
+ "help": "Parquet directory to read from (must be a flat directory).",
193
+ },
194
+ {
195
+ "name": "--blocksize",
196
+ "default": "256MB",
197
+ "type": parse_bytes,
198
+ "help": "How to set the blocksize option",
199
+ },
200
+ {
201
+ "name": "--aggregate-files",
202
+ "default": False,
203
+ "action": "store_true",
204
+ "help": "How to set the aggregate_files option",
205
+ },
206
+ {
207
+ "name": "--file-count",
208
+ "type": int,
209
+ "help": "Maximum number of files to read.",
210
+ },
211
+ {
212
+ "name": "--columns",
213
+ "type": str,
214
+ "help": "Columns to read/select from data.",
215
+ },
216
+ {
217
+ "name": "--key",
218
+ "type": str,
219
+ "help": "Public S3 key.",
220
+ },
221
+ {
222
+ "name": "--secret",
223
+ "type": str,
224
+ "help": "Secret S3 key.",
225
+ },
226
+ {
227
+ "name": [
228
+ "-t",
229
+ "--type",
230
+ ],
231
+ "choices": ["cpu", "gpu"],
232
+ "default": "gpu",
233
+ "type": str,
234
+ "help": "Use GPU or CPU dataframes (default 'gpu')",
235
+ },
236
+ {
237
+ "name": "--filesystem",
238
+ "choices": ["arrow", "fsspec"],
239
+ "default": "fsspec",
240
+ "type": str,
241
+ "help": "Filesystem backend",
242
+ },
243
+ {
244
+ "name": "--runs",
245
+ "default": 3,
246
+ "type": int,
247
+ "help": "Number of runs",
248
+ },
249
+ ]
250
+
251
+ args = parse_benchmark_args(
252
+ description="Parquet read benchmark",
253
+ args_list=special_args,
254
+ check_explicit_comms=False,
255
+ )
256
+ args.no_show_p2p_bandwidth = True
257
+ return args
258
+
259
+
260
+ if __name__ == "__main__":
261
+ execute_benchmark(
262
+ Config(
263
+ args=parse_args(),
264
+ bench_once=bench_once,
265
+ create_tidy_results=create_tidy_results,
266
+ pretty_print_results=pretty_print_results,
267
+ )
268
+ )
@@ -17,6 +17,7 @@ from dask.utils import format_bytes, format_time, parse_bytes
17
17
  from distributed.comm.addressing import get_address_host
18
18
 
19
19
  from dask_cuda.local_cuda_cluster import LocalCUDACluster
20
+ from dask_cuda.utils import parse_device_memory_limit
20
21
 
21
22
 
22
23
  def as_noop(dsk):
@@ -93,15 +94,41 @@ def parse_benchmark_args(
93
94
  "'forkserver' can be used to avoid issues with fork not being allowed "
94
95
  "after the networking stack has been initialised.",
95
96
  )
97
+ cluster_args.add_argument(
98
+ "--disable-rmm",
99
+ action="store_true",
100
+ help="Disable RMM.",
101
+ )
102
+ cluster_args.add_argument(
103
+ "--disable-rmm-pool",
104
+ action="store_true",
105
+ help="Uses RMM for allocations but without a memory pool.",
106
+ )
96
107
  cluster_args.add_argument(
97
108
  "--rmm-pool-size",
98
109
  default=None,
99
110
  type=parse_bytes,
100
111
  help="The size of the RMM memory pool. Can be an integer (bytes) or a string "
101
- "(like '4GB' or '5000M'). By default, 1/2 of the total GPU memory is used.",
112
+ "(like '4GB' or '5000M'). By default, 1/2 of the total GPU memory is used."
113
+ ""
114
+ ".. note::"
115
+ " This size is a per-worker configuration, and not cluster-wide.",
102
116
  )
103
117
  cluster_args.add_argument(
104
- "--disable-rmm-pool", action="store_true", help="Disable the RMM memory pool"
118
+ "--rmm-maximum-pool-size",
119
+ default=None,
120
+ help="When ``--rmm-pool-size`` is specified, this argument indicates the "
121
+ "maximum pool size. Can be an integer (bytes), or a string (like '4GB' or "
122
+ "'5000M'). By default, the total available memory on the GPU is used. "
123
+ "``rmm_pool_size`` must be specified to use RMM pool and to set the maximum "
124
+ "pool size."
125
+ ""
126
+ ".. note::"
127
+ " When paired with `--enable-rmm-async` the maximum size cannot be "
128
+ " guaranteed due to fragmentation."
129
+ ""
130
+ ".. note::"
131
+ " This size is a per-worker configuration, and not cluster-wide.",
105
132
  )
106
133
  cluster_args.add_argument(
107
134
  "--enable-rmm-managed",
@@ -310,6 +337,13 @@ def parse_benchmark_args(
310
337
  "If the files already exist, new files are created with a uniquified "
311
338
  "BASENAME.",
312
339
  )
340
+ parser.add_argument(
341
+ "--ignore-size",
342
+ default="1 MiB",
343
+ metavar="nbytes",
344
+ type=parse_bytes,
345
+ help="Bandwidth statistics: ignore messages smaller than this (default '1 MB')",
346
+ )
313
347
 
314
348
  for args in args_list:
315
349
  name = args.pop("name")
@@ -407,10 +441,29 @@ def get_worker_device():
407
441
  return -1
408
442
 
409
443
 
444
+ def setup_rmm_resources(statistics=False, rmm_track_allocations=False):
445
+ import cupy
446
+
447
+ import rmm
448
+ from rmm.allocators.cupy import rmm_cupy_allocator
449
+
450
+ cupy.cuda.set_allocator(rmm_cupy_allocator)
451
+ if statistics:
452
+ rmm.mr.set_current_device_resource(
453
+ rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
454
+ )
455
+ if rmm_track_allocations:
456
+ rmm.mr.set_current_device_resource(
457
+ rmm.mr.TrackingResourceAdaptor(rmm.mr.get_current_device_resource())
458
+ )
459
+
460
+
410
461
  def setup_memory_pool(
411
462
  dask_worker=None,
463
+ disable_rmm=None,
464
+ disable_rmm_pool=None,
412
465
  pool_size=None,
413
- disable_pool=False,
466
+ maximum_pool_size=None,
414
467
  rmm_async=False,
415
468
  rmm_managed=False,
416
469
  release_threshold=None,
@@ -418,45 +471,66 @@ def setup_memory_pool(
418
471
  statistics=False,
419
472
  rmm_track_allocations=False,
420
473
  ):
421
- import cupy
422
-
423
474
  import rmm
424
- from rmm.allocators.cupy import rmm_cupy_allocator
425
475
 
426
476
  from dask_cuda.utils import get_rmm_log_file_name
427
477
 
428
478
  logging = log_directory is not None
429
479
 
430
- if rmm_async:
431
- rmm.mr.set_current_device_resource(
432
- rmm.mr.CudaAsyncMemoryResource(
433
- initial_pool_size=pool_size, release_threshold=release_threshold
434
- )
435
- )
436
- else:
437
- rmm.reinitialize(
438
- pool_allocator=not disable_pool,
439
- managed_memory=rmm_managed,
440
- initial_pool_size=pool_size,
441
- logging=logging,
442
- log_file_name=get_rmm_log_file_name(dask_worker, logging, log_directory),
443
- )
444
- cupy.cuda.set_allocator(rmm_cupy_allocator)
445
- if statistics:
446
- rmm.mr.set_current_device_resource(
447
- rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
480
+ if pool_size is not None:
481
+ pool_size = parse_device_memory_limit(pool_size, alignment_size=256)
482
+
483
+ if maximum_pool_size is not None:
484
+ maximum_pool_size = parse_device_memory_limit(
485
+ maximum_pool_size, alignment_size=256
448
486
  )
449
- if rmm_track_allocations:
450
- rmm.mr.set_current_device_resource(
451
- rmm.mr.TrackingResourceAdaptor(rmm.mr.get_current_device_resource())
487
+
488
+ if release_threshold is not None:
489
+ release_threshold = parse_device_memory_limit(
490
+ release_threshold, alignment_size=256
452
491
  )
453
492
 
493
+ if not disable_rmm:
494
+ if rmm_async:
495
+ mr = rmm.mr.CudaAsyncMemoryResource(
496
+ initial_pool_size=pool_size,
497
+ release_threshold=release_threshold,
498
+ )
499
+
500
+ if maximum_pool_size is not None:
501
+ mr = rmm.mr.LimitingResourceAdaptor(
502
+ mr, allocation_limit=maximum_pool_size
503
+ )
504
+
505
+ rmm.mr.set_current_device_resource(mr)
506
+
507
+ setup_rmm_resources(
508
+ statistics=statistics, rmm_track_allocations=rmm_track_allocations
509
+ )
510
+ else:
511
+ rmm.reinitialize(
512
+ pool_allocator=not disable_rmm_pool,
513
+ managed_memory=rmm_managed,
514
+ initial_pool_size=pool_size,
515
+ maximum_pool_size=maximum_pool_size,
516
+ logging=logging,
517
+ log_file_name=get_rmm_log_file_name(
518
+ dask_worker, logging, log_directory
519
+ ),
520
+ )
521
+
522
+ setup_rmm_resources(
523
+ statistics=statistics, rmm_track_allocations=rmm_track_allocations
524
+ )
525
+
454
526
 
455
527
  def setup_memory_pools(
456
528
  client,
457
529
  is_gpu,
530
+ disable_rmm,
531
+ disable_rmm_pool,
458
532
  pool_size,
459
- disable_pool,
533
+ maximum_pool_size,
460
534
  rmm_async,
461
535
  rmm_managed,
462
536
  release_threshold,
@@ -468,8 +542,10 @@ def setup_memory_pools(
468
542
  return
469
543
  client.run(
470
544
  setup_memory_pool,
545
+ disable_rmm=disable_rmm,
546
+ disable_rmm_pool=disable_rmm_pool,
471
547
  pool_size=pool_size,
472
- disable_pool=disable_pool,
548
+ maximum_pool_size=maximum_pool_size,
473
549
  rmm_async=rmm_async,
474
550
  rmm_managed=rmm_managed,
475
551
  release_threshold=release_threshold,
@@ -482,7 +558,9 @@ def setup_memory_pools(
482
558
  client.run_on_scheduler(
483
559
  setup_memory_pool,
484
560
  pool_size=1e9,
485
- disable_pool=disable_pool,
561
+ disable_rmm=disable_rmm,
562
+ disable_rmm_pool=disable_rmm_pool,
563
+ maximum_pool_size=maximum_pool_size,
486
564
  rmm_async=rmm_async,
487
565
  rmm_managed=rmm_managed,
488
566
  release_threshold=release_threshold,
@@ -694,7 +772,7 @@ def print_throughput_bandwidth(
694
772
  )
695
773
  print_key_value(
696
774
  key="Wall clock",
697
- value=f"{format_time(durations.mean())} +/- {format_time(durations.std()) }",
775
+ value=f"{format_time(durations.mean())} +/- {format_time(durations.std())}",
698
776
  )
699
777
  if not args.no_show_p2p_bandwidth:
700
778
  print_separator(separator="=")