cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,488 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Core RapidsMPF streaming-engine API."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from rapidsmpf.communicator.single import new_communicator
12
+ from rapidsmpf.config import Options, get_environment_variables
13
+ from rapidsmpf.memory.buffer import MemoryType
14
+ from rapidsmpf.memory.buffer_resource import BufferResource, LimitAvailableMemory
15
+ from rapidsmpf.rmm_resource_adaptor import RmmResourceAdaptor
16
+ from rapidsmpf.streaming.core.context import Context
17
+ from rapidsmpf.streaming.core.leaf_node import pull_from_channel
18
+ from rapidsmpf.streaming.core.node import (
19
+ run_streaming_pipeline,
20
+ )
21
+ from rapidsmpf.streaming.cudf.table_chunk import TableChunk
22
+
23
+ import rmm
24
+
25
+ import cudf_polars.experimental.rapidsmpf.collectives.shuffle
26
+ import cudf_polars.experimental.rapidsmpf.io
27
+ import cudf_polars.experimental.rapidsmpf.join
28
+ import cudf_polars.experimental.rapidsmpf.lower
29
+ import cudf_polars.experimental.rapidsmpf.repartition
30
+ import cudf_polars.experimental.rapidsmpf.union # noqa: F401
31
+ from cudf_polars.containers import DataFrame
32
+ from cudf_polars.dsl.ir import DataFrameScan, IRExecutionContext, Join, Scan, Union
33
+ from cudf_polars.dsl.traversal import CachingVisitor, traversal
34
+ from cudf_polars.experimental.rapidsmpf.collectives import ReserveOpIDs
35
+ from cudf_polars.experimental.rapidsmpf.dispatch import FanoutInfo, lower_ir_node
36
+ from cudf_polars.experimental.rapidsmpf.nodes import (
37
+ generate_ir_sub_network_wrapper,
38
+ metadata_drain_node,
39
+ )
40
+ from cudf_polars.experimental.rapidsmpf.utils import empty_table_chunk
41
+ from cudf_polars.experimental.statistics import collect_statistics
42
+ from cudf_polars.experimental.utils import _concat
43
+ from cudf_polars.utils.config import CUDAStreamPoolConfig
44
+
45
+ if TYPE_CHECKING:
46
+ from collections.abc import MutableMapping
47
+
48
+ from rapidsmpf.streaming.core.channel import Channel
49
+ from rapidsmpf.streaming.core.leaf_node import DeferredMessages
50
+
51
+ import polars as pl
52
+
53
+ from rmm.pylibrmm.cuda_stream_pool import CudaStreamPool
54
+
55
+ from cudf_polars.dsl.ir import IR
56
+ from cudf_polars.experimental.base import PartitionInfo, StatsCollector
57
+ from cudf_polars.experimental.parallel import ConfigOptions
58
+ from cudf_polars.experimental.rapidsmpf.dispatch import (
59
+ GenState,
60
+ LowerIRTransformer,
61
+ LowerState,
62
+ SubNetGenerator,
63
+ )
64
+ from cudf_polars.experimental.rapidsmpf.utils import Metadata
65
+
66
+
67
+ def evaluate_logical_plan(
68
+ ir: IR,
69
+ config_options: ConfigOptions,
70
+ *,
71
+ collect_metadata: bool = False,
72
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
73
+ """
74
+ Evaluate a logical plan with the RapidsMPF streaming runtime.
75
+
76
+ Parameters
77
+ ----------
78
+ ir
79
+ The IR node.
80
+ config_options
81
+ The configuration options.
82
+ collect_metadata
83
+ Whether to collect runtime metadata.
84
+
85
+ Returns
86
+ -------
87
+ The output DataFrame and metadata collector.
88
+ """
89
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
90
+ assert config_options.executor.runtime == "rapidsmpf", "Runtime must be rapidsmpf"
91
+
92
+ # Lower the IR graph on the client process (for now).
93
+ ir, partition_info, stats = lower_ir_graph(ir, config_options)
94
+
95
+ # Reserve shuffle IDs for the entire pipeline execution
96
+ with ReserveOpIDs(ir) as shuffle_id_map:
97
+ # Build and execute the streaming pipeline.
98
+ # This must be done on all worker processes
99
+ # for cluster == "distributed".
100
+ if (
101
+ config_options.executor.cluster == "distributed"
102
+ ): # pragma: no cover; block depends on executor type and Distributed cluster
103
+ # Distributed execution: Use client.run
104
+
105
+ # NOTE: Distributed execution requires Dask for now
106
+ from cudf_polars.experimental.rapidsmpf.dask import evaluate_pipeline_dask
107
+
108
+ result, metadata_collector = evaluate_pipeline_dask(
109
+ evaluate_pipeline,
110
+ ir,
111
+ partition_info,
112
+ config_options,
113
+ stats,
114
+ shuffle_id_map,
115
+ collect_metadata=collect_metadata,
116
+ )
117
+ else:
118
+ # Single-process execution: Run locally
119
+ result, metadata_collector = evaluate_pipeline(
120
+ ir,
121
+ partition_info,
122
+ config_options,
123
+ stats,
124
+ shuffle_id_map,
125
+ collect_metadata=collect_metadata,
126
+ )
127
+
128
+ return result, metadata_collector
129
+
130
+
131
+ def evaluate_pipeline(
132
+ ir: IR,
133
+ partition_info: MutableMapping[IR, PartitionInfo],
134
+ config_options: ConfigOptions,
135
+ stats: StatsCollector,
136
+ collective_id_map: dict[IR, int],
137
+ rmpf_context: Context | None = None,
138
+ *,
139
+ collect_metadata: bool = False,
140
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
141
+ """
142
+ Build and evaluate a RapidsMPF streaming pipeline.
143
+
144
+ Parameters
145
+ ----------
146
+ ir
147
+ The IR node.
148
+ partition_info
149
+ The partition information.
150
+ config_options
151
+ The configuration options.
152
+ stats
153
+ The statistics collector.
154
+ collective_id_map
155
+ The mapping of IR nodes to collective IDs.
156
+ rmpf_context
157
+ The RapidsMPF context.
158
+ collect_metadata
159
+ Whether to collect runtime metadata.
160
+
161
+ Returns
162
+ -------
163
+ The output DataFrame and metadata collector.
164
+ """
165
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
166
+ assert config_options.executor.runtime == "rapidsmpf", "Runtime must be rapidsmpf"
167
+
168
+ _initial_mr: Any = None
169
+ stream_pool: CudaStreamPool | bool = False
170
+ if rmpf_context is not None:
171
+ # Using "distributed" mode.
172
+ # Always use the RapidsMPF stream pool for now.
173
+ br = rmpf_context.br()
174
+ stream_pool = True
175
+ else:
176
+ # Using "single" mode.
177
+ # Create a new local RapidsMPF context.
178
+ _original_mr = rmm.mr.get_current_device_resource()
179
+ mr = RmmResourceAdaptor(_original_mr)
180
+ rmm.mr.set_current_device_resource(mr)
181
+ memory_available: MutableMapping[MemoryType, LimitAvailableMemory] | None = None
182
+ single_spill_device = config_options.executor.client_device_threshold
183
+ if single_spill_device > 0.0 and single_spill_device < 1.0:
184
+ total_memory = rmm.mr.available_device_memory()[1]
185
+ memory_available = {
186
+ MemoryType.DEVICE: LimitAvailableMemory(
187
+ mr, limit=int(total_memory * single_spill_device)
188
+ )
189
+ }
190
+
191
+ options = Options(
192
+ {
193
+ # By default, set the number of streaming threads to the max
194
+ # number of IO threads. The user may override this with an
195
+ # environment variable (i.e. RAPIDSMPF_NUM_STREAMING_THREADS)
196
+ "num_streaming_threads": str(
197
+ max(config_options.executor.max_io_threads, 1)
198
+ )
199
+ }
200
+ | get_environment_variables()
201
+ )
202
+ if isinstance(config_options.cuda_stream_policy, CUDAStreamPoolConfig):
203
+ stream_pool = config_options.cuda_stream_policy.build()
204
+ local_comm = new_communicator(options)
205
+ br = BufferResource(
206
+ mr, memory_available=memory_available, stream_pool=stream_pool
207
+ )
208
+ rmpf_context = Context(local_comm, br, options)
209
+
210
+ # Create the IR execution context
211
+ if stream_pool:
212
+ ir_context = IRExecutionContext(
213
+ get_cuda_stream=rmpf_context.get_stream_from_pool
214
+ )
215
+ else:
216
+ ir_context = IRExecutionContext.from_config_options(config_options)
217
+
218
+ # Generate network nodes
219
+ assert rmpf_context is not None, "RapidsMPF context must defined."
220
+ metadata_collector: list[Metadata] | None = [] if collect_metadata else None
221
+ nodes, output = generate_network(
222
+ rmpf_context,
223
+ ir,
224
+ partition_info,
225
+ config_options,
226
+ stats,
227
+ ir_context=ir_context,
228
+ collective_id_map=collective_id_map,
229
+ metadata_collector=metadata_collector,
230
+ )
231
+
232
+ # Run the network
233
+ executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpse")
234
+ run_streaming_pipeline(nodes=nodes, py_executor=executor)
235
+
236
+ # Extract/return the concatenated result.
237
+ # Keep chunks alive until after concatenation to prevent
238
+ # use-after-free with stream-ordered allocations
239
+ messages = output.release()
240
+ chunks = [
241
+ TableChunk.from_message(msg).make_available_and_spill(
242
+ br, allow_overbooking=True
243
+ )
244
+ for msg in messages
245
+ ]
246
+ dfs: list[DataFrame] = []
247
+ if chunks:
248
+ dfs = [
249
+ DataFrame.from_table(
250
+ chunk.table_view(),
251
+ list(ir.schema.keys()),
252
+ list(ir.schema.values()),
253
+ chunk.stream,
254
+ )
255
+ for chunk in chunks
256
+ ]
257
+ df = _concat(*dfs, context=ir_context)
258
+ else:
259
+ # No chunks received - create an empty DataFrame with correct schema
260
+ stream = ir_context.get_cuda_stream()
261
+ chunk = empty_table_chunk(ir, rmpf_context, stream)
262
+ df = DataFrame.from_table(
263
+ chunk.table_view(),
264
+ list(ir.schema.keys()),
265
+ list(ir.schema.values()),
266
+ stream,
267
+ )
268
+
269
+ # We need to materialize the polars dataframe before we drop the rapidsmpf
270
+ # context, which keeps the CUDA streams alive.
271
+ stream = df.stream
272
+ result = df.to_polars()
273
+ stream.synchronize()
274
+
275
+ # Now we need to drop *all* GPU data. This ensures that no cudaFreeAsync runs
276
+ # before the Context, which ultimately contains the rmm MR, goes out of scope.
277
+ del nodes, output, messages, chunks, dfs, df
278
+
279
+ # Restore the initial RMM memory resource
280
+ if _initial_mr is not None:
281
+ rmm.mr.set_current_device_resource(_original_mr)
282
+
283
+ return result, metadata_collector
284
+
285
+
286
+ def lower_ir_graph(
287
+ ir: IR,
288
+ config_options: ConfigOptions,
289
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo], StatsCollector]:
290
+ """
291
+ Rewrite an IR graph and extract partitioning information.
292
+
293
+ Parameters
294
+ ----------
295
+ ir
296
+ Root of the graph to rewrite.
297
+ config_options
298
+ GPUEngine configuration options.
299
+ stats
300
+ The statistics collector.
301
+
302
+ Returns
303
+ -------
304
+ new_ir, partition_info, stats
305
+ The rewritten graph, a mapping from unique nodes
306
+ in the new graph to associated partitioning information,
307
+ and the statistics collector.
308
+
309
+ Notes
310
+ -----
311
+ This function is nearly identical to the `lower_ir_graph` function
312
+ in the `parallel` module, but with some differences:
313
+ - A distinct `lower_ir_node` function is used.
314
+ - A `Repartition` node is added to ensure a single chunk is produced.
315
+ - Statistics are returned.
316
+
317
+ See Also
318
+ --------
319
+ lower_ir_node
320
+ """
321
+ state: LowerState = {
322
+ "config_options": config_options,
323
+ "stats": collect_statistics(ir, config_options),
324
+ }
325
+ mapper: LowerIRTransformer = CachingVisitor(lower_ir_node, state=state)
326
+ return *mapper(ir), state["stats"]
327
+
328
+
329
+ def determine_fanout_nodes(
330
+ ir: IR,
331
+ partition_info: MutableMapping[IR, PartitionInfo],
332
+ ir_dep_count: defaultdict[IR, int],
333
+ ) -> dict[IR, FanoutInfo]:
334
+ """
335
+ Determine which IR nodes need fanout and what type.
336
+
337
+ Parameters
338
+ ----------
339
+ ir
340
+ The root IR node.
341
+ partition_info
342
+ Partition information for each IR node.
343
+ ir_dep_count
344
+ The number of IR dependencies for each IR node.
345
+
346
+ Returns
347
+ -------
348
+ Dictionary mapping IR nodes to FanoutInfo tuples where:
349
+ - num_consumers: number of consumers
350
+ - unbounded: whether the node needs unbounded fanout
351
+ Only includes nodes that need fanout (i.e., have multiple consumers).
352
+ """
353
+ # Determine which nodes need unbounded fanout
354
+ unbounded: set[IR] = set()
355
+
356
+ def _mark_children_unbounded(node: IR) -> None:
357
+ for child in node.children:
358
+ unbounded.add(child)
359
+
360
+ # Traverse the graph and identify nodes that need unbounded fanout
361
+ for node in traversal([ir]):
362
+ if node in unbounded:
363
+ _mark_children_unbounded(node)
364
+ elif isinstance(node, Union):
365
+ # Union processes children sequentially, so all children
366
+ # with multiple consumers need unbounded fanout
367
+ _mark_children_unbounded(node)
368
+ elif isinstance(node, Join):
369
+ # This may be a broadcast join
370
+ _mark_children_unbounded(node)
371
+ elif len(node.children) > 1:
372
+ # Check if this node is doing any broadcasting.
373
+ # When we move to dynamic partitioning, we will need a
374
+ # new way to indicate that a node is broadcasting 1+ children.
375
+ counts = [partition_info[c].count for c in node.children]
376
+ has_broadcast = any(c == 1 for c in counts) and not all(
377
+ c == 1 for c in counts
378
+ )
379
+ if has_broadcast:
380
+ # Broadcasting operation - children need unbounded fanout
381
+ _mark_children_unbounded(node)
382
+
383
+ # Build result dictionary: only include nodes with multiple consumers
384
+ fanout_nodes: dict[IR, FanoutInfo] = {}
385
+ for node, count in ir_dep_count.items():
386
+ if count > 1:
387
+ fanout_nodes[node] = FanoutInfo(
388
+ num_consumers=count,
389
+ unbounded=node in unbounded,
390
+ )
391
+
392
+ return fanout_nodes
393
+
394
+
395
+ def generate_network(
396
+ context: Context,
397
+ ir: IR,
398
+ partition_info: MutableMapping[IR, PartitionInfo],
399
+ config_options: ConfigOptions,
400
+ stats: StatsCollector,
401
+ *,
402
+ ir_context: IRExecutionContext,
403
+ collective_id_map: dict[IR, int],
404
+ metadata_collector: list[Metadata] | None,
405
+ ) -> tuple[list[Any], DeferredMessages]:
406
+ """
407
+ Translate the IR graph to a RapidsMPF streaming network.
408
+
409
+ Parameters
410
+ ----------
411
+ context
412
+ The rapidsmpf context.
413
+ ir
414
+ The IR node.
415
+ partition_info
416
+ The partition information.
417
+ config_options
418
+ The configuration options.
419
+ stats
420
+ Statistics collector.
421
+ ir_context
422
+ The execution context for the IR node.
423
+ collective_id_map
424
+ The mapping of IR nodes to collective IDs.
425
+ metadata_collector
426
+ The list to collect the final metadata.
427
+ This list will be mutated when the network is executed.
428
+ If None, metadata will not be collected.
429
+
430
+ Returns
431
+ -------
432
+ The network nodes and output hook.
433
+ """
434
+ # Count the number of IO nodes and the number of IR dependencies
435
+ num_io_nodes: int = 0
436
+ ir_dep_count: defaultdict[IR, int] = defaultdict(int)
437
+ for node in traversal([ir]):
438
+ if isinstance(node, (DataFrameScan, Scan)):
439
+ num_io_nodes += 1
440
+ for child in node.children:
441
+ ir_dep_count[child] += 1
442
+
443
+ # Determine which nodes need fanout
444
+ fanout_nodes = determine_fanout_nodes(ir, partition_info, ir_dep_count)
445
+
446
+ # Get max_io_threads from config (default: 2)
447
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
448
+ max_io_threads_global = config_options.executor.max_io_threads
449
+ max_io_threads_local = max(1, max_io_threads_global // max(1, num_io_nodes))
450
+
451
+ # Generate the network
452
+ state: GenState = {
453
+ "context": context,
454
+ "config_options": config_options,
455
+ "partition_info": partition_info,
456
+ "fanout_nodes": fanout_nodes,
457
+ "ir_context": ir_context,
458
+ "max_io_threads": max_io_threads_local,
459
+ "stats": stats,
460
+ "collective_id_map": collective_id_map,
461
+ }
462
+ mapper: SubNetGenerator = CachingVisitor(
463
+ generate_ir_sub_network_wrapper, state=state
464
+ )
465
+ nodes_dict, channels = mapper(ir)
466
+ ch_out = channels[ir].reserve_output_slot()
467
+
468
+ # Add node to drain metadata channel before pull_from_channel
469
+ # (since pull_from_channel doesn't accept a ChannelPair)
470
+ ch_final_data: Channel[TableChunk] = context.create_channel()
471
+ drain_node = metadata_drain_node(
472
+ context,
473
+ ir,
474
+ ir_context,
475
+ ch_out,
476
+ ch_final_data,
477
+ metadata_collector,
478
+ )
479
+
480
+ # Add final node to pull from the output data channel
481
+ output_node, output = pull_from_channel(context, ch_in=ch_final_data)
482
+
483
+ # Flatten the nodes dictionary into a list for run_streaming_pipeline
484
+ nodes: list[Any] = [node for node_list in nodes_dict.values() for node in node_list]
485
+ nodes.extend([drain_node, output_node])
486
+
487
+ # Return network and output hook
488
+ return nodes, output
@@ -0,0 +1,172 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Dask-based execution with the streaming RapidsMPF runtime."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Any, Protocol
8
+
9
+ from distributed import get_client
10
+ from rapidsmpf.config import Options, get_environment_variables
11
+ from rapidsmpf.integrations.dask import get_worker_context
12
+ from rapidsmpf.streaming.core.context import Context
13
+
14
+ import polars as pl
15
+
16
+ from cudf_polars.experimental.dask_registers import DaskRegisterManager
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import MutableMapping
20
+
21
+ from distributed import Client
22
+
23
+ from cudf_polars.dsl.ir import IR
24
+ from cudf_polars.experimental.base import PartitionInfo, StatsCollector
25
+ from cudf_polars.experimental.parallel import ConfigOptions
26
+ from cudf_polars.experimental.rapidsmpf.utils import Metadata
27
+
28
+
29
+ class EvaluatePipelineCallback(Protocol):
30
+ """Protocol for the evaluate_pipeline callback."""
31
+
32
+ def __call__(
33
+ self,
34
+ ir: IR,
35
+ partition_info: MutableMapping[IR, PartitionInfo],
36
+ config_options: ConfigOptions,
37
+ stats: StatsCollector,
38
+ collective_id_map: dict[IR, int],
39
+ rmpf_context: Context | None = None,
40
+ *,
41
+ collect_metadata: bool = False,
42
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
43
+ """Evaluate a pipeline and return the result DataFrame and metadata."""
44
+ ...
45
+
46
+
47
+ def get_dask_client() -> Client:
48
+ """Get a distributed Dask client."""
49
+ client = get_client()
50
+ DaskRegisterManager.register_once()
51
+ DaskRegisterManager.run_on_cluster(client)
52
+ return client
53
+
54
+
55
+ def evaluate_pipeline_dask(
56
+ callback: EvaluatePipelineCallback,
57
+ ir: IR,
58
+ partition_info: MutableMapping[IR, PartitionInfo],
59
+ config_options: ConfigOptions,
60
+ stats: StatsCollector,
61
+ shuffle_id_map: dict[IR, int],
62
+ *,
63
+ collect_metadata: bool = False,
64
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
65
+ """
66
+ Evaluate a RapidsMPF streaming pipeline on a Dask cluster.
67
+
68
+ Parameters
69
+ ----------
70
+ callback
71
+ The callback function to evaluate the pipeline.
72
+ ir
73
+ The IR node.
74
+ partition_info
75
+ The partition information.
76
+ config_options
77
+ The configuration options.
78
+ stats
79
+ The statistics collector.
80
+ shuffle_id_map
81
+ Mapping from Shuffle/Repartition/Join IR nodes to reserved shuffle IDs.
82
+ collect_metadata
83
+ Whether to collect metadata.
84
+
85
+ Returns
86
+ -------
87
+ The output DataFrame and metadata collector.
88
+ """
89
+ client = get_dask_client()
90
+ result = client.run(
91
+ _evaluate_pipeline_dask,
92
+ callback,
93
+ ir,
94
+ partition_info,
95
+ config_options,
96
+ stats,
97
+ shuffle_id_map,
98
+ collect_metadata=collect_metadata,
99
+ )
100
+ dfs: list[pl.DataFrame] = []
101
+ metadata_collector: list[Metadata] = []
102
+ for df, md in result.values():
103
+ dfs.append(df)
104
+ if md is not None:
105
+ metadata_collector.extend(md)
106
+
107
+ return pl.concat(dfs), metadata_collector or None
108
+
109
+
110
+ def _evaluate_pipeline_dask(
111
+ callback: EvaluatePipelineCallback,
112
+ ir: IR,
113
+ partition_info: MutableMapping[IR, PartitionInfo],
114
+ config_options: ConfigOptions,
115
+ stats: StatsCollector,
116
+ shuffle_id_map: dict[IR, int],
117
+ dask_worker: Any = None,
118
+ *,
119
+ collect_metadata: bool = False,
120
+ ) -> tuple[pl.DataFrame, list[Metadata] | None]:
121
+ """
122
+ Build and evaluate a RapidsMPF streaming pipeline.
123
+
124
+ Parameters
125
+ ----------
126
+ callback
127
+ The callback function to evaluate the pipeline.
128
+ ir
129
+ The IR node.
130
+ partition_info
131
+ The partition information.
132
+ config_options
133
+ The configuration options.
134
+ stats
135
+ The statistics collector.
136
+ shuffle_id_map
137
+ Mapping from Shuffle/Repartition/Join IR nodes to reserved shuffle IDs.
138
+ dask_worker
139
+ Dask worker reference.
140
+ This kwarg is automatically populated by Dask
141
+ when evaluate_pipeline is called with `client.run`.
142
+ collect_metadata
143
+ Whether to collect metadata.
144
+
145
+ Returns
146
+ -------
147
+ The output DataFrame and metadata collector.
148
+ """
149
+ assert dask_worker is not None, "Dask worker must be provided"
150
+ assert config_options.executor.name == "streaming", "Executor must be streaming"
151
+
152
+ # NOTE: The Dask-CUDA cluster must be bootstrapped
153
+ # ahead of time using bootstrap_dask_cluster
154
+ # (rapidsmpf.integrations.dask.bootstrap_dask_cluster).
155
+ # TODO: Automatically bootstrap the cluster if necessary.
156
+ options = Options(
157
+ {"num_streaming_threads": str(max(config_options.executor.max_io_threads, 1))}
158
+ | get_environment_variables()
159
+ )
160
+ dask_context = get_worker_context(dask_worker)
161
+ rmpf_context = Context(dask_context.comm, dask_context.br, options)
162
+
163
+ # IDs are already reserved by the caller, just pass them through
164
+ return callback(
165
+ ir,
166
+ partition_info,
167
+ config_options,
168
+ stats,
169
+ shuffle_id_map,
170
+ rmpf_context,
171
+ collect_metadata=collect_metadata,
172
+ )