cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +28 -7
  4. cudf_polars/containers/column.py +51 -26
  5. cudf_polars/dsl/expressions/binaryop.py +1 -1
  6. cudf_polars/dsl/expressions/boolean.py +1 -1
  7. cudf_polars/dsl/expressions/selection.py +1 -1
  8. cudf_polars/dsl/expressions/string.py +29 -20
  9. cudf_polars/dsl/expressions/ternary.py +25 -1
  10. cudf_polars/dsl/expressions/unary.py +11 -8
  11. cudf_polars/dsl/ir.py +351 -281
  12. cudf_polars/dsl/translate.py +18 -15
  13. cudf_polars/dsl/utils/aggregations.py +10 -5
  14. cudf_polars/experimental/base.py +10 -0
  15. cudf_polars/experimental/benchmarks/pdsh.py +1 -1
  16. cudf_polars/experimental/benchmarks/utils.py +83 -2
  17. cudf_polars/experimental/distinct.py +2 -0
  18. cudf_polars/experimental/explain.py +1 -1
  19. cudf_polars/experimental/expressions.py +8 -5
  20. cudf_polars/experimental/groupby.py +2 -0
  21. cudf_polars/experimental/io.py +64 -42
  22. cudf_polars/experimental/join.py +15 -2
  23. cudf_polars/experimental/parallel.py +10 -7
  24. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  25. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  26. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  27. cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
  28. cudf_polars/experimental/rapidsmpf/core.py +194 -67
  29. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  30. cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
  31. cudf_polars/experimental/rapidsmpf/io.py +162 -70
  32. cudf_polars/experimental/rapidsmpf/join.py +162 -77
  33. cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
  34. cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
  35. cudf_polars/experimental/rapidsmpf/union.py +24 -5
  36. cudf_polars/experimental/rapidsmpf/utils.py +228 -16
  37. cudf_polars/experimental/shuffle.py +18 -4
  38. cudf_polars/experimental/sort.py +13 -6
  39. cudf_polars/experimental/spilling.py +1 -1
  40. cudf_polars/testing/plugin.py +6 -3
  41. cudf_polars/utils/config.py +67 -0
  42. cudf_polars/utils/versions.py +3 -3
  43. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
  44. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
  45. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  46. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,10 @@ from collections import defaultdict
8
8
  from concurrent.futures import ThreadPoolExecutor
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
- from rapidsmpf.buffer.buffer import MemoryType
12
- from rapidsmpf.buffer.resource import BufferResource, LimitAvailableMemory
13
11
  from rapidsmpf.communicator.single import new_communicator
14
12
  from rapidsmpf.config import Options, get_environment_variables
13
+ from rapidsmpf.memory.buffer import MemoryType
14
+ from rapidsmpf.memory.buffer_resource import BufferResource, LimitAvailableMemory
15
15
  from rapidsmpf.rmm_resource_adaptor import RmmResourceAdaptor
16
16
  from rapidsmpf.streaming.core.context import Context
17
17
  from rapidsmpf.streaming.core.leaf_node import pull_from_channel
@@ -22,17 +22,22 @@ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
22
22
 
23
23
  import rmm
24
24
 
25
+ import cudf_polars.experimental.rapidsmpf.collectives.shuffle
25
26
  import cudf_polars.experimental.rapidsmpf.io
26
27
  import cudf_polars.experimental.rapidsmpf.join
27
28
  import cudf_polars.experimental.rapidsmpf.lower
28
29
  import cudf_polars.experimental.rapidsmpf.repartition
29
- import cudf_polars.experimental.rapidsmpf.shuffle
30
30
  import cudf_polars.experimental.rapidsmpf.union # noqa: F401
31
31
  from cudf_polars.containers import DataFrame
32
32
  from cudf_polars.dsl.ir import DataFrameScan, IRExecutionContext, Join, Scan, Union
33
33
  from cudf_polars.dsl.traversal import CachingVisitor, traversal
34
+ from cudf_polars.experimental.rapidsmpf.collectives import ReserveOpIDs
34
35
  from cudf_polars.experimental.rapidsmpf.dispatch import FanoutInfo, lower_ir_node
35
- from cudf_polars.experimental.rapidsmpf.nodes import generate_ir_sub_network_wrapper
36
+ from cudf_polars.experimental.rapidsmpf.nodes import (
37
+ generate_ir_sub_network_wrapper,
38
+ metadata_drain_node,
39
+ )
40
+ from cudf_polars.experimental.rapidsmpf.utils import empty_table_chunk
36
41
  from cudf_polars.experimental.statistics import collect_statistics
37
42
  from cudf_polars.experimental.utils import _concat
38
43
  from cudf_polars.utils.config import CUDAStreamPoolConfig
@@ -40,10 +45,13 @@ from cudf_polars.utils.config import CUDAStreamPoolConfig
40
45
  if TYPE_CHECKING:
41
46
  from collections.abc import MutableMapping
42
47
 
48
+ from rapidsmpf.streaming.core.channel import Channel
43
49
  from rapidsmpf.streaming.core.leaf_node import DeferredMessages
44
50
 
45
51
  import polars as pl
46
52
 
53
+ from rmm.pylibrmm.cuda_stream_pool import CudaStreamPool
54
+
47
55
  from cudf_polars.dsl.ir import IR
48
56
  from cudf_polars.experimental.base import PartitionInfo, StatsCollector
49
57
  from cudf_polars.experimental.parallel import ConfigOptions
@@ -53,12 +61,15 @@ if TYPE_CHECKING:
53
61
  LowerState,
54
62
  SubNetGenerator,
55
63
  )
64
+ from cudf_polars.experimental.rapidsmpf.utils import Metadata
56
65
 
57
66
 
58
67
  def evaluate_logical_plan(
59
68
  ir: IR,
60
69
  config_options: ConfigOptions,
61
- ) -> pl.DataFrame:
70
+ *,
71
+ collect_metadata: bool = False,
72
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
62
73
  """
63
74
  Evaluate a logical plan with the RapidsMPF streaming runtime.
64
75
 
@@ -68,59 +79,136 @@ def evaluate_logical_plan(
68
79
  The IR node.
69
80
  config_options
70
81
  The configuration options.
82
+ collect_metadata
83
+ Whether to collect runtime metadata.
71
84
 
72
85
  Returns
73
86
  -------
74
- The output DataFrame.
87
+ The output DataFrame and metadata collector.
75
88
  """
76
89
  assert config_options.executor.name == "streaming", "Executor must be streaming"
77
90
  assert config_options.executor.runtime == "rapidsmpf", "Runtime must be rapidsmpf"
78
91
 
79
- if (
80
- config_options.executor.scheduler == "distributed"
81
- ): # pragma: no cover; Requires distributed
82
- # TODO: Add distributed-execution support
83
- raise NotImplementedError(
84
- "The rapidsmpf engine does not support distributed execution yet."
85
- )
86
-
87
92
  # Lower the IR graph on the client process (for now).
88
93
  ir, partition_info, stats = lower_ir_graph(ir, config_options)
89
94
 
90
- # Configure the context.
91
- # TODO: Multi-GPU version will be different. The rest of this function
92
- # will be executed on each rank independently.
93
- # TODO: Need a way to configure options specific to the rapidmspf engine.
94
- options = Options(get_environment_variables())
95
- comm = new_communicator(options)
96
- mr = RmmResourceAdaptor(rmm.mr.get_current_device_resource())
97
- rmm.mr.set_current_device_resource(mr)
98
- memory_available: MutableMapping[MemoryType, LimitAvailableMemory] | None = None
99
- single_spill_device = config_options.executor.client_device_threshold
100
- if single_spill_device > 0.0 and single_spill_device < 1.0:
101
- total_memory = rmm.mr.available_device_memory()[1]
102
- memory_available = {
103
- MemoryType.DEVICE: LimitAvailableMemory(
104
- mr, limit=int(total_memory * single_spill_device)
95
+ # Reserve shuffle IDs for the entire pipeline execution
96
+ with ReserveOpIDs(ir) as shuffle_id_map:
97
+ # Build and execute the streaming pipeline.
98
+ # This must be done on all worker processes
99
+ # for cluster == "distributed".
100
+ if (
101
+ config_options.executor.cluster == "distributed"
102
+ ): # pragma: no cover; block depends on executor type and Distributed cluster
103
+ # Distributed execution: Use client.run
104
+
105
+ # NOTE: Distributed execution requires Dask for now
106
+ from cudf_polars.experimental.rapidsmpf.dask import evaluate_pipeline_dask
107
+
108
+ result, metadata_collector = evaluate_pipeline_dask(
109
+ evaluate_pipeline,
110
+ ir,
111
+ partition_info,
112
+ config_options,
113
+ stats,
114
+ shuffle_id_map,
115
+ collect_metadata=collect_metadata,
116
+ )
117
+ else:
118
+ # Single-process execution: Run locally
119
+ result, metadata_collector = evaluate_pipeline(
120
+ ir,
121
+ partition_info,
122
+ config_options,
123
+ stats,
124
+ shuffle_id_map,
125
+ collect_metadata=collect_metadata,
105
126
  )
106
- }
107
127
 
108
- # We have a couple of cases to consider here:
109
- # 1: we want to use the same stream pool for cudf-polars and rapidsmpf
110
- # 2: rapidsmpf uses its own pool and cudf-polars uses the default stream
111
- if isinstance(config_options.cuda_stream_policy, CUDAStreamPoolConfig):
112
- stream_pool = config_options.cuda_stream_policy.build()
113
- else:
114
- stream_pool = None
128
+ return result, metadata_collector
115
129
 
116
- br = BufferResource(mr, memory_available=memory_available, stream_pool=stream_pool)
117
- rmpf_context = Context(comm, br, options)
118
130
 
119
- executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpse")
131
+ def evaluate_pipeline(
132
+ ir: IR,
133
+ partition_info: MutableMapping[IR, PartitionInfo],
134
+ config_options: ConfigOptions,
135
+ stats: StatsCollector,
136
+ collective_id_map: dict[IR, int],
137
+ rmpf_context: Context | None = None,
138
+ *,
139
+ collect_metadata: bool = False,
140
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
141
+ """
142
+ Build and evaluate a RapidsMPF streaming pipeline.
143
+
144
+ Parameters
145
+ ----------
146
+ ir
147
+ The IR node.
148
+ partition_info
149
+ The partition information.
150
+ config_options
151
+ The configuration options.
152
+ stats
153
+ The statistics collector.
154
+ collective_id_map
155
+ The mapping of IR nodes to collective IDs.
156
+ rmpf_context
157
+ The RapidsMPF context.
158
+ collect_metadata
159
+ Whether to collect runtime metadata.
120
160
 
121
- # Create the IR execution context.
122
- if stream_pool is not None:
123
- # both cudf-polars and rapidsmpf are using the same stream pool
161
+ Returns
162
+ -------
163
+ The output DataFrame and metadata collector.
164
+ """
165
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
166
+ assert config_options.executor.runtime == "rapidsmpf", "Runtime must be rapidsmpf"
167
+
168
+ _initial_mr: Any = None
169
+ stream_pool: CudaStreamPool | bool = False
170
+ if rmpf_context is not None:
171
+ # Using "distributed" mode.
172
+ # Always use the RapidsMPF stream pool for now.
173
+ br = rmpf_context.br()
174
+ stream_pool = True
175
+ else:
176
+ # Using "single" mode.
177
+ # Create a new local RapidsMPF context.
178
+ _original_mr = rmm.mr.get_current_device_resource()
179
+ mr = RmmResourceAdaptor(_original_mr)
180
+ rmm.mr.set_current_device_resource(mr)
181
+ memory_available: MutableMapping[MemoryType, LimitAvailableMemory] | None = None
182
+ single_spill_device = config_options.executor.client_device_threshold
183
+ if single_spill_device > 0.0 and single_spill_device < 1.0:
184
+ total_memory = rmm.mr.available_device_memory()[1]
185
+ memory_available = {
186
+ MemoryType.DEVICE: LimitAvailableMemory(
187
+ mr, limit=int(total_memory * single_spill_device)
188
+ )
189
+ }
190
+
191
+ options = Options(
192
+ {
193
+ # By default, set the number of streaming threads to the max
194
+ # number of IO threads. The user may override this with an
195
+ # environment variable (i.e. RAPIDSMPF_NUM_STREAMING_THREADS)
196
+ "num_streaming_threads": str(
197
+ max(config_options.executor.max_io_threads, 1)
198
+ )
199
+ }
200
+ | get_environment_variables()
201
+ )
202
+ if isinstance(config_options.cuda_stream_policy, CUDAStreamPoolConfig):
203
+ stream_pool = config_options.cuda_stream_policy.build()
204
+ local_comm = new_communicator(options)
205
+ br = BufferResource(
206
+ mr, memory_available=memory_available, stream_pool=stream_pool
207
+ )
208
+ rmpf_context = Context(local_comm, br, options)
209
+
210
+ # Create the IR execution context
211
+ if stream_pool:
124
212
  ir_context = IRExecutionContext(
125
213
  get_cuda_stream=rmpf_context.get_stream_from_pool
126
214
  )
@@ -128,6 +216,8 @@ def evaluate_logical_plan(
128
216
  ir_context = IRExecutionContext.from_config_options(config_options)
129
217
 
130
218
  # Generate network nodes
219
+ assert rmpf_context is not None, "RapidsMPF context must defined."
220
+ metadata_collector: list[Metadata] | None = [] if collect_metadata else None
131
221
  nodes, output = generate_network(
132
222
  rmpf_context,
133
223
  ir,
@@ -135,9 +225,12 @@ def evaluate_logical_plan(
135
225
  config_options,
136
226
  stats,
137
227
  ir_context=ir_context,
228
+ collective_id_map=collective_id_map,
229
+ metadata_collector=metadata_collector,
138
230
  )
139
231
 
140
232
  # Run the network
233
+ executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpse")
141
234
  run_streaming_pipeline(nodes=nodes, py_executor=executor)
142
235
 
143
236
  # Extract/return the concatenated result.
@@ -150,16 +243,29 @@ def evaluate_logical_plan(
150
243
  )
151
244
  for msg in messages
152
245
  ]
153
- dfs = [
154
- DataFrame.from_table(
246
+ dfs: list[DataFrame] = []
247
+ if chunks:
248
+ dfs = [
249
+ DataFrame.from_table(
250
+ chunk.table_view(),
251
+ list(ir.schema.keys()),
252
+ list(ir.schema.values()),
253
+ chunk.stream,
254
+ )
255
+ for chunk in chunks
256
+ ]
257
+ df = _concat(*dfs, context=ir_context)
258
+ else:
259
+ # No chunks received - create an empty DataFrame with correct schema
260
+ stream = ir_context.get_cuda_stream()
261
+ chunk = empty_table_chunk(ir, rmpf_context, stream)
262
+ df = DataFrame.from_table(
155
263
  chunk.table_view(),
156
264
  list(ir.schema.keys()),
157
265
  list(ir.schema.values()),
158
- chunk.stream,
266
+ stream,
159
267
  )
160
- for chunk in chunks
161
- ]
162
- df = _concat(*dfs, context=ir_context)
268
+
163
269
  # We need to materialize the polars dataframe before we drop the rapidsmpf
164
270
  # context, which keeps the CUDA streams alive.
165
271
  stream = df.stream
@@ -170,7 +276,11 @@ def evaluate_logical_plan(
170
276
  # before the Context, which ultimately contains the rmm MR, goes out of scope.
171
277
  del nodes, output, messages, chunks, dfs, df
172
278
 
173
- return result
279
+ # Restore the initial RMM memory resource
280
+ if _initial_mr is not None:
281
+ rmm.mr.set_current_device_resource(_original_mr)
282
+
283
+ return result, metadata_collector
174
284
 
175
285
 
176
286
  def lower_ir_graph(
@@ -186,12 +296,15 @@ def lower_ir_graph(
186
296
  Root of the graph to rewrite.
187
297
  config_options
188
298
  GPUEngine configuration options.
299
+ stats
300
+ The statistics collector.
189
301
 
190
302
  Returns
191
303
  -------
192
304
  new_ir, partition_info, stats
193
- The rewritten graph, and a mapping from unique nodes
194
- in the new graph to associated partitioning information.
305
+ The rewritten graph, a mapping from unique nodes
306
+ in the new graph to associated partitioning information,
307
+ and the statistics collector.
195
308
 
196
309
  Notes
197
310
  -----
@@ -287,6 +400,8 @@ def generate_network(
287
400
  stats: StatsCollector,
288
401
  *,
289
402
  ir_context: IRExecutionContext,
403
+ collective_id_map: dict[IR, int],
404
+ metadata_collector: list[Metadata] | None,
290
405
  ) -> tuple[list[Any], DeferredMessages]:
291
406
  """
292
407
  Translate the IR graph to a RapidsMPF streaming network.
@@ -305,6 +420,12 @@ def generate_network(
305
420
  Statistics collector.
306
421
  ir_context
307
422
  The execution context for the IR node.
423
+ collective_id_map
424
+ The mapping of IR nodes to collective IDs.
425
+ metadata_collector
426
+ The list to collect the final metadata.
427
+ This list will be mutated when the network is executed.
428
+ If None, metadata will not be collected.
308
429
 
309
430
  Returns
310
431
  -------
@@ -322,8 +443,9 @@ def generate_network(
322
443
  # Determine which nodes need fanout
323
444
  fanout_nodes = determine_fanout_nodes(ir, partition_info, ir_dep_count)
324
445
 
325
- # TODO: Make this configurable
326
- max_io_threads_global = 2
446
+ # Get max_io_threads from config (default: 2)
447
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
448
+ max_io_threads_global = config_options.executor.max_io_threads
327
449
  max_io_threads_local = max(1, max_io_threads_global // max(1, num_io_nodes))
328
450
 
329
451
  # Generate the network
@@ -335,27 +457,32 @@ def generate_network(
335
457
  "ir_context": ir_context,
336
458
  "max_io_threads": max_io_threads_local,
337
459
  "stats": stats,
460
+ "collective_id_map": collective_id_map,
338
461
  }
339
462
  mapper: SubNetGenerator = CachingVisitor(
340
463
  generate_ir_sub_network_wrapper, state=state
341
464
  )
342
- nodes, channels = mapper(ir)
343
-
344
- # Deduplicate nodes.
345
- # TODO: Remove after https://github.com/rapidsai/cudf/pull/20586
346
- nodes = list(set(nodes))
347
-
465
+ nodes_dict, channels = mapper(ir)
348
466
  ch_out = channels[ir].reserve_output_slot()
349
467
 
350
- # TODO: We will need an additional node here to drain
351
- # the metadata channel once we start plumbing metadata
352
- # through the network. This node could also drop
353
- # "duplicated" data on all but rank 0.
468
+ # Add node to drain metadata channel before pull_from_channel
469
+ # (since pull_from_channel doesn't accept a ChannelPair)
470
+ ch_final_data: Channel[TableChunk] = context.create_channel()
471
+ drain_node = metadata_drain_node(
472
+ context,
473
+ ir,
474
+ ir_context,
475
+ ch_out,
476
+ ch_final_data,
477
+ metadata_collector,
478
+ )
354
479
 
355
480
  # Add final node to pull from the output data channel
356
- # (metadata channel is unused)
357
- output_node, output = pull_from_channel(context, ch_in=ch_out.data)
358
- nodes.append(output_node)
481
+ output_node, output = pull_from_channel(context, ch_in=ch_final_data)
482
+
483
+ # Flatten the nodes dictionary into a list for run_streaming_pipeline
484
+ nodes: list[Any] = [node for node_list in nodes_dict.values() for node in node_list]
485
+ nodes.extend([drain_node, output_node])
359
486
 
360
487
  # Return network and output hook
361
488
  return nodes, output
@@ -0,0 +1,172 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Dask-based execution with the streaming RapidsMPF runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Any, Protocol
8
+
9
+ from distributed import get_client
10
+ from rapidsmpf.config import Options, get_environment_variables
11
+ from rapidsmpf.integrations.dask import get_worker_context
12
+ from rapidsmpf.streaming.core.context import Context
13
+
14
+ import polars as pl
15
+
16
+ from cudf_polars.experimental.dask_registers import DaskRegisterManager
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import MutableMapping
20
+
21
+ from distributed import Client
22
+
23
+ from cudf_polars.dsl.ir import IR
24
+ from cudf_polars.experimental.base import PartitionInfo, StatsCollector
25
+ from cudf_polars.experimental.parallel import ConfigOptions
26
+ from cudf_polars.experimental.rapidsmpf.utils import Metadata
27
+
28
+
29
+ class EvaluatePipelineCallback(Protocol):
30
+ """Protocol for the evaluate_pipeline callback."""
31
+
32
+ def __call__(
33
+ self,
34
+ ir: IR,
35
+ partition_info: MutableMapping[IR, PartitionInfo],
36
+ config_options: ConfigOptions,
37
+ stats: StatsCollector,
38
+ collective_id_map: dict[IR, int],
39
+ rmpf_context: Context | None = None,
40
+ *,
41
+ collect_metadata: bool = False,
42
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
43
+ """Evaluate a pipeline and return the result DataFrame and metadata."""
44
+ ...
45
+
46
+
47
+ def get_dask_client() -> Client:
48
+ """Get a distributed Dask client."""
49
+ client = get_client()
50
+ DaskRegisterManager.register_once()
51
+ DaskRegisterManager.run_on_cluster(client)
52
+ return client
53
+
54
+
55
+ def evaluate_pipeline_dask(
56
+ callback: EvaluatePipelineCallback,
57
+ ir: IR,
58
+ partition_info: MutableMapping[IR, PartitionInfo],
59
+ config_options: ConfigOptions,
60
+ stats: StatsCollector,
61
+ shuffle_id_map: dict[IR, int],
62
+ *,
63
+ collect_metadata: bool = False,
64
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
65
+ """
66
+ Evaluate a RapidsMPF streaming pipeline on a Dask cluster.
67
+
68
+ Parameters
69
+ ----------
70
+ callback
71
+ The callback function to evaluate the pipeline.
72
+ ir
73
+ The IR node.
74
+ partition_info
75
+ The partition information.
76
+ config_options
77
+ The configuration options.
78
+ stats
79
+ The statistics collector.
80
+ shuffle_id_map
81
+ Mapping from Shuffle/Repartition/Join IR nodes to reserved shuffle IDs.
82
+ collect_metadata
83
+ Whether to collect metadata.
84
+
85
+ Returns
86
+ -------
87
+ The output DataFrame and metadata collector.
88
+ """
89
+ client = get_dask_client()
90
+ result = client.run(
91
+ _evaluate_pipeline_dask,
92
+ callback,
93
+ ir,
94
+ partition_info,
95
+ config_options,
96
+ stats,
97
+ shuffle_id_map,
98
+ collect_metadata=collect_metadata,
99
+ )
100
+ dfs: list[pl.DataFrame] = []
101
+ metadata_collector: list[Metadata] = []
102
+ for df, md in result.values():
103
+ dfs.append(df)
104
+ if md is not None:
105
+ metadata_collector.extend(md)
106
+
107
+ return pl.concat(dfs), metadata_collector or None
108
+
109
+
110
+ def _evaluate_pipeline_dask(
111
+ callback: EvaluatePipelineCallback,
112
+ ir: IR,
113
+ partition_info: MutableMapping[IR, PartitionInfo],
114
+ config_options: ConfigOptions,
115
+ stats: StatsCollector,
116
+ shuffle_id_map: dict[IR, int],
117
+ dask_worker: Any = None,
118
+ *,
119
+ collect_metadata: bool = False,
120
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
121
+ """
122
+ Build and evaluate a RapidsMPF streaming pipeline.
123
+
124
+ Parameters
125
+ ----------
126
+ callback
127
+ The callback function to evaluate the pipeline.
128
+ ir
129
+ The IR node.
130
+ partition_info
131
+ The partition information.
132
+ config_options
133
+ The configuration options.
134
+ stats
135
+ The statistics collector.
136
+ shuffle_id_map
137
+ Mapping from Shuffle/Repartition/Join IR nodes to reserved shuffle IDs.
138
+ dask_worker
139
+ Dask worker reference.
140
+ This kwarg is automatically populated by Dask
141
+ when evaluate_pipeline is called with `client.run`.
142
+ collect_metadata
143
+ Whether to collect metadata.
144
+
145
+ Returns
146
+ -------
147
+ The output DataFrame and metadata collector.
148
+ """
149
+ assert dask_worker is not None, "Dask worker must be provided"
150
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
151
+
152
+ # NOTE: The Dask-CUDA cluster must be bootstrapped
153
+ # ahead of time using bootstrap_dask_cluster
154
+ # (rapidsmpf.integrations.dask.bootstrap_dask_cluster).
155
+ # TODO: Automatically bootstrap the cluster if necessary.
156
+ options = Options(
157
+ {"num_streaming_threads": str(max(config_options.executor.max_io_threads, 1))}
158
+ | get_environment_variables()
159
+ )
160
+ dask_context = get_worker_context(dask_worker)
161
+ rmpf_context = Context(dask_context.comm, dask_context.br, options)
162
+
163
+ # IDs are already reserved by the caller, just pass them through
164
+ return callback(
165
+ ir,
166
+ partition_info,
167
+ config_options,
168
+ stats,
169
+ shuffle_id_map,
170
+ rmpf_context,
171
+ collect_metadata=collect_metadata,
172
+ )
@@ -75,6 +75,8 @@ class GenState(TypedDict):
75
75
  a single IO node.
76
76
  stats
77
77
  Statistics collector.
78
+ collective_id_map
79
+ The mapping of IR nodes to collective IDs.
78
80
  """
79
81
 
80
82
  context: Context
@@ -84,10 +86,11 @@ class GenState(TypedDict):
84
86
  ir_context: IRExecutionContext
85
87
  max_io_threads: int
86
88
  stats: StatsCollector
89
+ collective_id_map: dict[IR, int]
87
90
 
88
91
 
89
92
  SubNetGenerator: TypeAlias = GenericTransformer[
90
- "IR", "tuple[list[Any], dict[IR, ChannelManager]]", GenState
93
+ "IR", "tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]", GenState
91
94
  ]
92
95
  """Protocol for Generating a streaming sub-network."""
93
96
 
@@ -128,7 +131,7 @@ def lower_ir_node(
128
131
  @singledispatch
129
132
  def generate_ir_sub_network(
130
133
  ir: IR, rec: SubNetGenerator
131
- ) -> tuple[list[Any], dict[IR, ChannelManager]]:
134
+ ) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
132
135
  """
133
136
  Generate a sub-network for the RapidsMPF streaming runtime.
134
137
 
@@ -142,7 +145,7 @@ def generate_ir_sub_network(
142
145
  Returns
143
146
  -------
144
147
  nodes
145
- List of streaming-network node(s).
148
+ Dictionary mapping each IR node to its list of streaming-network node(s).
146
149
  channels
147
150
  Dictionary mapping between each IR node and its
148
151
  corresponding output ChannelManager object.