dask-cuda 25.12.0__py3-none-manylinux_2_28_aarch64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. dask_cuda/GIT_COMMIT +1 -0
  2. dask_cuda/VERSION +1 -0
  3. dask_cuda/__init__.py +57 -0
  4. dask_cuda/_compat.py +19 -0
  5. dask_cuda/_version.py +19 -0
  6. dask_cuda/benchmarks/__init__.py +2 -0
  7. dask_cuda/benchmarks/common.py +216 -0
  8. dask_cuda/benchmarks/local_cudf_groupby.py +278 -0
  9. dask_cuda/benchmarks/local_cudf_merge.py +373 -0
  10. dask_cuda/benchmarks/local_cudf_shuffle.py +327 -0
  11. dask_cuda/benchmarks/local_cupy.py +327 -0
  12. dask_cuda/benchmarks/local_cupy_map_overlap.py +198 -0
  13. dask_cuda/benchmarks/read_parquet.py +270 -0
  14. dask_cuda/benchmarks/utils.py +936 -0
  15. dask_cuda/cli.py +546 -0
  16. dask_cuda/cuda_worker.py +237 -0
  17. dask_cuda/device_host_file.py +325 -0
  18. dask_cuda/disk_io.py +227 -0
  19. dask_cuda/explicit_comms/__init__.py +2 -0
  20. dask_cuda/explicit_comms/comms.py +359 -0
  21. dask_cuda/explicit_comms/dataframe/__init__.py +2 -0
  22. dask_cuda/explicit_comms/dataframe/shuffle.py +722 -0
  23. dask_cuda/get_device_memory_objects.py +155 -0
  24. dask_cuda/initialize.py +245 -0
  25. dask_cuda/is_device_object.py +44 -0
  26. dask_cuda/is_spillable_object.py +59 -0
  27. dask_cuda/local_cuda_cluster.py +459 -0
  28. dask_cuda/plugins.py +209 -0
  29. dask_cuda/proxify_device_objects.py +263 -0
  30. dask_cuda/proxify_host_file.py +795 -0
  31. dask_cuda/proxy_object.py +951 -0
  32. dask_cuda/tests/conftest.py +41 -0
  33. dask_cuda/tests/test_cudf_builtin_spilling.py +155 -0
  34. dask_cuda/tests/test_dask_cuda_worker.py +696 -0
  35. dask_cuda/tests/test_dask_setup.py +193 -0
  36. dask_cuda/tests/test_device_host_file.py +204 -0
  37. dask_cuda/tests/test_dgx.py +227 -0
  38. dask_cuda/tests/test_explicit_comms.py +566 -0
  39. dask_cuda/tests/test_from_array.py +20 -0
  40. dask_cuda/tests/test_gds.py +47 -0
  41. dask_cuda/tests/test_initialize.py +434 -0
  42. dask_cuda/tests/test_local_cuda_cluster.py +661 -0
  43. dask_cuda/tests/test_proxify_host_file.py +534 -0
  44. dask_cuda/tests/test_proxy.py +698 -0
  45. dask_cuda/tests/test_spill.py +504 -0
  46. dask_cuda/tests/test_utils.py +348 -0
  47. dask_cuda/tests/test_version.py +13 -0
  48. dask_cuda/tests/test_worker_spec.py +83 -0
  49. dask_cuda/utils.py +974 -0
  50. dask_cuda/utils_test.py +48 -0
  51. dask_cuda/worker_common.py +196 -0
  52. dask_cuda/worker_spec.py +131 -0
  53. dask_cuda-25.12.0.dist-info/METADATA +75 -0
  54. dask_cuda-25.12.0.dist-info/RECORD +61 -0
  55. dask_cuda-25.12.0.dist-info/WHEEL +6 -0
  56. dask_cuda-25.12.0.dist-info/entry_points.txt +6 -0
  57. dask_cuda-25.12.0.dist-info/licenses/LICENSE +201 -0
  58. dask_cuda-25.12.0.dist-info/top_level.txt +6 -0
  59. shared-actions/check_nightly_success/check-nightly-success/check.py +148 -0
  60. shared-actions/telemetry-impls/summarize/bump_time.py +54 -0
  61. shared-actions/telemetry-impls/summarize/send_trace.py +425 -0
dask_cuda/GIT_COMMIT ADDED
@@ -0,0 +1 @@
1
+ 7edf2c69a732ebd24f8dd3e76cb06235a473f7a5
dask_cuda/VERSION ADDED
@@ -0,0 +1 @@
1
+ 25.12.000
dask_cuda/__init__.py ADDED
@@ -0,0 +1,57 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import sys
5
+
6
+ if sys.platform != "linux":
7
+ raise ImportError("Only Linux is supported by Dask-CUDA at this time")
8
+
9
+ import dask
10
+ import dask.utils
11
+ from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
12
+ from distributed.protocol.serialize import dask_deserialize, dask_serialize
13
+
14
+ from ._version import __git_commit__, __version__
15
+ from .cuda_worker import CUDAWorker
16
+
17
+ from .local_cuda_cluster import LocalCUDACluster
18
+
19
+
20
+ try:
21
+ import dask.dataframe as dask_dataframe
22
+ except ImportError:
23
+ # Dask DataFrame (optional) isn't installed
24
+ dask_dataframe = None
25
+
26
+
27
+ if dask_dataframe is not None:
28
+ from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
29
+ from .proxify_device_objects import proxify_decorator, unproxify_decorator
30
+
31
+ # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
32
+ patch_shuffle_expression()
33
+ # Monkey patching Dask to make use of proxify and unproxify in compatibility mode
34
+ dask_dataframe.shuffle.shuffle_group = proxify_decorator(
35
+ dask.dataframe.shuffle.shuffle_group
36
+ )
37
+ dask_dataframe.core._concat = unproxify_decorator(dask.dataframe.core._concat)
38
+
39
+ def _register_cudf_spill_aware():
40
+ import cudf
41
+
42
+ # Only enable Dask/cuDF spilling if cuDF spilling is disabled, see
43
+ # https://github.com/rapidsai/dask-cuda/issues/1363
44
+ if not cudf.get_option("spill"):
45
+ # This reproduces the implementation of `_register_cudf`, see
46
+ # https://github.com/dask/distributed/blob/40fcd65e991382a956c3b879e438be1b100dff97/distributed/protocol/__init__.py#L106-L115
47
+ from cudf.comm import serialize
48
+
49
+ for registry in [
50
+ cuda_serialize,
51
+ cuda_deserialize,
52
+ dask_serialize,
53
+ dask_deserialize,
54
+ ]:
55
+ for lib in ["cudf", "dask_cudf"]:
56
+ if lib in registry._lazy:
57
+ registry._lazy[lib] = _register_cudf_spill_aware
dask_cuda/_compat.py ADDED
@@ -0,0 +1,19 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import functools
5
+ import importlib.metadata
6
+
7
+ import packaging.version
8
+
9
+
10
+ @functools.lru_cache(maxsize=None)
11
+ def get_dask_version() -> packaging.version.Version:
12
+ return packaging.version.parse(importlib.metadata.version("dask"))
13
+
14
+
15
+ @functools.lru_cache(maxsize=None)
16
+ def DASK_2025_4_0():
17
+ # dask 2025.4.0 isn't currently released, so we're relying
18
+ # on strictly greater than here.
19
+ return get_dask_version() > packaging.version.parse("2025.3.0")
dask_cuda/_version.py ADDED
@@ -0,0 +1,19 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import importlib.resources
5
+
6
+ __version__ = (
7
+ importlib.resources.files(__package__).joinpath("VERSION").read_text().strip()
8
+ )
9
+ try:
10
+ __git_commit__ = (
11
+ importlib.resources.files(__package__)
12
+ .joinpath("GIT_COMMIT")
13
+ .read_text()
14
+ .strip()
15
+ )
16
+ except FileNotFoundError:
17
+ __git_commit__ = ""
18
+
19
+ __all__ = ["__git_commit__", "__version__"]
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,216 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import contextlib
5
+ from argparse import Namespace
6
+ from functools import partial
7
+ from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Tuple
8
+ from warnings import filterwarnings
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ import dask
14
+ from distributed import Client, performance_report
15
+
16
+ from dask_cuda.benchmarks.utils import (
17
+ address_to_index,
18
+ aggregate_transfer_log_data,
19
+ bandwidth_statistics,
20
+ get_cluster_options,
21
+ peer_to_peer_bandwidths,
22
+ save_benchmark_data,
23
+ setup_memory_pools,
24
+ wait_for_cluster,
25
+ )
26
+ from dask_cuda.utils import all_to_all
27
+
28
+ __all__ = ("execute_benchmark", "Config")
29
+
30
+
31
+ class Config(NamedTuple):
32
+ """Benchmark configuration"""
33
+
34
+ args: Namespace
35
+ """Parsed benchmark arguments"""
36
+ bench_once: Callable[[Client, Namespace, Optional[str]], Any]
37
+ """Callable to run a single benchmark iteration
38
+
39
+ Parameters
40
+ ----------
41
+ client
42
+ distributed Client object
43
+ args
44
+ Benchmark parsed arguments
45
+ write_profile
46
+ Should a profile be written?
47
+
48
+ Returns
49
+ -------
50
+ Benchmark data to be interpreted by ``pretty_print_results`` and
51
+ ``create_tidy_results``.
52
+ """
53
+ create_tidy_results: Callable[
54
+ [Namespace, np.ndarray, List[Any]], Tuple[pd.DataFrame, np.ndarray]
55
+ ]
56
+ """Callable to create tidy results for saving to disk
57
+
58
+ Parameters
59
+ ----------
60
+ args
61
+ Benchmark parsed arguments
62
+ p2p_bw
63
+ Array of point-to-point bandwidths
64
+ results: list
65
+ List of results from running ``bench_once``
66
+ Returns
67
+ -------
68
+ tuple
69
+ two-tuple of a pandas dataframe and the point-to-point bandwidths
70
+ """
71
+ pretty_print_results: Callable[
72
+ [Namespace, Mapping[str, int], np.ndarray, List[Any], Optional[Client]], None
73
+ ]
74
+ """Callable to pretty-print results for human consumption
75
+
76
+ Parameters
77
+ ----------
78
+ args
79
+ Benchmark parsed arguments
80
+ address_to_index
81
+ Mapping from worker addresses to indices
82
+ p2p_bw
83
+ Array of point-to-point bandwidths
84
+ results: list
85
+ List of results from running ``bench_once``
86
+ """
87
+
88
+
89
+ def run_benchmark(client: Client, args: Namespace, config: Config):
90
+ """Run a benchmark a specified number of times
91
+
92
+ If ``args.profile`` is set, the final run is profiled.
93
+ """
94
+
95
+ results = []
96
+ for _ in range(max(0, args.warmup_runs)):
97
+ config.bench_once(client, args, write_profile=None)
98
+
99
+ ctx = contextlib.nullcontext()
100
+ if args.profile is not None:
101
+ ctx = performance_report(filename=args.profile)
102
+ with ctx:
103
+ for _ in range(max(1, args.runs) - 1):
104
+ res = config.bench_once(client, args, write_profile=None)
105
+ results.append(res)
106
+ results.append(config.bench_once(client, args, write_profile=args.profile_last))
107
+ return results
108
+
109
+
110
+ def gather_bench_results(client: Client, args: Namespace, config: Config):
111
+ """Collect benchmark results from the workers"""
112
+ address2index = address_to_index(client)
113
+ if args.all_to_all:
114
+ all_to_all(client)
115
+ results = run_benchmark(client, args, config)
116
+ # Collect aggregated peer-to-peer bandwidth
117
+ message_data = client.run(
118
+ partial(aggregate_transfer_log_data, bandwidth_statistics, args.ignore_size)
119
+ )
120
+ return address2index, results, message_data
121
+
122
+
123
+ def run(client: Client, args: Namespace, config: Config):
124
+ """Run the full benchmark on the cluster
125
+
126
+ Waits for the cluster, sets up memory pools, prints and saves results
127
+ """
128
+
129
+ wait_for_cluster(client, shutdown_on_failure=True)
130
+ assert len(client.scheduler_info(n_workers=-1)["workers"]) > 0
131
+ setup_memory_pools(
132
+ client=client,
133
+ is_gpu=args.type == "gpu",
134
+ disable_rmm=args.disable_rmm,
135
+ disable_rmm_pool=args.disable_rmm_pool,
136
+ pool_size=args.rmm_pool_size,
137
+ maximum_pool_size=args.rmm_maximum_pool_size,
138
+ rmm_async=args.enable_rmm_async,
139
+ rmm_managed=args.enable_rmm_managed,
140
+ release_threshold=args.rmm_release_threshold,
141
+ log_directory=args.rmm_log_directory,
142
+ statistics=args.enable_rmm_statistics,
143
+ rmm_track_allocations=args.enable_rmm_track_allocations,
144
+ )
145
+ address_to_index, results, message_data = gather_bench_results(client, args, config)
146
+ p2p_bw = peer_to_peer_bandwidths(message_data, address_to_index)
147
+ config.pretty_print_results(args, address_to_index, p2p_bw, results, client=client)
148
+ if args.output_basename:
149
+ df, p2p_bw = config.create_tidy_results(args, p2p_bw, results)
150
+ df["num_workers"] = len(address_to_index)
151
+ save_benchmark_data(
152
+ args.output_basename,
153
+ address_to_index,
154
+ df,
155
+ p2p_bw,
156
+ )
157
+
158
+
159
+ def run_client_from_existing_scheduler(args: Namespace, config: Config):
160
+ """Set up a client by connecting to a scheduler
161
+
162
+ Shuts down the cluster at the end of the benchmark conditional on
163
+ ``args.shutdown_cluster``.
164
+ """
165
+ if args.scheduler_address is not None:
166
+ kwargs = {"address": args.scheduler_address}
167
+ elif args.scheduler_file is not None:
168
+ kwargs = {"scheduler_file": args.scheduler_file}
169
+ else:
170
+ raise RuntimeError(
171
+ "Need to specify either --scheduler-file or --scheduler-address"
172
+ )
173
+ with Client(**kwargs) as client:
174
+ run(client, args, config)
175
+ if args.shutdown_cluster:
176
+ client.shutdown()
177
+
178
+
179
+ def run_create_client(args: Namespace, config: Config):
180
+ """Create a client + cluster and run
181
+
182
+ Shuts down the cluster at the end of the benchmark
183
+ """
184
+ cluster_options = get_cluster_options(args)
185
+ Cluster = cluster_options["class"]
186
+ cluster_args = cluster_options["args"]
187
+ cluster_kwargs = cluster_options["kwargs"]
188
+ scheduler_addr = cluster_options["scheduler_addr"]
189
+
190
+ filterwarnings("ignore", message=".*NVLink.*rmm_pool_size.*", category=UserWarning)
191
+
192
+ with Cluster(*cluster_args, **cluster_kwargs) as cluster:
193
+ # Use the scheduler address with an SSHCluster rather than the cluster
194
+ # object, otherwise we can't shut it down.
195
+ with Client(scheduler_addr if args.multi_node else cluster) as client:
196
+ run(client, args, config)
197
+ # An SSHCluster will not automatically shut down, we have to
198
+ # ensure it does.
199
+ if args.multi_node:
200
+ client.shutdown()
201
+
202
+
203
+ def execute_benchmark(config: Config):
204
+ """Run complete benchmark given a configuration"""
205
+ args = config.args
206
+ if args.multiprocessing_method == "forkserver":
207
+ import multiprocessing.forkserver as f
208
+
209
+ f.ensure_running()
210
+ with dask.config.set(
211
+ {"distributed.worker.multiprocessing-method": args.multiprocessing_method}
212
+ ):
213
+ if args.scheduler_file is not None or args.scheduler_address is not None:
214
+ run_client_from_existing_scheduler(args, config)
215
+ else:
216
+ run_create_client(args, config)
@@ -0,0 +1,278 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import contextlib
5
+ from collections import ChainMap
6
+ from time import perf_counter as clock
7
+
8
+ import pandas as pd
9
+
10
+ import dask
11
+ import dask.dataframe as dd
12
+ from dask.distributed import performance_report, wait
13
+ from dask.utils import format_bytes
14
+
15
+ from dask_cuda.benchmarks.common import Config, execute_benchmark
16
+ from dask_cuda.benchmarks.utils import (
17
+ as_noop,
18
+ parse_benchmark_args,
19
+ print_key_value,
20
+ print_separator,
21
+ print_throughput_bandwidth,
22
+ )
23
+
24
+
25
+ def apply_groupby(
26
+ df,
27
+ backend,
28
+ sort=False,
29
+ split_out=1,
30
+ split_every=8,
31
+ shuffle=None,
32
+ ):
33
+ if backend == "dask-noop" and shuffle == "explicit-comms":
34
+ raise RuntimeError("dask-noop not valid for explicit-comms shuffle")
35
+ # Handle special "explicit-comms" case
36
+ config = {}
37
+ if shuffle == "explicit-comms":
38
+ shuffle = "tasks"
39
+ config = {"explicit-comms": True}
40
+
41
+ with dask.config.set(config):
42
+ agg = df.groupby("key", sort=sort).agg(
43
+ {"int64": ["max", "count"], "float64": "mean"},
44
+ split_out=split_out,
45
+ split_every=split_every,
46
+ shuffle=shuffle,
47
+ )
48
+ if backend == "dask-noop":
49
+ agg = as_noop(agg)
50
+
51
+ wait(agg.persist())
52
+ return agg
53
+
54
+
55
+ def generate_chunk(chunk_info, unique_size=1, gpu=True):
56
+ # Setting a seed that triggers max amount of comm in the two-GPU case.
57
+ if gpu:
58
+ import cupy as xp
59
+
60
+ import cudf as xdf
61
+ else:
62
+ import numpy as xp
63
+ import pandas as xdf
64
+
65
+ i_chunk, local_size = chunk_info
66
+ xp.random.seed(i_chunk * 1_000)
67
+ return xdf.DataFrame(
68
+ {
69
+ "key": xp.random.randint(0, unique_size, size=local_size, dtype="int64"),
70
+ "int64": xp.random.permutation(xp.arange(local_size, dtype="int64")),
71
+ "float64": xp.random.permutation(xp.arange(local_size, dtype="float64")),
72
+ }
73
+ )
74
+
75
+
76
+ def get_random_ddf(args):
77
+ total_size = args.chunk_size * args.in_parts
78
+ chunk_kwargs = {
79
+ "unique_size": max(int(args.unique_ratio * total_size), 1),
80
+ "gpu": True if args.type == "gpu" else False,
81
+ }
82
+
83
+ return dd.from_map(
84
+ generate_chunk,
85
+ [(i, args.chunk_size) for i in range(args.in_parts)],
86
+ meta=generate_chunk((0, 1), **chunk_kwargs),
87
+ enforce_metadata=False,
88
+ **chunk_kwargs,
89
+ )
90
+
91
+
92
+ def bench_once(client, args, write_profile=None):
93
+ # Generate random Dask dataframe
94
+ df = get_random_ddf(args)
95
+
96
+ data_processed = len(df) * sum([t.itemsize for t in df.dtypes])
97
+ shuffle = {
98
+ "True": "tasks",
99
+ "False": False,
100
+ }.get(args.shuffle, args.shuffle)
101
+
102
+ ctx = contextlib.nullcontext()
103
+ if write_profile is not None:
104
+ ctx = performance_report(filename=write_profile)
105
+
106
+ with ctx:
107
+ t1 = clock()
108
+ agg = apply_groupby(
109
+ df,
110
+ backend=args.backend,
111
+ sort=args.sort,
112
+ split_out=args.split_out,
113
+ split_every=args.split_every,
114
+ shuffle=shuffle,
115
+ )
116
+ t2 = clock()
117
+
118
+ output_size = agg.memory_usage(index=True, deep=True).compute().sum()
119
+ return (data_processed, output_size, t2 - t1)
120
+
121
+
122
+ def pretty_print_results(args, address_to_index, p2p_bw, results, client=None):
123
+ if args.markdown:
124
+ print("```")
125
+ print("Groupby benchmark")
126
+ print_separator(separator="-")
127
+ print_key_value(key="Use shuffle", value=f"{args.shuffle}")
128
+ print_key_value(key="Backend", value=f"{args.backend}")
129
+ print_key_value(key="Output partitions", value=f"{args.split_out}")
130
+ print_key_value(key="Input partitions", value=f"{args.in_parts}")
131
+ print_key_value(key="Sort Groups", value=f"{args.sort}")
132
+ print_key_value(key="Rows-per-chunk", value=f"{args.chunk_size}")
133
+ print_key_value(key="Unique-group ratio", value=f"{args.unique_ratio}")
134
+ print_key_value(key="Protocol", value=f"{args.protocol}")
135
+ print_key_value(key="Device(s)", value=f"{args.devs}")
136
+ print_key_value(key="Tree-reduction width", value=f"{args.split_every}")
137
+ if args.device_memory_limit:
138
+ print_key_value(
139
+ key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
140
+ )
141
+ print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
142
+ if args.protocol in ["ucx", "ucxx"]:
143
+ print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
144
+ print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
145
+ print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
146
+ print_key_value(key="Worker thread(s)", value=f"{args.threads_per_worker}")
147
+ print_key_value(key="Data processed", value=f"{format_bytes(results[0][0])}")
148
+ print_key_value(key="Output size", value=f"{format_bytes(results[0][1])}")
149
+ if args.markdown:
150
+ print("\n```")
151
+ data_processed, output_size, durations = zip(*results)
152
+ print_throughput_bandwidth(
153
+ args, durations, data_processed, p2p_bw, address_to_index
154
+ )
155
+
156
+
157
+ def create_tidy_results(args, p2p_bw, results):
158
+ configuration = {
159
+ "dataframe_type": "cudf" if args.type == "gpu" else "pandas",
160
+ "shuffle": args.shuffle,
161
+ "backend": args.backend,
162
+ "sort": args.sort,
163
+ "split_out": args.split_out,
164
+ "split_every": args.split_every,
165
+ "in_parts": args.in_parts,
166
+ "rows_per_chunk": args.chunk_size,
167
+ "unique_ratio": args.unique_ratio,
168
+ "protocol": args.protocol,
169
+ "devs": args.devs,
170
+ "device_memory_limit": args.device_memory_limit,
171
+ "rmm_pool": not args.disable_rmm_pool,
172
+ "tcp": args.enable_tcp_over_ucx,
173
+ "ib": args.enable_infiniband,
174
+ "nvlink": args.enable_nvlink,
175
+ }
176
+ timing_data = pd.DataFrame(
177
+ [
178
+ pd.Series(
179
+ data=ChainMap(
180
+ configuration,
181
+ {
182
+ "wallclock": duration,
183
+ "data_processed": data_processed,
184
+ "output_size": output_size,
185
+ },
186
+ )
187
+ )
188
+ for data_processed, output_size, duration in results
189
+ ]
190
+ )
191
+ return timing_data, p2p_bw
192
+
193
+
194
+ def parse_args():
195
+ special_args = [
196
+ {
197
+ "name": "--in-parts",
198
+ "default": 100,
199
+ "metavar": "n",
200
+ "type": int,
201
+ "help": "Number of input partitions (default '100')",
202
+ },
203
+ {
204
+ "name": [
205
+ "-c",
206
+ "--chunk-size",
207
+ ],
208
+ "default": 1_000_000,
209
+ "metavar": "n",
210
+ "type": int,
211
+ "help": "Chunk size (default 1_000_000)",
212
+ },
213
+ {
214
+ "name": "--unique-ratio",
215
+ "default": 0.01,
216
+ "type": float,
217
+ "help": "Fraction of rows that are unique groups",
218
+ },
219
+ {
220
+ "name": "--sort",
221
+ "default": False,
222
+ "action": "store_true",
223
+ "help": "Whether to sort the output group order.",
224
+ },
225
+ {
226
+ "name": "--split_out",
227
+ "default": 1,
228
+ "type": int,
229
+ "help": "How many partitions to return.",
230
+ },
231
+ {
232
+ "name": "--split_every",
233
+ "default": 8,
234
+ "type": int,
235
+ "help": "Tree-reduction width.",
236
+ },
237
+ {
238
+ "name": "--shuffle",
239
+ "choices": ["False", "True", "tasks", "explicit-comms"],
240
+ "default": "False",
241
+ "type": str,
242
+ "help": "Whether to use shuffle-based groupby.",
243
+ },
244
+ {
245
+ "name": "--backend",
246
+ "choices": ["dask", "dask-noop"],
247
+ "default": "dask",
248
+ "type": str,
249
+ "help": (
250
+ "Compute engine to use, dask-noop turns the graph into a noop graph"
251
+ ),
252
+ },
253
+ {
254
+ "name": [
255
+ "-t",
256
+ "--type",
257
+ ],
258
+ "choices": ["cpu", "gpu"],
259
+ "default": "gpu",
260
+ "type": str,
261
+ "help": "Do shuffle with GPU or CPU dataframes (default 'gpu')",
262
+ },
263
+ ]
264
+
265
+ return parse_benchmark_args(
266
+ description="Distributed groupby (dask/cudf) benchmark", args_list=special_args
267
+ )
268
+
269
+
270
+ if __name__ == "__main__":
271
+ execute_benchmark(
272
+ Config(
273
+ args=parse_args(),
274
+ bench_once=bench_once,
275
+ create_tidy_results=create_tidy_results,
276
+ pretty_print_results=pretty_print_results,
277
+ )
278
+ )