cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py CHANGED
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """
4
4
  DSL nodes for the LogicalPlan of polars.
@@ -13,46 +13,61 @@ can be considered as functions:
13
13
 
14
14
  from __future__ import annotations
15
15
 
16
+ import contextlib
16
17
  import itertools
17
18
  import json
18
19
  import random
19
20
  import time
21
+ from dataclasses import dataclass
20
22
  from functools import cache
21
23
  from pathlib import Path
22
- from typing import TYPE_CHECKING, Any, ClassVar
24
+ from typing import TYPE_CHECKING, Any, ClassVar, overload
23
25
 
24
26
  from typing_extensions import assert_never
25
27
 
26
28
  import polars as pl
27
29
 
28
30
  import pylibcudf as plc
31
+ from pylibcudf import expressions as plc_expr
29
32
 
30
33
  import cudf_polars.dsl.expr as expr
31
34
  from cudf_polars.containers import Column, DataFrame, DataType
35
+ from cudf_polars.containers.dataframe import NamedColumn
32
36
  from cudf_polars.dsl.expressions import rolling, unary
33
37
  from cudf_polars.dsl.expressions.base import ExecutionContext
34
38
  from cudf_polars.dsl.nodebase import Node
35
39
  from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
36
- from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
40
+ from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
37
41
  from cudf_polars.dsl.utils.reshape import broadcast
38
- from cudf_polars.dsl.utils.windows import range_window_bounds
42
+ from cudf_polars.dsl.utils.windows import (
43
+ offsets_to_windows,
44
+ range_window_bounds,
45
+ )
39
46
  from cudf_polars.utils import dtypes
40
- from cudf_polars.utils.versions import POLARS_VERSION_LT_131
47
+ from cudf_polars.utils.config import CUDAStreamPolicy
48
+ from cudf_polars.utils.cuda_stream import (
49
+ get_cuda_stream,
50
+ get_joined_cuda_stream,
51
+ get_new_cuda_stream,
52
+ join_cuda_streams,
53
+ )
54
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_131, POLARS_VERSION_LT_134
41
55
 
42
56
  if TYPE_CHECKING:
43
- from collections.abc import Callable, Hashable, Iterable, Sequence
57
+ from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
44
58
  from typing import Literal
45
59
 
46
60
  from typing_extensions import Self
47
61
 
48
- from polars.polars import _expr_nodes as pl_expr
62
+ from polars import polars # type: ignore[attr-defined]
63
+
64
+ from rmm.pylibrmm.stream import Stream
49
65
 
50
66
  from cudf_polars.containers.dataframe import NamedColumn
51
67
  from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
52
- from cudf_polars.utils.config import ParquetOptions
68
+ from cudf_polars.utils.config import ConfigOptions, ParquetOptions
53
69
  from cudf_polars.utils.timer import Timer
54
70
 
55
-
56
71
  __all__ = [
57
72
  "IR",
58
73
  "Cache",
@@ -65,6 +80,7 @@ __all__ = [
65
80
  "GroupBy",
66
81
  "HConcat",
67
82
  "HStack",
83
+ "IRExecutionContext",
68
84
  "Join",
69
85
  "MapFunction",
70
86
  "MergeSorted",
@@ -81,6 +97,97 @@ __all__ = [
81
97
  ]
82
98
 
83
99
 
100
+ @dataclass(frozen=True)
101
+ class IRExecutionContext:
102
+ """
103
+ Runtime context for IR node execution.
104
+
105
+ This dataclass holds runtime information and configuration needed
106
+ during the evaluation of IR nodes.
107
+
108
+ Parameters
109
+ ----------
110
+ get_cuda_stream
111
+ A zero-argument callable that returns a CUDA stream.
112
+ """
113
+
114
+ get_cuda_stream: Callable[[], Stream]
115
+
116
+ @classmethod
117
+ def from_config_options(cls, config_options: ConfigOptions) -> IRExecutionContext:
118
+ """Create an IRExecutionContext from ConfigOptions."""
119
+ match config_options.cuda_stream_policy:
120
+ case CUDAStreamPolicy.DEFAULT:
121
+ return cls(get_cuda_stream=get_cuda_stream)
122
+ case CUDAStreamPolicy.NEW:
123
+ return cls(get_cuda_stream=get_new_cuda_stream)
124
+ case _: # pragma: no cover
125
+ raise ValueError(
126
+ f"Invalid CUDA stream policy: {config_options.cuda_stream_policy}"
127
+ )
128
+
129
+ @contextlib.contextmanager
130
+ def stream_ordered_after(self, *dfs: DataFrame) -> Generator[Stream, None, None]:
131
+ """
132
+ Get a joined CUDA stream with safe stream ordering for deallocation of inputs.
133
+
134
+ Parameters
135
+ ----------
136
+ dfs
137
+ The dataframes being provided to stream-ordered operations.
138
+
139
+ Yields
140
+ ------
141
+ A CUDA stream that is downstream of the given dataframes.
142
+
143
+ Notes
144
+ -----
145
+ This context manager provides two useful guarantees when working with
146
+ objects holding references to stream-ordered objects:
147
+
148
+ 1. The stream yield upon entering the context manager is *downstream* of
149
+ all the input dataframes. This ensures that you can safely perform
150
+ stream-ordered operations on any input using the yielded stream.
151
+ 2. The stream-ordered CUDA deallocation of the inputs happens *after* the
152
+ context manager exits. This ensures that all stream-ordered operations
153
+ submitted inside the context manager can complete before the memory
154
+ referenced by the inputs is deallocated.
155
+
156
+ Note that this does (deliberately) disconnect the dropping of the Python
157
+ object (by its refcount dropping to 0) from the actual stream-ordered
158
+ deallocation of the CUDA memory. This is precisely what we need to ensure
159
+ that the inputs are valid long enough for the stream-ordered operations to
160
+ complete.
161
+ """
162
+ result_stream = get_joined_cuda_stream(
163
+ self.get_cuda_stream, upstreams=[df.stream for df in dfs]
164
+ )
165
+
166
+ yield result_stream
167
+
168
+ # ensure that the inputs are downstream of result_stream (so that deallocation happens after the result is ready)
169
+ join_cuda_streams(
170
+ downstreams=[df.stream for df in dfs], upstreams=[result_stream]
171
+ )
172
+
173
+
174
+ _BINOPS = {
175
+ plc.binaryop.BinaryOperator.EQUAL,
176
+ plc.binaryop.BinaryOperator.NOT_EQUAL,
177
+ plc.binaryop.BinaryOperator.LESS,
178
+ plc.binaryop.BinaryOperator.LESS_EQUAL,
179
+ plc.binaryop.BinaryOperator.GREATER,
180
+ plc.binaryop.BinaryOperator.GREATER_EQUAL,
181
+ # TODO: Handle other binary operations as needed
182
+ }
183
+
184
+
185
+ _DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
186
+
187
+
188
+ _FLOAT_TYPES = {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
189
+
190
+
84
191
  class IR(Node["IR"]):
85
192
  """Abstract plan node, representing an unevaluated dataframe."""
86
193
 
@@ -134,7 +241,9 @@ class IR(Node["IR"]):
134
241
  translation phase should fail earlier.
135
242
  """
136
243
 
137
- def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
244
+ def evaluate(
245
+ self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
246
+ ) -> DataFrame:
138
247
  """
139
248
  Evaluate the node (recursively) and return a dataframe.
140
249
 
@@ -146,6 +255,8 @@ class IR(Node["IR"]):
146
255
  timer
147
256
  If not None, a Timer object to record timings for the
148
257
  evaluation of the node.
258
+ context
259
+ The execution context for the node.
149
260
 
150
261
  Notes
151
262
  -----
@@ -164,16 +275,19 @@ class IR(Node["IR"]):
164
275
  If evaluation fails. Ideally this should not occur, since the
165
276
  translation phase should fail earlier.
166
277
  """
167
- children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
278
+ children = [
279
+ child.evaluate(cache=cache, timer=timer, context=context)
280
+ for child in self.children
281
+ ]
168
282
  if timer is not None:
169
283
  start = time.monotonic_ns()
170
- result = self.do_evaluate(*self._non_child_args, *children)
284
+ result = self.do_evaluate(*self._non_child_args, *children, context=context)
171
285
  end = time.monotonic_ns()
172
286
  # TODO: Set better names on each class object.
173
287
  timer.store(start, end, type(self).__name__)
174
288
  return result
175
289
  else:
176
- return self.do_evaluate(*self._non_child_args, *children)
290
+ return self.do_evaluate(*self._non_child_args, *children, context=context)
177
291
 
178
292
 
179
293
  class ErrorNode(IR):
@@ -212,29 +326,97 @@ class PythonScan(IR):
212
326
  raise NotImplementedError("PythonScan not implemented")
213
327
 
214
328
 
329
+ _DECIMAL_IDS = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
330
+
331
+ _COMPARISON_BINOPS = {
332
+ plc.binaryop.BinaryOperator.EQUAL,
333
+ plc.binaryop.BinaryOperator.NOT_EQUAL,
334
+ plc.binaryop.BinaryOperator.LESS,
335
+ plc.binaryop.BinaryOperator.LESS_EQUAL,
336
+ plc.binaryop.BinaryOperator.GREATER,
337
+ plc.binaryop.BinaryOperator.GREATER_EQUAL,
338
+ }
339
+
340
+
341
+ def _parquet_physical_types(
342
+ schema: Schema, paths: list[str], columns: list[str] | None, stream: Stream
343
+ ) -> dict[str, plc.DataType]:
344
+ # TODO: Read the physical types as cudf::data_type's using
345
+ # read_parquet_metadata or another parquet API
346
+ options = plc.io.parquet.ParquetReaderOptions.builder(
347
+ plc.io.SourceInfo(paths)
348
+ ).build()
349
+ if columns is not None:
350
+ options.set_columns(columns)
351
+ options.set_num_rows(0)
352
+ df = plc.io.parquet.read_parquet(options, stream=stream)
353
+ return dict(zip(schema.keys(), [c.type() for c in df.tbl.columns()], strict=True))
354
+
355
+
356
+ def _cast_literal_to_decimal(
357
+ side: expr.Expr, lit: expr.Literal, phys_type_map: dict[str, plc.DataType]
358
+ ) -> expr.Expr:
359
+ if isinstance(side, expr.Cast):
360
+ col = side.children[0]
361
+ assert isinstance(col, expr.Col)
362
+ name = col.name
363
+ else:
364
+ assert isinstance(side, expr.Col)
365
+ name = side.name
366
+ if (type_ := phys_type_map[name]).id() in _DECIMAL_IDS:
367
+ scale = abs(type_.scale())
368
+ return expr.Cast(
369
+ side.dtype,
370
+ True, # noqa: FBT003
371
+ expr.Cast(DataType(pl.Decimal(38, scale)), True, lit), # noqa: FBT003
372
+ )
373
+ return lit
374
+
375
+
376
+ def _cast_literals_to_physical_types(
377
+ node: expr.Expr, phys_type_map: dict[str, plc.DataType]
378
+ ) -> expr.Expr:
379
+ if isinstance(node, expr.BinOp):
380
+ left, right = node.children
381
+ left = _cast_literals_to_physical_types(left, phys_type_map)
382
+ right = _cast_literals_to_physical_types(right, phys_type_map)
383
+ if node.op in _COMPARISON_BINOPS:
384
+ if isinstance(left, (expr.Col, expr.Cast)) and isinstance(
385
+ right, expr.Literal
386
+ ):
387
+ right = _cast_literal_to_decimal(left, right, phys_type_map)
388
+ elif isinstance(right, (expr.Col, expr.Cast)) and isinstance(
389
+ left, expr.Literal
390
+ ):
391
+ left = _cast_literal_to_decimal(right, left, phys_type_map)
392
+
393
+ return node.reconstruct([left, right])
394
+ return node
395
+
396
+
215
397
  def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
216
398
  # TODO: Alternatively set the schema of the parquet reader to decimal128
217
- plc_decimals_ids = {
218
- plc.TypeId.DECIMAL32,
219
- plc.TypeId.DECIMAL64,
220
- plc.TypeId.DECIMAL128,
221
- }
222
399
  cast_list = []
223
400
 
224
401
  for name, col in df.column_map.items():
225
402
  src = col.obj.type()
226
- dst = schema[name].plc
403
+ dst = schema[name].plc_type
404
+
227
405
  if (
228
- src.id() in plc_decimals_ids
229
- and dst.id() in plc_decimals_ids
230
- and ((src.id() != dst.id()) or (src.scale != dst.scale))
406
+ plc.traits.is_fixed_point(src)
407
+ and plc.traits.is_fixed_point(dst)
408
+ and ((src.id() != dst.id()) or (src.scale() != dst.scale()))
231
409
  ):
232
410
  cast_list.append(
233
- Column(plc.unary.cast(col.obj, dst), name=name, dtype=schema[name])
411
+ Column(
412
+ plc.unary.cast(col.obj, dst, stream=df.stream),
413
+ name=name,
414
+ dtype=schema[name],
415
+ )
234
416
  )
235
417
 
236
418
  if cast_list:
237
- df = df.with_columns(cast_list)
419
+ df = df.with_columns(cast_list, stream=df.stream)
238
420
 
239
421
  return df
240
422
 
@@ -460,13 +642,24 @@ class Scan(IR):
460
642
  Each path is repeated according to the number of rows read from it.
461
643
  """
462
644
  (filepaths,) = plc.filling.repeat(
463
- plc.Table([plc.Column.from_arrow(pl.Series(values=map(str, paths)))]),
645
+ plc.Table(
646
+ [
647
+ plc.Column.from_arrow(
648
+ pl.Series(values=map(str, paths)),
649
+ stream=df.stream,
650
+ )
651
+ ]
652
+ ),
464
653
  plc.Column.from_arrow(
465
- pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32())
654
+ pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32()),
655
+ stream=df.stream,
466
656
  ),
657
+ stream=df.stream,
467
658
  ).columns()
468
659
  dtype = DataType(pl.String())
469
- return df.with_columns([Column(filepaths, name=name, dtype=dtype)])
660
+ return df.with_columns(
661
+ [Column(filepaths, name=name, dtype=dtype)], stream=df.stream
662
+ )
470
663
 
471
664
  def fast_count(self) -> int: # pragma: no cover
472
665
  """Get the number of rows in a Parquet Scan."""
@@ -479,6 +672,7 @@ class Scan(IR):
479
672
  return max(total_rows, 0)
480
673
 
481
674
  @classmethod
675
+ @log_do_evaluate
482
676
  @nvtx_annotate_cudf_polars(message="Scan")
483
677
  def do_evaluate(
484
678
  cls,
@@ -493,8 +687,11 @@ class Scan(IR):
493
687
  include_file_paths: str | None,
494
688
  predicate: expr.NamedExpr | None,
495
689
  parquet_options: ParquetOptions,
690
+ *,
691
+ context: IRExecutionContext,
496
692
  ) -> DataFrame:
497
693
  """Evaluate and return a dataframe."""
694
+ stream = context.get_cuda_stream()
498
695
  if typ == "csv":
499
696
 
500
697
  def read_csv_header(
@@ -551,6 +748,7 @@ class Scan(IR):
551
748
  plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
552
749
  .nrows(n_rows)
553
750
  .skiprows(skiprows + skip_rows)
751
+ .skip_blank_lines(skip_blank_lines=False)
554
752
  .lineterminator(str(eol))
555
753
  .quotechar(str(quote))
556
754
  .decimal(decimal)
@@ -567,13 +765,15 @@ class Scan(IR):
567
765
  column_names = read_csv_header(path, str(sep))
568
766
  options.set_names(column_names)
569
767
  options.set_header(header)
570
- options.set_dtypes({name: dtype.plc for name, dtype in schema.items()})
768
+ options.set_dtypes(
769
+ {name: dtype.plc_type for name, dtype in schema.items()}
770
+ )
571
771
  if usecols is not None:
572
772
  options.set_use_cols_names([str(name) for name in usecols])
573
773
  options.set_na_values(null_values)
574
774
  if comment is not None:
575
775
  options.set_comment(comment)
576
- tbl_w_meta = plc.io.csv.read_csv(options)
776
+ tbl_w_meta = plc.io.csv.read_csv(options, stream=stream)
577
777
  pieces.append(tbl_w_meta)
578
778
  if include_file_paths is not None:
579
779
  seen_paths.append(p)
@@ -589,9 +789,10 @@ class Scan(IR):
589
789
  strict=True,
590
790
  )
591
791
  df = DataFrame.from_table(
592
- plc.concatenate.concatenate(list(tables)),
792
+ plc.concatenate.concatenate(list(tables), stream=stream),
593
793
  colnames,
594
794
  [schema[colname] for colname in colnames],
795
+ stream=stream,
595
796
  )
596
797
  if include_file_paths is not None:
597
798
  df = Scan.add_file_paths(
@@ -604,42 +805,50 @@ class Scan(IR):
604
805
  filters = None
605
806
  if predicate is not None and row_index is None:
606
807
  # Can't apply filters during read if we have a row index.
607
- filters = to_parquet_filter(predicate.value)
608
- options = plc.io.parquet.ParquetReaderOptions.builder(
808
+ filters = to_parquet_filter(
809
+ _cast_literals_to_physical_types(
810
+ predicate.value,
811
+ _parquet_physical_types(
812
+ schema, paths, with_columns or list(schema.keys()), stream
813
+ ),
814
+ ),
815
+ stream=stream,
816
+ )
817
+ parquet_reader_options = plc.io.parquet.ParquetReaderOptions.builder(
609
818
  plc.io.SourceInfo(paths)
610
819
  ).build()
611
820
  if with_columns is not None:
612
- options.set_columns(with_columns)
821
+ parquet_reader_options.set_columns(with_columns)
613
822
  if filters is not None:
614
- options.set_filter(filters)
823
+ parquet_reader_options.set_filter(filters)
615
824
  if n_rows != -1:
616
- options.set_num_rows(n_rows)
825
+ parquet_reader_options.set_num_rows(n_rows)
617
826
  if skip_rows != 0:
618
- options.set_skip_rows(skip_rows)
827
+ parquet_reader_options.set_skip_rows(skip_rows)
619
828
  if parquet_options.chunked:
620
829
  reader = plc.io.parquet.ChunkedParquetReader(
621
- options,
830
+ parquet_reader_options,
622
831
  chunk_read_limit=parquet_options.chunk_read_limit,
623
832
  pass_read_limit=parquet_options.pass_read_limit,
833
+ stream=stream,
624
834
  )
625
835
  chunk = reader.read_chunk()
626
- tbl = chunk.tbl
627
836
  # TODO: Nested column names
628
837
  names = chunk.column_names(include_children=False)
629
- concatenated_columns = tbl.columns()
838
+ concatenated_columns = chunk.tbl.columns()
630
839
  while reader.has_next():
631
- chunk = reader.read_chunk()
632
- tbl = chunk.tbl
633
- for i in range(tbl.num_columns()):
840
+ columns = reader.read_chunk().tbl.columns()
841
+ # Discard columns while concatenating to reduce memory footprint.
842
+ # Reverse order to avoid O(n^2) list popping cost.
843
+ for i in range(len(concatenated_columns) - 1, -1, -1):
634
844
  concatenated_columns[i] = plc.concatenate.concatenate(
635
- [concatenated_columns[i], tbl._columns[i]]
845
+ [concatenated_columns[i], columns.pop()], stream=stream
636
846
  )
637
- # Drop residual columns to save memory
638
- tbl._columns[i] = None
639
847
  df = DataFrame.from_table(
640
848
  plc.Table(concatenated_columns),
641
849
  names=names,
642
850
  dtypes=[schema[name] for name in names],
851
+ stream=stream,
643
852
  )
644
853
  df = _align_parquet_schema(df, schema)
645
854
  if include_file_paths is not None:
@@ -647,13 +856,16 @@ class Scan(IR):
647
856
  include_file_paths, paths, chunk.num_rows_per_source, df
648
857
  )
649
858
  else:
650
- tbl_w_meta = plc.io.parquet.read_parquet(options)
859
+ tbl_w_meta = plc.io.parquet.read_parquet(
860
+ parquet_reader_options, stream=stream
861
+ )
651
862
  # TODO: consider nested column names?
652
863
  col_names = tbl_w_meta.column_names(include_children=False)
653
864
  df = DataFrame.from_table(
654
865
  tbl_w_meta.tbl,
655
866
  col_names,
656
867
  [schema[name] for name in col_names],
868
+ stream=stream,
657
869
  )
658
870
  df = _align_parquet_schema(df, schema)
659
871
  if include_file_paths is not None:
@@ -665,16 +877,16 @@ class Scan(IR):
665
877
  return df
666
878
  elif typ == "ndjson":
667
879
  json_schema: list[plc.io.json.NameAndType] = [
668
- (name, typ.plc, []) for name, typ in schema.items()
880
+ (name, typ.plc_type, []) for name, typ in schema.items()
669
881
  ]
670
- plc_tbl_w_meta = plc.io.json.read_json(
671
- plc.io.json._setup_json_reader_options(
672
- plc.io.SourceInfo(paths),
673
- lines=True,
674
- dtypes=json_schema,
675
- prune_columns=True,
676
- )
882
+ json_reader_options = (
883
+ plc.io.json.JsonReaderOptions.builder(plc.io.SourceInfo(paths))
884
+ .lines(val=True)
885
+ .dtypes(json_schema)
886
+ .prune_columns(val=True)
887
+ .build()
677
888
  )
889
+ plc_tbl_w_meta = plc.io.json.read_json(json_reader_options, stream=stream)
678
890
  # TODO: I don't think cudf-polars supports nested types in general right now
679
891
  # (but when it does, we should pass child column names from nested columns in)
680
892
  col_names = plc_tbl_w_meta.column_names(include_children=False)
@@ -682,6 +894,7 @@ class Scan(IR):
682
894
  plc_tbl_w_meta.tbl,
683
895
  col_names,
684
896
  [schema[name] for name in col_names],
897
+ stream=stream,
685
898
  )
686
899
  col_order = list(schema.keys())
687
900
  if row_index is not None:
@@ -695,26 +908,28 @@ class Scan(IR):
695
908
  name, offset = row_index
696
909
  offset += skip_rows
697
910
  dtype = schema[name]
698
- step = plc.Scalar.from_py(1, dtype.plc)
699
- init = plc.Scalar.from_py(offset, dtype.plc)
911
+ step = plc.Scalar.from_py(1, dtype.plc_type, stream=stream)
912
+ init = plc.Scalar.from_py(offset, dtype.plc_type, stream=stream)
700
913
  index_col = Column(
701
- plc.filling.sequence(df.num_rows, init, step),
914
+ plc.filling.sequence(df.num_rows, init, step, stream=stream),
702
915
  is_sorted=plc.types.Sorted.YES,
703
916
  order=plc.types.Order.ASCENDING,
704
917
  null_order=plc.types.NullOrder.AFTER,
705
918
  name=name,
706
919
  dtype=dtype,
707
920
  )
708
- df = DataFrame([index_col, *df.columns])
921
+ df = DataFrame([index_col, *df.columns], stream=df.stream)
709
922
  if next(iter(schema)) != name:
710
923
  df = df.select(schema)
711
924
  assert all(
712
- c.obj.type() == schema[name].plc for name, c in df.column_map.items()
925
+ c.obj.type() == schema[name].plc_type for name, c in df.column_map.items()
713
926
  )
714
927
  if predicate is None:
715
928
  return df
716
929
  else:
717
- (mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)
930
+ (mask,) = broadcast(
931
+ predicate.evaluate(df), target_length=df.num_rows, stream=df.stream
932
+ )
718
933
  return df.filter(mask)
719
934
 
720
935
 
@@ -775,7 +990,8 @@ class Sink(IR):
775
990
  child_schema = df.schema.values()
776
991
  if kind == "Csv":
777
992
  if not all(
778
- plc.io.csv.is_supported_write_csv(dtype.plc) for dtype in child_schema
993
+ plc.io.csv.is_supported_write_csv(dtype.plc_type)
994
+ for dtype in child_schema
779
995
  ):
780
996
  # Nested types are unsupported in polars and libcudf
781
997
  raise NotImplementedError(
@@ -826,7 +1042,8 @@ class Sink(IR):
826
1042
  kind == "Json"
827
1043
  ): # pragma: no cover; options are validated on the polars side
828
1044
  if not all(
829
- plc.io.json.is_supported_write_json(dtype.plc) for dtype in child_schema
1045
+ plc.io.json.is_supported_write_json(dtype.plc_type)
1046
+ for dtype in child_schema
830
1047
  ):
831
1048
  # Nested types are unsupported in polars and libcudf
832
1049
  raise NotImplementedError(
@@ -863,7 +1080,7 @@ class Sink(IR):
863
1080
  ) -> None:
864
1081
  """Write CSV data to a sink."""
865
1082
  serialize = options["serialize_options"]
866
- options = (
1083
+ csv_writer_options = (
867
1084
  plc.io.csv.CsvWriterOptions.builder(target, df.table)
868
1085
  .include_header(options["include_header"])
869
1086
  .names(df.column_names if options["include_header"] else [])
@@ -872,7 +1089,7 @@ class Sink(IR):
872
1089
  .inter_column_delimiter(chr(serialize["separator"]))
873
1090
  .build()
874
1091
  )
875
- plc.io.csv.write_csv(options)
1092
+ plc.io.csv.write_csv(csv_writer_options, stream=df.stream)
876
1093
 
877
1094
  @classmethod
878
1095
  def _write_json(cls, target: plc.io.SinkInfo, df: DataFrame) -> None:
@@ -889,7 +1106,7 @@ class Sink(IR):
889
1106
  .utf8_escaped(val=False)
890
1107
  .build()
891
1108
  )
892
- plc.io.json.write_json(options)
1109
+ plc.io.json.write_json(options, stream=df.stream)
893
1110
 
894
1111
  @staticmethod
895
1112
  def _make_parquet_metadata(df: DataFrame) -> plc.io.types.TableInputMetadata:
@@ -899,6 +1116,20 @@ class Sink(IR):
899
1116
  metadata.column_metadata[i].set_name(name)
900
1117
  return metadata
901
1118
 
1119
+ @overload
1120
+ @staticmethod
1121
+ def _apply_parquet_writer_options(
1122
+ builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder,
1123
+ options: dict[str, Any],
1124
+ ) -> plc.io.parquet.ChunkedParquetWriterOptionsBuilder: ...
1125
+
1126
+ @overload
1127
+ @staticmethod
1128
+ def _apply_parquet_writer_options(
1129
+ builder: plc.io.parquet.ParquetWriterOptionsBuilder,
1130
+ options: dict[str, Any],
1131
+ ) -> plc.io.parquet.ParquetWriterOptionsBuilder: ...
1132
+
902
1133
  @staticmethod
903
1134
  def _apply_parquet_writer_options(
904
1135
  builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder
@@ -944,12 +1175,16 @@ class Sink(IR):
944
1175
  and parquet_options.n_output_chunks != 1
945
1176
  and df.table.num_rows() != 0
946
1177
  ):
947
- builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
1178
+ chunked_builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
948
1179
  target
949
1180
  ).metadata(metadata)
950
- builder = cls._apply_parquet_writer_options(builder, options)
951
- writer_options = builder.build()
952
- writer = plc.io.parquet.ChunkedParquetWriter.from_options(writer_options)
1181
+ chunked_builder = cls._apply_parquet_writer_options(
1182
+ chunked_builder, options
1183
+ )
1184
+ chunked_writer_options = chunked_builder.build()
1185
+ writer = plc.io.parquet.ChunkedParquetWriter.from_options(
1186
+ chunked_writer_options, stream=df.stream
1187
+ )
953
1188
 
954
1189
  # TODO: Can be based on a heuristic that estimates chunk size
955
1190
  # from the input table size and available GPU memory.
@@ -957,6 +1192,7 @@ class Sink(IR):
957
1192
  table_chunks = plc.copying.split(
958
1193
  df.table,
959
1194
  [i * df.table.num_rows() // num_chunks for i in range(1, num_chunks)],
1195
+ stream=df.stream,
960
1196
  )
961
1197
  for chunk in table_chunks:
962
1198
  writer.write(chunk)
@@ -968,9 +1204,10 @@ class Sink(IR):
968
1204
  ).metadata(metadata)
969
1205
  builder = cls._apply_parquet_writer_options(builder, options)
970
1206
  writer_options = builder.build()
971
- plc.io.parquet.write_parquet(writer_options)
1207
+ plc.io.parquet.write_parquet(writer_options, stream=df.stream)
972
1208
 
973
1209
  @classmethod
1210
+ @log_do_evaluate
974
1211
  @nvtx_annotate_cudf_polars(message="Sink")
975
1212
  def do_evaluate(
976
1213
  cls,
@@ -980,6 +1217,8 @@ class Sink(IR):
980
1217
  parquet_options: ParquetOptions,
981
1218
  options: dict[str, Any],
982
1219
  df: DataFrame,
1220
+ *,
1221
+ context: IRExecutionContext,
983
1222
  ) -> DataFrame:
984
1223
  """Write the dataframe to a file."""
985
1224
  target = plc.io.SinkInfo([path])
@@ -993,7 +1232,7 @@ class Sink(IR):
993
1232
  elif kind == "Json":
994
1233
  cls._write_json(target, df)
995
1234
 
996
- return DataFrame([])
1235
+ return DataFrame([], stream=df.stream)
997
1236
 
998
1237
 
999
1238
  class Cache(IR):
@@ -1030,16 +1269,24 @@ class Cache(IR):
1030
1269
  return False
1031
1270
 
1032
1271
  @classmethod
1272
+ @log_do_evaluate
1033
1273
  @nvtx_annotate_cudf_polars(message="Cache")
1034
1274
  def do_evaluate(
1035
- cls, key: int, refcount: int | None, df: DataFrame
1275
+ cls,
1276
+ key: int,
1277
+ refcount: int | None,
1278
+ df: DataFrame,
1279
+ *,
1280
+ context: IRExecutionContext,
1036
1281
  ) -> DataFrame: # pragma: no cover; basic evaluation never calls this
1037
1282
  """Evaluate and return a dataframe."""
1038
1283
  # Our value has already been computed for us, so let's just
1039
1284
  # return it.
1040
1285
  return df
1041
1286
 
1042
- def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1287
+ def evaluate(
1288
+ self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
1289
+ ) -> DataFrame:
1043
1290
  """Evaluate and return a dataframe."""
1044
1291
  # We must override the recursion scheme because we don't want
1045
1292
  # to recurse if we're in the cache.
@@ -1047,7 +1294,7 @@ class Cache(IR):
1047
1294
  (result, hits) = cache[self.key]
1048
1295
  except KeyError:
1049
1296
  (value,) = self.children
1050
- result = value.evaluate(cache=cache, timer=timer)
1297
+ result = value.evaluate(cache=cache, timer=timer, context=context)
1051
1298
  cache[self.key] = (result, 0)
1052
1299
  return result
1053
1300
  else:
@@ -1093,6 +1340,42 @@ class DataFrameScan(IR):
1093
1340
  self.children = ()
1094
1341
  self._id_for_hash = random.randint(0, 2**64 - 1)
1095
1342
 
1343
+ @staticmethod
1344
+ def _reconstruct(
1345
+ schema: Schema,
1346
+ pl_df: pl.DataFrame,
1347
+ projection: Sequence[str] | None,
1348
+ id_for_hash: int,
1349
+ ) -> DataFrameScan: # pragma: no cover
1350
+ """
1351
+ Reconstruct a DataFrameScan from pickled data.
1352
+
1353
+ Parameters
1354
+ ----------
1355
+ schema: Schema
1356
+ The schema of the DataFrameScan.
1357
+ pl_df: pl.DataFrame
1358
+ The underlying polars DataFrame.
1359
+ projection: Sequence[str] | None
1360
+ The projection of the DataFrameScan.
1361
+ id_for_hash: int
1362
+ The id for hash of the DataFrameScan.
1363
+
1364
+ Returns
1365
+ -------
1366
+ The reconstructed DataFrameScan.
1367
+ """
1368
+ node = DataFrameScan(schema, pl_df._df, projection)
1369
+ node._id_for_hash = id_for_hash
1370
+ return node
1371
+
1372
+ def __reduce__(self) -> tuple[Any, ...]: # pragma: no cover
1373
+ """Pickle a DataFrameScan object."""
1374
+ return (
1375
+ self._reconstruct,
1376
+ (*self._non_child_args, self._id_for_hash),
1377
+ )
1378
+
1096
1379
  def get_hashable(self) -> Hashable:
1097
1380
  """
1098
1381
  Hashable representation of the node.
@@ -1109,20 +1392,34 @@ class DataFrameScan(IR):
1109
1392
  self.projection,
1110
1393
  )
1111
1394
 
1395
+ def is_equal(self, other: Self) -> bool:
1396
+ """Equality of DataFrameScan nodes."""
1397
+ return self is other or (
1398
+ self._id_for_hash == other._id_for_hash
1399
+ and self.schema == other.schema
1400
+ and self.projection == other.projection
1401
+ and pl.DataFrame._from_pydf(self.df).equals(
1402
+ pl.DataFrame._from_pydf(other.df)
1403
+ )
1404
+ )
1405
+
1112
1406
  @classmethod
1407
+ @log_do_evaluate
1113
1408
  @nvtx_annotate_cudf_polars(message="DataFrameScan")
1114
1409
  def do_evaluate(
1115
1410
  cls,
1116
1411
  schema: Schema,
1117
1412
  df: Any,
1118
1413
  projection: tuple[str, ...] | None,
1414
+ *,
1415
+ context: IRExecutionContext,
1119
1416
  ) -> DataFrame:
1120
1417
  """Evaluate and return a dataframe."""
1121
1418
  if projection is not None:
1122
1419
  df = df.select(projection)
1123
- df = DataFrame.from_polars(df)
1420
+ df = DataFrame.from_polars(df, stream=context.get_cuda_stream())
1124
1421
  assert all(
1125
- c.obj.type() == dtype.plc
1422
+ c.obj.type() == dtype.plc_type
1126
1423
  for c, dtype in zip(df.columns, schema.values(), strict=True)
1127
1424
  )
1128
1425
  return df
@@ -1169,21 +1466,26 @@ class Select(IR):
1169
1466
  return False
1170
1467
 
1171
1468
  @classmethod
1469
+ @log_do_evaluate
1172
1470
  @nvtx_annotate_cudf_polars(message="Select")
1173
1471
  def do_evaluate(
1174
1472
  cls,
1175
1473
  exprs: tuple[expr.NamedExpr, ...],
1176
1474
  should_broadcast: bool, # noqa: FBT001
1177
1475
  df: DataFrame,
1476
+ *,
1477
+ context: IRExecutionContext,
1178
1478
  ) -> DataFrame:
1179
1479
  """Evaluate and return a dataframe."""
1180
1480
  # Handle any broadcasting
1181
1481
  columns = [e.evaluate(df) for e in exprs]
1182
1482
  if should_broadcast:
1183
- columns = broadcast(*columns)
1184
- return DataFrame(columns)
1483
+ columns = broadcast(*columns, stream=df.stream)
1484
+ return DataFrame(columns, stream=df.stream)
1185
1485
 
1186
- def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1486
+ def evaluate(
1487
+ self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
1488
+ ) -> DataFrame:
1187
1489
  """
1188
1490
  Evaluate the Select node with special handling for fast count queries.
1189
1491
 
@@ -1195,6 +1497,8 @@ class Select(IR):
1195
1497
  timer
1196
1498
  If not None, a Timer object to record timings for the
1197
1499
  evaluation of the node.
1500
+ context
1501
+ The execution context for the node.
1198
1502
 
1199
1503
  Returns
1200
1504
  -------
@@ -1214,21 +1518,23 @@ class Select(IR):
1214
1518
  and Select._is_len_expr(self.exprs)
1215
1519
  and self.children[0].typ == "parquet"
1216
1520
  and self.children[0].predicate is None
1217
- ):
1218
- scan = self.children[0] # pragma: no cover
1219
- effective_rows = scan.fast_count() # pragma: no cover
1220
- dtype = DataType(pl.UInt32()) # pragma: no cover
1521
+ ): # pragma: no cover
1522
+ stream = context.get_cuda_stream()
1523
+ scan = self.children[0]
1524
+ effective_rows = scan.fast_count()
1525
+ dtype = DataType(pl.UInt32())
1221
1526
  col = Column(
1222
1527
  plc.Column.from_scalar(
1223
- plc.Scalar.from_py(effective_rows, dtype.plc),
1528
+ plc.Scalar.from_py(effective_rows, dtype.plc_type, stream=stream),
1224
1529
  1,
1530
+ stream=stream,
1225
1531
  ),
1226
1532
  name=self.exprs[0].name or "len",
1227
1533
  dtype=dtype,
1228
- ) # pragma: no cover
1229
- return DataFrame([col]) # pragma: no cover
1534
+ )
1535
+ return DataFrame([col], stream=stream)
1230
1536
 
1231
- return super().evaluate(cache=cache, timer=timer)
1537
+ return super().evaluate(cache=cache, timer=timer, context=context)
1232
1538
 
1233
1539
 
1234
1540
  class Reduce(IR):
@@ -1252,16 +1558,19 @@ class Reduce(IR):
1252
1558
  self._non_child_args = (self.exprs,)
1253
1559
 
1254
1560
  @classmethod
1561
+ @log_do_evaluate
1255
1562
  @nvtx_annotate_cudf_polars(message="Reduce")
1256
1563
  def do_evaluate(
1257
1564
  cls,
1258
1565
  exprs: tuple[expr.NamedExpr, ...],
1259
1566
  df: DataFrame,
1567
+ *,
1568
+ context: IRExecutionContext,
1260
1569
  ) -> DataFrame: # pragma: no cover; not exposed by polars yet
1261
1570
  """Evaluate and return a dataframe."""
1262
- columns = broadcast(*(e.evaluate(df) for e in exprs))
1571
+ columns = broadcast(*(e.evaluate(df) for e in exprs), stream=df.stream)
1263
1572
  assert all(column.size == 1 for column in columns)
1264
- return DataFrame(columns)
1573
+ return DataFrame(columns, stream=df.stream)
1265
1574
 
1266
1575
 
1267
1576
  class Rolling(IR):
@@ -1270,17 +1579,19 @@ class Rolling(IR):
1270
1579
  __slots__ = (
1271
1580
  "agg_requests",
1272
1581
  "closed_window",
1273
- "following",
1582
+ "following_ordinal",
1274
1583
  "index",
1584
+ "index_dtype",
1275
1585
  "keys",
1276
- "preceding",
1586
+ "preceding_ordinal",
1277
1587
  "zlice",
1278
1588
  )
1279
1589
  _non_child = (
1280
1590
  "schema",
1281
1591
  "index",
1282
- "preceding",
1283
- "following",
1592
+ "index_dtype",
1593
+ "preceding_ordinal",
1594
+ "following_ordinal",
1284
1595
  "closed_window",
1285
1596
  "keys",
1286
1597
  "agg_requests",
@@ -1288,10 +1599,12 @@ class Rolling(IR):
1288
1599
  )
1289
1600
  index: expr.NamedExpr
1290
1601
  """Column being rolled over."""
1291
- preceding: plc.Scalar
1292
- """Preceding window extent defining start of window."""
1293
- following: plc.Scalar
1294
- """Following window extent defining end of window."""
1602
+ index_dtype: plc.DataType
1603
+ """Datatype of the index column."""
1604
+ preceding_ordinal: int
1605
+ """Preceding window extent defining start of window as a host integer."""
1606
+ following_ordinal: int
1607
+ """Following window extent defining end of window as a host integer."""
1295
1608
  closed_window: ClosedInterval
1296
1609
  """Treatment of window endpoints."""
1297
1610
  keys: tuple[expr.NamedExpr, ...]
@@ -1305,8 +1618,9 @@ class Rolling(IR):
1305
1618
  self,
1306
1619
  schema: Schema,
1307
1620
  index: expr.NamedExpr,
1308
- preceding: plc.Scalar,
1309
- following: plc.Scalar,
1621
+ index_dtype: plc.DataType,
1622
+ preceding_ordinal: int,
1623
+ following_ordinal: int,
1310
1624
  closed_window: ClosedInterval,
1311
1625
  keys: Sequence[expr.NamedExpr],
1312
1626
  agg_requests: Sequence[expr.NamedExpr],
@@ -1315,14 +1629,15 @@ class Rolling(IR):
1315
1629
  ):
1316
1630
  self.schema = schema
1317
1631
  self.index = index
1318
- self.preceding = preceding
1319
- self.following = following
1632
+ self.index_dtype = index_dtype
1633
+ self.preceding_ordinal = preceding_ordinal
1634
+ self.following_ordinal = following_ordinal
1320
1635
  self.closed_window = closed_window
1321
1636
  self.keys = tuple(keys)
1322
1637
  self.agg_requests = tuple(agg_requests)
1323
1638
  if not all(
1324
1639
  plc.rolling.is_valid_rolling_aggregation(
1325
- agg.value.dtype.plc, agg.value.agg_request
1640
+ agg.value.dtype.plc_type, agg.value.agg_request
1326
1641
  )
1327
1642
  for agg in self.agg_requests
1328
1643
  ):
@@ -1339,8 +1654,9 @@ class Rolling(IR):
1339
1654
  self.children = (df,)
1340
1655
  self._non_child_args = (
1341
1656
  index,
1342
- preceding,
1343
- following,
1657
+ index_dtype,
1658
+ preceding_ordinal,
1659
+ following_ordinal,
1344
1660
  closed_window,
1345
1661
  keys,
1346
1662
  agg_requests,
@@ -1348,31 +1664,46 @@ class Rolling(IR):
1348
1664
  )
1349
1665
 
1350
1666
  @classmethod
1667
+ @log_do_evaluate
1351
1668
  @nvtx_annotate_cudf_polars(message="Rolling")
1352
1669
  def do_evaluate(
1353
1670
  cls,
1354
1671
  index: expr.NamedExpr,
1355
- preceding: plc.Scalar,
1356
- following: plc.Scalar,
1672
+ index_dtype: plc.DataType,
1673
+ preceding_ordinal: int,
1674
+ following_ordinal: int,
1357
1675
  closed_window: ClosedInterval,
1358
1676
  keys_in: Sequence[expr.NamedExpr],
1359
1677
  aggs: Sequence[expr.NamedExpr],
1360
1678
  zlice: Zlice | None,
1361
1679
  df: DataFrame,
1680
+ *,
1681
+ context: IRExecutionContext,
1362
1682
  ) -> DataFrame:
1363
1683
  """Evaluate and return a dataframe."""
1364
- keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1684
+ keys = broadcast(
1685
+ *(k.evaluate(df) for k in keys_in),
1686
+ target_length=df.num_rows,
1687
+ stream=df.stream,
1688
+ )
1365
1689
  orderby = index.evaluate(df)
1366
1690
  # Polars casts integral orderby to int64, but only for calculating window bounds
1367
1691
  if (
1368
1692
  plc.traits.is_integral(orderby.obj.type())
1369
1693
  and orderby.obj.type().id() != plc.TypeId.INT64
1370
1694
  ):
1371
- orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
1695
+ orderby_obj = plc.unary.cast(
1696
+ orderby.obj, plc.DataType(plc.TypeId.INT64), stream=df.stream
1697
+ )
1372
1698
  else:
1373
1699
  orderby_obj = orderby.obj
1700
+
1701
+ preceding_scalar, following_scalar = offsets_to_windows(
1702
+ index_dtype, preceding_ordinal, following_ordinal, stream=df.stream
1703
+ )
1704
+
1374
1705
  preceding_window, following_window = range_window_bounds(
1375
- preceding, following, closed_window
1706
+ preceding_scalar, following_scalar, closed_window
1376
1707
  )
1377
1708
  if orderby.obj.null_count() != 0:
1378
1709
  raise RuntimeError(
@@ -1383,12 +1714,17 @@ class Rolling(IR):
1383
1714
  table = plc.Table([*(k.obj for k in keys), orderby_obj])
1384
1715
  n = table.num_columns()
1385
1716
  if not plc.sorting.is_sorted(
1386
- table, [plc.types.Order.ASCENDING] * n, [plc.types.NullOrder.BEFORE] * n
1717
+ table,
1718
+ [plc.types.Order.ASCENDING] * n,
1719
+ [plc.types.NullOrder.BEFORE] * n,
1720
+ stream=df.stream,
1387
1721
  ):
1388
1722
  raise RuntimeError("Input for grouped rolling is not sorted")
1389
1723
  else:
1390
1724
  if not orderby.check_sorted(
1391
- order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
1725
+ order=plc.types.Order.ASCENDING,
1726
+ null_order=plc.types.NullOrder.BEFORE,
1727
+ stream=df.stream,
1392
1728
  ):
1393
1729
  raise RuntimeError(
1394
1730
  f"Index column '{index.name}' in rolling is not sorted, please sort first"
@@ -1401,6 +1737,7 @@ class Rolling(IR):
1401
1737
  preceding_window,
1402
1738
  following_window,
1403
1739
  [rolling.to_request(request.value, orderby, df) for request in aggs],
1740
+ stream=df.stream,
1404
1741
  )
1405
1742
  return DataFrame(
1406
1743
  itertools.chain(
@@ -1410,7 +1747,8 @@ class Rolling(IR):
1410
1747
  Column(col, name=request.name, dtype=request.value.dtype)
1411
1748
  for col, request in zip(values.columns(), aggs, strict=True)
1412
1749
  ),
1413
- )
1750
+ ),
1751
+ stream=df.stream,
1414
1752
  ).slice(zlice)
1415
1753
 
1416
1754
 
@@ -1472,6 +1810,7 @@ class GroupBy(IR):
1472
1810
  )
1473
1811
 
1474
1812
  @classmethod
1813
+ @log_do_evaluate
1475
1814
  @nvtx_annotate_cudf_polars(message="GroupBy")
1476
1815
  def do_evaluate(
1477
1816
  cls,
@@ -1481,9 +1820,15 @@ class GroupBy(IR):
1481
1820
  maintain_order: bool, # noqa: FBT001
1482
1821
  zlice: Zlice | None,
1483
1822
  df: DataFrame,
1823
+ *,
1824
+ context: IRExecutionContext,
1484
1825
  ) -> DataFrame:
1485
1826
  """Evaluate and return a dataframe."""
1486
- keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1827
+ keys = broadcast(
1828
+ *(k.evaluate(df) for k in keys_in),
1829
+ target_length=df.num_rows,
1830
+ stream=df.stream,
1831
+ )
1487
1832
  sorted = (
1488
1833
  plc.types.Sorted.YES
1489
1834
  if all(k.is_sorted for k in keys)
@@ -1512,10 +1857,15 @@ class GroupBy(IR):
1512
1857
  col = child.evaluate(df, context=ExecutionContext.GROUPBY).obj
1513
1858
  else:
1514
1859
  # Anything else, we pre-evaluate
1515
- col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
1860
+ column = value.evaluate(df, context=ExecutionContext.GROUPBY)
1861
+ if column.size != keys[0].size:
1862
+ column = broadcast(
1863
+ column, target_length=keys[0].size, stream=df.stream
1864
+ )[0]
1865
+ col = column.obj
1516
1866
  requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
1517
1867
  names.append(name)
1518
- group_keys, raw_tables = grouper.aggregate(requests)
1868
+ group_keys, raw_tables = grouper.aggregate(requests, stream=df.stream)
1519
1869
  results = [
1520
1870
  Column(column, name=name, dtype=schema[name])
1521
1871
  for name, column, request in zip(
@@ -1529,7 +1879,7 @@ class GroupBy(IR):
1529
1879
  Column(grouped_key, name=key.name, dtype=key.dtype)
1530
1880
  for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
1531
1881
  ]
1532
- broadcasted = broadcast(*result_keys, *results)
1882
+ broadcasted = broadcast(*result_keys, *results, stream=df.stream)
1533
1883
  # Handle order preservation of groups
1534
1884
  if maintain_order and not sorted:
1535
1885
  # The order we want
@@ -1539,6 +1889,7 @@ class GroupBy(IR):
1539
1889
  plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
1540
1890
  plc.types.NullEquality.EQUAL,
1541
1891
  plc.types.NanEquality.ALL_EQUAL,
1892
+ stream=df.stream,
1542
1893
  )
1543
1894
  # The order we have
1544
1895
  have = plc.Table([key.obj for key in broadcasted[: len(keys)]])
@@ -1546,7 +1897,7 @@ class GroupBy(IR):
1546
1897
  # We know an inner join is OK because by construction
1547
1898
  # want and have are permutations of each other.
1548
1899
  left_order, right_order = plc.join.inner_join(
1549
- want, have, plc.types.NullEquality.EQUAL
1900
+ want, have, plc.types.NullEquality.EQUAL, stream=df.stream
1550
1901
  )
1551
1902
  # Now left_order is an arbitrary permutation of the ordering we
1552
1903
  # want, and right_order is a matching permutation of the ordering
@@ -1559,11 +1910,13 @@ class GroupBy(IR):
1559
1910
  plc.Table([left_order]),
1560
1911
  [plc.types.Order.ASCENDING],
1561
1912
  [plc.types.NullOrder.AFTER],
1913
+ stream=df.stream,
1562
1914
  ).columns()
1563
1915
  ordered_table = plc.copying.gather(
1564
1916
  plc.Table([col.obj for col in broadcasted]),
1565
1917
  right_order,
1566
1918
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1919
+ stream=df.stream,
1567
1920
  )
1568
1921
  broadcasted = [
1569
1922
  Column(reordered, name=old.name, dtype=old.dtype)
@@ -1571,7 +1924,126 @@ class GroupBy(IR):
1571
1924
  ordered_table.columns(), broadcasted, strict=True
1572
1925
  )
1573
1926
  ]
1574
- return DataFrame(broadcasted).slice(zlice)
1927
+ return DataFrame(broadcasted, stream=df.stream).slice(zlice)
1928
+
1929
+
1930
+ def _strip_predicate_casts(node: expr.Expr) -> expr.Expr:
1931
+ if isinstance(node, expr.Cast):
1932
+ (child,) = node.children
1933
+ child = _strip_predicate_casts(child)
1934
+
1935
+ src = child.dtype
1936
+ dst = node.dtype
1937
+
1938
+ if plc.traits.is_fixed_point(src.plc_type) or plc.traits.is_fixed_point(
1939
+ dst.plc_type
1940
+ ):
1941
+ return child
1942
+
1943
+ if (
1944
+ not POLARS_VERSION_LT_134
1945
+ and isinstance(child, expr.ColRef)
1946
+ and (
1947
+ (
1948
+ plc.traits.is_floating_point(src.plc_type)
1949
+ and plc.traits.is_floating_point(dst.plc_type)
1950
+ )
1951
+ or (
1952
+ plc.traits.is_integral(src.plc_type)
1953
+ and plc.traits.is_integral(dst.plc_type)
1954
+ and src.plc_type.id() == dst.plc_type.id()
1955
+ )
1956
+ )
1957
+ ):
1958
+ return child
1959
+
1960
+ if not node.children:
1961
+ return node
1962
+ return node.reconstruct([_strip_predicate_casts(child) for child in node.children])
1963
+
1964
+
1965
+ def _add_cast(
1966
+ target: DataType,
1967
+ side: expr.ColRef,
1968
+ left_casts: dict[str, DataType],
1969
+ right_casts: dict[str, DataType],
1970
+ ) -> None:
1971
+ (col,) = side.children
1972
+ assert isinstance(col, expr.Col)
1973
+ casts = (
1974
+ left_casts if side.table_ref == plc_expr.TableReference.LEFT else right_casts
1975
+ )
1976
+ casts[col.name] = target
1977
+
1978
+
1979
+ def _align_decimal_binop_types(
1980
+ left_expr: expr.ColRef,
1981
+ right_expr: expr.ColRef,
1982
+ left_casts: dict[str, DataType],
1983
+ right_casts: dict[str, DataType],
1984
+ ) -> None:
1985
+ left_type, right_type = left_expr.dtype, right_expr.dtype
1986
+
1987
+ if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
1988
+ right_type.plc_type
1989
+ ):
1990
+ target = DataType.common_decimal_dtype(left_type, right_type)
1991
+
1992
+ if left_type.id() != target.id() or left_type.scale() != target.scale():
1993
+ _add_cast(target, left_expr, left_casts, right_casts)
1994
+
1995
+ if right_type.id() != target.id() or right_type.scale() != target.scale():
1996
+ _add_cast(target, right_expr, left_casts, right_casts)
1997
+
1998
+ elif (
1999
+ plc.traits.is_fixed_point(left_type.plc_type)
2000
+ and plc.traits.is_floating_point(right_type.plc_type)
2001
+ ) or (
2002
+ plc.traits.is_fixed_point(right_type.plc_type)
2003
+ and plc.traits.is_floating_point(left_type.plc_type)
2004
+ ):
2005
+ is_decimal_left = plc.traits.is_fixed_point(left_type.plc_type)
2006
+ decimal_expr, float_expr = (
2007
+ (left_expr, right_expr) if is_decimal_left else (right_expr, left_expr)
2008
+ )
2009
+ _add_cast(decimal_expr.dtype, float_expr, left_casts, right_casts)
2010
+
2011
+
2012
+ def _collect_decimal_binop_casts(
2013
+ predicate: expr.Expr,
2014
+ ) -> tuple[dict[str, DataType], dict[str, DataType]]:
2015
+ left_casts: dict[str, DataType] = {}
2016
+ right_casts: dict[str, DataType] = {}
2017
+
2018
+ def _walk(node: expr.Expr) -> None:
2019
+ if isinstance(node, expr.BinOp) and node.op in _BINOPS:
2020
+ left_expr, right_expr = node.children
2021
+ if isinstance(left_expr, expr.ColRef) and isinstance(
2022
+ right_expr, expr.ColRef
2023
+ ):
2024
+ _align_decimal_binop_types(
2025
+ left_expr, right_expr, left_casts, right_casts
2026
+ )
2027
+ for child in node.children:
2028
+ _walk(child)
2029
+
2030
+ _walk(predicate)
2031
+ return left_casts, right_casts
2032
+
2033
+
2034
+ def _apply_casts(df: DataFrame, casts: dict[str, DataType]) -> DataFrame:
2035
+ if not casts:
2036
+ return df
2037
+
2038
+ columns = []
2039
+ for col in df.columns:
2040
+ target = casts.get(col.name)
2041
+ if target is None:
2042
+ columns.append(Column(col.obj, dtype=col.dtype, name=col.name))
2043
+ else:
2044
+ casted = col.astype(target, stream=df.stream)
2045
+ columns.append(Column(casted.obj, dtype=casted.dtype, name=col.name))
2046
+ return DataFrame(columns, stream=df.stream)
1575
2047
 
1576
2048
 
1577
2049
  class ConditionalJoin(IR):
@@ -1585,7 +2057,14 @@ class ConditionalJoin(IR):
1585
2057
 
1586
2058
  def __init__(self, predicate: expr.Expr):
1587
2059
  self.predicate = predicate
1588
- self.ast = to_ast(predicate)
2060
+ stream = get_cuda_stream()
2061
+ ast_result = to_ast(predicate, stream=stream)
2062
+ stream.synchronize()
2063
+ if ast_result is None:
2064
+ raise NotImplementedError(
2065
+ f"Conditional join with predicate {predicate}"
2066
+ ) # pragma: no cover; polars never delivers expressions we can't handle
2067
+ self.ast = ast_result
1589
2068
 
1590
2069
  def __reduce__(self) -> tuple[Any, ...]:
1591
2070
  """Pickle a Predicate object."""
@@ -1598,8 +2077,9 @@ class ConditionalJoin(IR):
1598
2077
  options: tuple[
1599
2078
  tuple[
1600
2079
  str,
1601
- pl_expr.Operator | Iterable[pl_expr.Operator],
1602
- ],
2080
+ polars._expr_nodes.Operator | Iterable[polars._expr_nodes.Operator],
2081
+ ]
2082
+ | None,
1603
2083
  bool,
1604
2084
  Zlice | None,
1605
2085
  str,
@@ -1620,7 +2100,14 @@ class ConditionalJoin(IR):
1620
2100
  self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
1621
2101
  ) -> None:
1622
2102
  self.schema = schema
2103
+ predicate = _strip_predicate_casts(predicate)
1623
2104
  self.predicate = predicate
2105
+ # options[0] is a tuple[str, Operator, ...]
2106
+ # The Operator class can't be pickled, but we don't use it anyway so
2107
+ # just throw that away
2108
+ if options[0] is not None:
2109
+ options = (None, *options[1:])
2110
+
1624
2111
  self.options = options
1625
2112
  self.children = (left, right)
1626
2113
  predicate_wrapper = self.Predicate(predicate)
@@ -1629,51 +2116,64 @@ class ConditionalJoin(IR):
1629
2116
  assert not nulls_equal
1630
2117
  assert not coalesce
1631
2118
  assert maintain_order == "none"
1632
- if predicate_wrapper.ast is None:
1633
- raise NotImplementedError(
1634
- f"Conditional join with predicate {predicate}"
1635
- ) # pragma: no cover; polars never delivers expressions we can't handle
1636
- self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
2119
+ self._non_child_args = (predicate_wrapper, options)
1637
2120
 
1638
2121
  @classmethod
2122
+ @log_do_evaluate
1639
2123
  @nvtx_annotate_cudf_polars(message="ConditionalJoin")
1640
2124
  def do_evaluate(
1641
2125
  cls,
1642
2126
  predicate_wrapper: Predicate,
1643
- zlice: Zlice | None,
1644
- suffix: str,
1645
- maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
2127
+ options: tuple,
1646
2128
  left: DataFrame,
1647
2129
  right: DataFrame,
2130
+ *,
2131
+ context: IRExecutionContext,
1648
2132
  ) -> DataFrame:
1649
2133
  """Evaluate and return a dataframe."""
1650
- lg, rg = plc.join.conditional_inner_join(
1651
- left.table,
1652
- right.table,
1653
- predicate_wrapper.ast,
1654
- )
1655
- left = DataFrame.from_table(
1656
- plc.copying.gather(
1657
- left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
1658
- ),
1659
- left.column_names,
1660
- left.dtypes,
1661
- )
1662
- right = DataFrame.from_table(
1663
- plc.copying.gather(
1664
- right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
1665
- ),
1666
- right.column_names,
1667
- right.dtypes,
1668
- )
1669
- right = right.rename_columns(
1670
- {
1671
- name: f"{name}{suffix}"
1672
- for name in right.column_names
1673
- if name in left.column_names_set
1674
- }
1675
- )
1676
- result = left.with_columns(right.columns)
2134
+ with context.stream_ordered_after(left, right) as stream:
2135
+ left_casts, right_casts = _collect_decimal_binop_casts(
2136
+ predicate_wrapper.predicate
2137
+ )
2138
+ _, _, zlice, suffix, _, _ = options
2139
+
2140
+ lg, rg = plc.join.conditional_inner_join(
2141
+ _apply_casts(left, left_casts).table,
2142
+ _apply_casts(right, right_casts).table,
2143
+ predicate_wrapper.ast,
2144
+ stream=stream,
2145
+ )
2146
+ left_result = DataFrame.from_table(
2147
+ plc.copying.gather(
2148
+ left.table,
2149
+ lg,
2150
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
2151
+ stream=stream,
2152
+ ),
2153
+ left.column_names,
2154
+ left.dtypes,
2155
+ stream=stream,
2156
+ )
2157
+ right_result = DataFrame.from_table(
2158
+ plc.copying.gather(
2159
+ right.table,
2160
+ rg,
2161
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
2162
+ stream=stream,
2163
+ ),
2164
+ right.column_names,
2165
+ right.dtypes,
2166
+ stream=stream,
2167
+ )
2168
+ right_result = right_result.rename_columns(
2169
+ {
2170
+ name: f"{name}{suffix}"
2171
+ for name in right.column_names
2172
+ if name in left.column_names_set
2173
+ }
2174
+ )
2175
+ result = left_result.with_columns(right_result.columns, stream=stream)
2176
+
1677
2177
  return result.slice(zlice)
1678
2178
 
1679
2179
 
@@ -1704,6 +2204,19 @@ class Join(IR):
1704
2204
  - maintain_order: which DataFrame row order to preserve, if any
1705
2205
  """
1706
2206
 
2207
+ SWAPPED_ORDER: ClassVar[
2208
+ dict[
2209
+ Literal["none", "left", "right", "left_right", "right_left"],
2210
+ Literal["none", "left", "right", "left_right", "right_left"],
2211
+ ]
2212
+ ] = {
2213
+ "none": "none",
2214
+ "left": "right",
2215
+ "right": "left",
2216
+ "left_right": "right_left",
2217
+ "right_left": "left_right",
2218
+ }
2219
+
1707
2220
  def __init__(
1708
2221
  self,
1709
2222
  schema: Schema,
@@ -1719,9 +2232,6 @@ class Join(IR):
1719
2232
  self.options = options
1720
2233
  self.children = (left, right)
1721
2234
  self._non_child_args = (self.left_on, self.right_on, self.options)
1722
- # TODO: Implement maintain_order
1723
- if options[5] != "none":
1724
- raise NotImplementedError("maintain_order not implemented yet")
1725
2235
 
1726
2236
  @staticmethod
1727
2237
  @cache
@@ -1770,6 +2280,9 @@ class Join(IR):
1770
2280
  right_rows: int,
1771
2281
  rg: plc.Column,
1772
2282
  right_policy: plc.copying.OutOfBoundsPolicy,
2283
+ *,
2284
+ left_primary: bool = True,
2285
+ stream: Stream,
1773
2286
  ) -> list[plc.Column]:
1774
2287
  """
1775
2288
  Reorder gather maps to satisfy polars join order restrictions.
@@ -1788,30 +2301,70 @@ class Join(IR):
1788
2301
  Right gather map
1789
2302
  right_policy
1790
2303
  Nullify policy for right map
2304
+ left_primary
2305
+ Whether to preserve the left input row order first, and which
2306
+ input stream to use for the primary sort.
2307
+ Defaults to True.
2308
+ stream
2309
+ CUDA stream used for device memory operations and kernel launches.
1791
2310
 
1792
2311
  Returns
1793
2312
  -------
1794
- list of reordered left and right gather maps.
2313
+ list[plc.Column]
2314
+ Reordered left and right gather maps.
1795
2315
 
1796
2316
  Notes
1797
2317
  -----
1798
- For a left join, the polars result preserves the order of the
1799
- left keys, and is stable wrt the right keys. For all other
1800
- joins, there is no order obligation.
2318
+ When ``left_primary`` is True, the pair of gather maps is stably sorted by
2319
+ the original row order of the left side, breaking ties by the right side.
2320
+ And vice versa when ``left_primary`` is False.
1801
2321
  """
1802
- init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
1803
- step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
1804
- left_order = plc.copying.gather(
1805
- plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
1806
- )
1807
- right_order = plc.copying.gather(
1808
- plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
2322
+ init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=stream)
2323
+ step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=stream)
2324
+
2325
+ (left_order_col,) = plc.copying.gather(
2326
+ plc.Table(
2327
+ [
2328
+ plc.filling.sequence(
2329
+ left_rows,
2330
+ init,
2331
+ step,
2332
+ stream=stream,
2333
+ )
2334
+ ]
2335
+ ),
2336
+ lg,
2337
+ left_policy,
2338
+ stream=stream,
2339
+ ).columns()
2340
+ (right_order_col,) = plc.copying.gather(
2341
+ plc.Table(
2342
+ [
2343
+ plc.filling.sequence(
2344
+ right_rows,
2345
+ init,
2346
+ step,
2347
+ stream=stream,
2348
+ )
2349
+ ]
2350
+ ),
2351
+ rg,
2352
+ right_policy,
2353
+ stream=stream,
2354
+ ).columns()
2355
+
2356
+ keys = (
2357
+ plc.Table([left_order_col, right_order_col])
2358
+ if left_primary
2359
+ else plc.Table([right_order_col, left_order_col])
1809
2360
  )
2361
+
1810
2362
  return plc.sorting.stable_sort_by_key(
1811
2363
  plc.Table([lg, rg]),
1812
- plc.Table([*left_order.columns(), *right_order.columns()]),
2364
+ keys,
1813
2365
  [plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
1814
2366
  [plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
2367
+ stream=stream,
1815
2368
  ).columns()
1816
2369
 
1817
2370
  @staticmethod
@@ -1822,31 +2375,35 @@ class Join(IR):
1822
2375
  left: bool = True,
1823
2376
  empty: bool = False,
1824
2377
  rename: Callable[[str], str] = lambda name: name,
2378
+ stream: Stream,
1825
2379
  ) -> list[Column]:
1826
2380
  if empty:
1827
2381
  return [
1828
2382
  Column(
1829
- plc.column_factories.make_empty_column(col.dtype.plc),
2383
+ plc.column_factories.make_empty_column(
2384
+ col.dtype.plc_type, stream=stream
2385
+ ),
1830
2386
  col.dtype,
1831
2387
  name=rename(col.name),
1832
2388
  )
1833
2389
  for col in template
1834
2390
  ]
1835
2391
 
1836
- columns = [
2392
+ result = [
1837
2393
  Column(new, col.dtype, name=rename(col.name))
1838
2394
  for new, col in zip(columns, template, strict=True)
1839
2395
  ]
1840
2396
 
1841
2397
  if left:
1842
- columns = [
2398
+ result = [
1843
2399
  col.sorted_like(orig)
1844
- for col, orig in zip(columns, template, strict=True)
2400
+ for col, orig in zip(result, template, strict=True)
1845
2401
  ]
1846
2402
 
1847
- return columns
2403
+ return result
1848
2404
 
1849
2405
  @classmethod
2406
+ @log_do_evaluate
1850
2407
  @nvtx_annotate_cudf_polars(message="Join")
1851
2408
  def do_evaluate(
1852
2409
  cls,
@@ -1862,112 +2419,168 @@ class Join(IR):
1862
2419
  ],
1863
2420
  left: DataFrame,
1864
2421
  right: DataFrame,
2422
+ *,
2423
+ context: IRExecutionContext,
1865
2424
  ) -> DataFrame:
1866
2425
  """Evaluate and return a dataframe."""
1867
- how, nulls_equal, zlice, suffix, coalesce, _ = options
1868
- if how == "Cross":
1869
- # Separate implementation, since cross_join returns the
1870
- # result, not the gather maps
1871
- if right.num_rows == 0:
1872
- left_cols = Join._build_columns([], left.columns, empty=True)
1873
- right_cols = Join._build_columns(
1874
- [],
1875
- right.columns,
1876
- left=False,
1877
- empty=True,
1878
- rename=lambda name: name
1879
- if name not in left.column_names_set
1880
- else f"{name}{suffix}",
1881
- )
1882
- return DataFrame([*left_cols, *right_cols])
2426
+ with context.stream_ordered_after(left, right) as stream:
2427
+ how, nulls_equal, zlice, suffix, coalesce, maintain_order = options
2428
+ if how == "Cross":
2429
+ # Separate implementation, since cross_join returns the
2430
+ # result, not the gather maps
2431
+ if right.num_rows == 0:
2432
+ left_cols = Join._build_columns(
2433
+ [], left.columns, empty=True, stream=stream
2434
+ )
2435
+ right_cols = Join._build_columns(
2436
+ [],
2437
+ right.columns,
2438
+ left=False,
2439
+ empty=True,
2440
+ rename=lambda name: name
2441
+ if name not in left.column_names_set
2442
+ else f"{name}{suffix}",
2443
+ stream=stream,
2444
+ )
2445
+ result = DataFrame([*left_cols, *right_cols], stream=stream)
2446
+ else:
2447
+ columns = plc.join.cross_join(
2448
+ left.table, right.table, stream=stream
2449
+ ).columns()
2450
+ left_cols = Join._build_columns(
2451
+ columns[: left.num_columns], left.columns, stream=stream
2452
+ )
2453
+ right_cols = Join._build_columns(
2454
+ columns[left.num_columns :],
2455
+ right.columns,
2456
+ rename=lambda name: name
2457
+ if name not in left.column_names_set
2458
+ else f"{name}{suffix}",
2459
+ left=False,
2460
+ stream=stream,
2461
+ )
2462
+ result = DataFrame([*left_cols, *right_cols], stream=stream).slice(
2463
+ zlice
2464
+ )
1883
2465
 
1884
- columns = plc.join.cross_join(left.table, right.table).columns()
1885
- left_cols = Join._build_columns(
1886
- columns[: left.num_columns],
1887
- left.columns,
1888
- )
1889
- right_cols = Join._build_columns(
1890
- columns[left.num_columns :],
1891
- right.columns,
1892
- rename=lambda name: name
1893
- if name not in left.column_names_set
1894
- else f"{name}{suffix}",
1895
- left=False,
1896
- )
1897
- return DataFrame([*left_cols, *right_cols]).slice(zlice)
1898
- # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
1899
- left_on = DataFrame(broadcast(*(e.evaluate(left) for e in left_on_exprs)))
1900
- right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
1901
- null_equality = (
1902
- plc.types.NullEquality.EQUAL
1903
- if nulls_equal
1904
- else plc.types.NullEquality.UNEQUAL
1905
- )
1906
- join_fn, left_policy, right_policy = cls._joiners(how)
1907
- if right_policy is None:
1908
- # Semi join
1909
- lg = join_fn(left_on.table, right_on.table, null_equality)
1910
- table = plc.copying.gather(left.table, lg, left_policy)
1911
- result = DataFrame.from_table(table, left.column_names, left.dtypes)
1912
- else:
1913
- if how == "Right":
1914
- # Right join is a left join with the tables swapped
1915
- left, right = right, left
1916
- left_on, right_on = right_on, left_on
1917
- lg, rg = join_fn(left_on.table, right_on.table, null_equality)
1918
- if how == "Left" or how == "Right":
1919
- # Order of left table is preserved
1920
- lg, rg = cls._reorder_maps(
1921
- left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
2466
+ else:
2467
+ # how != "Cross"
2468
+ # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
2469
+ left_on = DataFrame(
2470
+ broadcast(
2471
+ *(e.evaluate(left) for e in left_on_exprs), stream=stream
2472
+ ),
2473
+ stream=stream,
2474
+ )
2475
+ right_on = DataFrame(
2476
+ broadcast(
2477
+ *(e.evaluate(right) for e in right_on_exprs), stream=stream
2478
+ ),
2479
+ stream=stream,
2480
+ )
2481
+ null_equality = (
2482
+ plc.types.NullEquality.EQUAL
2483
+ if nulls_equal
2484
+ else plc.types.NullEquality.UNEQUAL
1922
2485
  )
1923
- if coalesce:
1924
- if how == "Full":
1925
- # In this case, keys must be column references,
1926
- # possibly with dtype casting. We should use them in
1927
- # preference to the columns from the original tables.
1928
- left = left.with_columns(left_on.columns, replace_only=True)
1929
- right = right.with_columns(right_on.columns, replace_only=True)
2486
+ join_fn, left_policy, right_policy = cls._joiners(how)
2487
+ if right_policy is None:
2488
+ # Semi join
2489
+ lg = join_fn(left_on.table, right_on.table, null_equality, stream)
2490
+ table = plc.copying.gather(
2491
+ left.table, lg, left_policy, stream=stream
2492
+ )
2493
+ result = DataFrame.from_table(
2494
+ table, left.column_names, left.dtypes, stream=stream
2495
+ )
1930
2496
  else:
1931
- right = right.discard_columns(right_on.column_names_set)
1932
- left = DataFrame.from_table(
1933
- plc.copying.gather(left.table, lg, left_policy),
1934
- left.column_names,
1935
- left.dtypes,
1936
- )
1937
- right = DataFrame.from_table(
1938
- plc.copying.gather(right.table, rg, right_policy),
1939
- right.column_names,
1940
- right.dtypes,
1941
- )
1942
- if coalesce and how == "Full":
1943
- left = left.with_columns(
1944
- (
1945
- Column(
1946
- plc.replace.replace_nulls(left_col.obj, right_col.obj),
1947
- name=left_col.name,
1948
- dtype=left_col.dtype,
2497
+ if how == "Right":
2498
+ # Right join is a left join with the tables swapped
2499
+ left, right = right, left
2500
+ left_on, right_on = right_on, left_on
2501
+ maintain_order = Join.SWAPPED_ORDER[maintain_order]
2502
+
2503
+ lg, rg = join_fn(
2504
+ left_on.table, right_on.table, null_equality, stream=stream
2505
+ )
2506
+ if (
2507
+ how in ("Inner", "Left", "Right", "Full")
2508
+ and maintain_order != "none"
2509
+ ):
2510
+ lg, rg = cls._reorder_maps(
2511
+ left.num_rows,
2512
+ lg,
2513
+ left_policy,
2514
+ right.num_rows,
2515
+ rg,
2516
+ right_policy,
2517
+ left_primary=maintain_order.startswith("left"),
2518
+ stream=stream,
1949
2519
  )
1950
- for left_col, right_col in zip(
1951
- left.select_columns(left_on.column_names_set),
1952
- right.select_columns(right_on.column_names_set),
1953
- strict=True,
2520
+ if coalesce:
2521
+ if how == "Full":
2522
+ # In this case, keys must be column references,
2523
+ # possibly with dtype casting. We should use them in
2524
+ # preference to the columns from the original tables.
2525
+
2526
+ # We need to specify `stream` here. We know that `{left,right}_on`
2527
+ # is valid on `stream`, which is ordered after `{left,right}.stream`.
2528
+ left = left.with_columns(
2529
+ left_on.columns, replace_only=True, stream=stream
2530
+ )
2531
+ right = right.with_columns(
2532
+ right_on.columns, replace_only=True, stream=stream
2533
+ )
2534
+ else:
2535
+ right = right.discard_columns(right_on.column_names_set)
2536
+ left = DataFrame.from_table(
2537
+ plc.copying.gather(left.table, lg, left_policy, stream=stream),
2538
+ left.column_names,
2539
+ left.dtypes,
2540
+ stream=stream,
2541
+ )
2542
+ right = DataFrame.from_table(
2543
+ plc.copying.gather(
2544
+ right.table, rg, right_policy, stream=stream
2545
+ ),
2546
+ right.column_names,
2547
+ right.dtypes,
2548
+ stream=stream,
2549
+ )
2550
+ if coalesce and how == "Full":
2551
+ left = left.with_columns(
2552
+ (
2553
+ Column(
2554
+ plc.replace.replace_nulls(
2555
+ left_col.obj, right_col.obj, stream=stream
2556
+ ),
2557
+ name=left_col.name,
2558
+ dtype=left_col.dtype,
2559
+ )
2560
+ for left_col, right_col in zip(
2561
+ left.select_columns(left_on.column_names_set),
2562
+ right.select_columns(right_on.column_names_set),
2563
+ strict=True,
2564
+ )
2565
+ ),
2566
+ replace_only=True,
2567
+ stream=stream,
1954
2568
  )
1955
- ),
1956
- replace_only=True,
1957
- )
1958
- right = right.discard_columns(right_on.column_names_set)
1959
- if how == "Right":
1960
- # Undo the swap for right join before gluing together.
1961
- left, right = right, left
1962
- right = right.rename_columns(
1963
- {
1964
- name: f"{name}{suffix}"
1965
- for name in right.column_names
1966
- if name in left.column_names_set
1967
- }
1968
- )
1969
- result = left.with_columns(right.columns)
1970
- return result.slice(zlice)
2569
+ right = right.discard_columns(right_on.column_names_set)
2570
+ if how == "Right":
2571
+ # Undo the swap for right join before gluing together.
2572
+ left, right = right, left
2573
+ right = right.rename_columns(
2574
+ {
2575
+ name: f"{name}{suffix}"
2576
+ for name in right.column_names
2577
+ if name in left.column_names_set
2578
+ }
2579
+ )
2580
+ result = left.with_columns(right.columns, stream=stream)
2581
+ result = result.slice(zlice)
2582
+
2583
+ return result
1971
2584
 
1972
2585
 
1973
2586
  class HStack(IR):
@@ -1992,18 +2605,23 @@ class HStack(IR):
1992
2605
  self.children = (df,)
1993
2606
 
1994
2607
  @classmethod
2608
+ @log_do_evaluate
1995
2609
  @nvtx_annotate_cudf_polars(message="HStack")
1996
2610
  def do_evaluate(
1997
2611
  cls,
1998
2612
  exprs: Sequence[expr.NamedExpr],
1999
2613
  should_broadcast: bool, # noqa: FBT001
2000
2614
  df: DataFrame,
2615
+ *,
2616
+ context: IRExecutionContext,
2001
2617
  ) -> DataFrame:
2002
2618
  """Evaluate and return a dataframe."""
2003
2619
  columns = [c.evaluate(df) for c in exprs]
2004
2620
  if should_broadcast:
2005
2621
  columns = broadcast(
2006
- *columns, target_length=df.num_rows if df.num_columns != 0 else None
2622
+ *columns,
2623
+ target_length=df.num_rows if df.num_columns != 0 else None,
2624
+ stream=df.stream,
2007
2625
  )
2008
2626
  else:
2009
2627
  # Polars ensures this is true, but let's make sure nothing
@@ -2014,7 +2632,7 @@ class HStack(IR):
2014
2632
  # never be turned into a pylibcudf Table with all columns
2015
2633
  # by the Select, which is why this is safe.
2016
2634
  assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs)
2017
- return df.with_columns(columns)
2635
+ return df.with_columns(columns, stream=df.stream)
2018
2636
 
2019
2637
 
2020
2638
  class Distinct(IR):
@@ -2057,6 +2675,7 @@ class Distinct(IR):
2057
2675
  }
2058
2676
 
2059
2677
  @classmethod
2678
+ @log_do_evaluate
2060
2679
  @nvtx_annotate_cudf_polars(message="Distinct")
2061
2680
  def do_evaluate(
2062
2681
  cls,
@@ -2065,6 +2684,8 @@ class Distinct(IR):
2065
2684
  zlice: Zlice | None,
2066
2685
  stable: bool, # noqa: FBT001
2067
2686
  df: DataFrame,
2687
+ *,
2688
+ context: IRExecutionContext,
2068
2689
  ) -> DataFrame:
2069
2690
  """Evaluate and return a dataframe."""
2070
2691
  if subset is None:
@@ -2079,6 +2700,7 @@ class Distinct(IR):
2079
2700
  indices,
2080
2701
  keep,
2081
2702
  plc.types.NullEquality.EQUAL,
2703
+ stream=df.stream,
2082
2704
  )
2083
2705
  else:
2084
2706
  distinct = (
@@ -2092,13 +2714,15 @@ class Distinct(IR):
2092
2714
  keep,
2093
2715
  plc.types.NullEquality.EQUAL,
2094
2716
  plc.types.NanEquality.ALL_EQUAL,
2717
+ df.stream,
2095
2718
  )
2096
2719
  # TODO: Is this sortedness setting correct
2097
2720
  result = DataFrame(
2098
2721
  [
2099
2722
  Column(new, name=old.name, dtype=old.dtype).sorted_like(old)
2100
2723
  for new, old in zip(table.columns(), df.columns, strict=True)
2101
- ]
2724
+ ],
2725
+ stream=df.stream,
2102
2726
  )
2103
2727
  if keys_sorted or stable:
2104
2728
  result = result.sorted_like(df)
@@ -2147,6 +2771,7 @@ class Sort(IR):
2147
2771
  self.children = (df,)
2148
2772
 
2149
2773
  @classmethod
2774
+ @log_do_evaluate
2150
2775
  @nvtx_annotate_cudf_polars(message="Sort")
2151
2776
  def do_evaluate(
2152
2777
  cls,
@@ -2156,17 +2781,24 @@ class Sort(IR):
2156
2781
  stable: bool, # noqa: FBT001
2157
2782
  zlice: Zlice | None,
2158
2783
  df: DataFrame,
2784
+ *,
2785
+ context: IRExecutionContext,
2159
2786
  ) -> DataFrame:
2160
2787
  """Evaluate and return a dataframe."""
2161
- sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
2788
+ sort_keys = broadcast(
2789
+ *(k.evaluate(df) for k in by), target_length=df.num_rows, stream=df.stream
2790
+ )
2162
2791
  do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
2163
2792
  table = do_sort(
2164
2793
  df.table,
2165
2794
  plc.Table([k.obj for k in sort_keys]),
2166
2795
  list(order),
2167
2796
  list(null_order),
2797
+ stream=df.stream,
2798
+ )
2799
+ result = DataFrame.from_table(
2800
+ table, df.column_names, df.dtypes, stream=df.stream
2168
2801
  )
2169
- result = DataFrame.from_table(table, df.column_names, df.dtypes)
2170
2802
  first_key = sort_keys[0]
2171
2803
  name = by[0].name
2172
2804
  first_key_in_result = (
@@ -2197,8 +2829,11 @@ class Slice(IR):
2197
2829
  self.children = (df,)
2198
2830
 
2199
2831
  @classmethod
2832
+ @log_do_evaluate
2200
2833
  @nvtx_annotate_cudf_polars(message="Slice")
2201
- def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
2834
+ def do_evaluate(
2835
+ cls, offset: int, length: int, df: DataFrame, *, context: IRExecutionContext
2836
+ ) -> DataFrame:
2202
2837
  """Evaluate and return a dataframe."""
2203
2838
  return df.slice((offset, length))
2204
2839
 
@@ -2218,10 +2853,15 @@ class Filter(IR):
2218
2853
  self.children = (df,)
2219
2854
 
2220
2855
  @classmethod
2856
+ @log_do_evaluate
2221
2857
  @nvtx_annotate_cudf_polars(message="Filter")
2222
- def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
2858
+ def do_evaluate(
2859
+ cls, mask_expr: expr.NamedExpr, df: DataFrame, *, context: IRExecutionContext
2860
+ ) -> DataFrame:
2223
2861
  """Evaluate and return a dataframe."""
2224
- (mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
2862
+ (mask,) = broadcast(
2863
+ mask_expr.evaluate(df), target_length=df.num_rows, stream=df.stream
2864
+ )
2225
2865
  return df.filter(mask)
2226
2866
 
2227
2867
 
@@ -2237,14 +2877,19 @@ class Projection(IR):
2237
2877
  self.children = (df,)
2238
2878
 
2239
2879
  @classmethod
2880
+ @log_do_evaluate
2240
2881
  @nvtx_annotate_cudf_polars(message="Projection")
2241
- def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
2882
+ def do_evaluate(
2883
+ cls, schema: Schema, df: DataFrame, *, context: IRExecutionContext
2884
+ ) -> DataFrame:
2242
2885
  """Evaluate and return a dataframe."""
2243
2886
  # This can reorder things.
2244
2887
  columns = broadcast(
2245
- *(df.column_map[name] for name in schema), target_length=df.num_rows
2888
+ *(df.column_map[name] for name in schema),
2889
+ target_length=df.num_rows,
2890
+ stream=df.stream,
2246
2891
  )
2247
- return DataFrame(columns)
2892
+ return DataFrame(columns, stream=df.stream)
2248
2893
 
2249
2894
 
2250
2895
  class MergeSorted(IR):
@@ -2270,23 +2915,31 @@ class MergeSorted(IR):
2270
2915
  self._non_child_args = (key,)
2271
2916
 
2272
2917
  @classmethod
2918
+ @log_do_evaluate
2273
2919
  @nvtx_annotate_cudf_polars(message="MergeSorted")
2274
- def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
2920
+ def do_evaluate(
2921
+ cls, key: str, *dfs: DataFrame, context: IRExecutionContext
2922
+ ) -> DataFrame:
2275
2923
  """Evaluate and return a dataframe."""
2276
- left, right = dfs
2277
- right = right.discard_columns(right.column_names_set - left.column_names_set)
2278
- on_col_left = left.select_columns({key})[0]
2279
- on_col_right = right.select_columns({key})[0]
2280
- return DataFrame.from_table(
2281
- plc.merge.merge(
2282
- [right.table, left.table],
2283
- [left.column_names.index(key), right.column_names.index(key)],
2284
- [on_col_left.order, on_col_right.order],
2285
- [on_col_left.null_order, on_col_right.null_order],
2286
- ),
2287
- left.column_names,
2288
- left.dtypes,
2289
- )
2924
+ with context.stream_ordered_after(*dfs) as stream:
2925
+ left, right = dfs
2926
+ right = right.discard_columns(
2927
+ right.column_names_set - left.column_names_set
2928
+ )
2929
+ on_col_left = left.select_columns({key})[0]
2930
+ on_col_right = right.select_columns({key})[0]
2931
+ return DataFrame.from_table(
2932
+ plc.merge.merge(
2933
+ [right.table, left.table],
2934
+ [left.column_names.index(key), right.column_names.index(key)],
2935
+ [on_col_left.order, on_col_right.order],
2936
+ [on_col_left.null_order, on_col_right.null_order],
2937
+ stream=stream,
2938
+ ),
2939
+ left.column_names,
2940
+ left.dtypes,
2941
+ stream=stream,
2942
+ )
2290
2943
 
2291
2944
 
2292
2945
  class MapFunction(IR):
@@ -2347,7 +3000,7 @@ class MapFunction(IR):
2347
3000
  index = frozenset(indices)
2348
3001
  pivotees = [name for name in df.schema if name not in index]
2349
3002
  if not all(
2350
- dtypes.can_cast(df.schema[p].plc, self.schema[value_name].plc)
3003
+ dtypes.can_cast(df.schema[p].plc_type, self.schema[value_name].plc_type)
2351
3004
  for p in pivotees
2352
3005
  ):
2353
3006
  raise NotImplementedError(
@@ -2362,6 +3015,8 @@ class MapFunction(IR):
2362
3015
  )
2363
3016
  elif self.name == "row_index":
2364
3017
  col_name, offset = options
3018
+ if col_name in df.schema:
3019
+ raise NotImplementedError("Duplicate row index name")
2365
3020
  self.options = (col_name, offset)
2366
3021
  elif self.name == "fast_count":
2367
3022
  # TODO: Remove this once all scan types support projections
@@ -2390,9 +3045,16 @@ class MapFunction(IR):
2390
3045
  )
2391
3046
 
2392
3047
  @classmethod
3048
+ @log_do_evaluate
2393
3049
  @nvtx_annotate_cudf_polars(message="MapFunction")
2394
3050
  def do_evaluate(
2395
- cls, schema: Schema, name: str, options: Any, df: DataFrame
3051
+ cls,
3052
+ schema: Schema,
3053
+ name: str,
3054
+ options: Any,
3055
+ df: DataFrame,
3056
+ *,
3057
+ context: IRExecutionContext,
2396
3058
  ) -> DataFrame:
2397
3059
  """Evaluate and return a dataframe."""
2398
3060
  if name == "rechunk":
@@ -2409,7 +3071,10 @@ class MapFunction(IR):
2409
3071
  index = df.column_names.index(to_explode)
2410
3072
  subset = df.column_names_set - {to_explode}
2411
3073
  return DataFrame.from_table(
2412
- plc.lists.explode_outer(df.table, index), df.column_names, df.dtypes
3074
+ plc.lists.explode_outer(df.table, index, stream=df.stream),
3075
+ df.column_names,
3076
+ df.dtypes,
3077
+ stream=df.stream,
2413
3078
  ).sorted_like(df, subset=subset)
2414
3079
  elif name == "unpivot":
2415
3080
  (
@@ -2423,7 +3088,7 @@ class MapFunction(IR):
2423
3088
  index_columns = [
2424
3089
  Column(tiled, name=name, dtype=old.dtype)
2425
3090
  for tiled, name, old in zip(
2426
- plc.reshape.tile(selected.table, npiv).columns(),
3091
+ plc.reshape.tile(selected.table, npiv, stream=df.stream).columns(),
2427
3092
  indices,
2428
3093
  selected.columns,
2429
3094
  strict=True,
@@ -2434,18 +3099,23 @@ class MapFunction(IR):
2434
3099
  [
2435
3100
  plc.Column.from_arrow(
2436
3101
  pl.Series(
2437
- values=pivotees, dtype=schema[variable_name].polars
2438
- )
3102
+ values=pivotees, dtype=schema[variable_name].polars_type
3103
+ ),
3104
+ stream=df.stream,
2439
3105
  )
2440
3106
  ]
2441
3107
  ),
2442
3108
  df.num_rows,
3109
+ stream=df.stream,
2443
3110
  ).columns()
2444
3111
  value_column = plc.concatenate.concatenate(
2445
3112
  [
2446
- df.column_map[pivotee].astype(schema[value_name]).obj
3113
+ df.column_map[pivotee]
3114
+ .astype(schema[value_name], stream=df.stream)
3115
+ .obj
2447
3116
  for pivotee in pivotees
2448
- ]
3117
+ ],
3118
+ stream=df.stream,
2449
3119
  )
2450
3120
  return DataFrame(
2451
3121
  [
@@ -2454,22 +3124,23 @@ class MapFunction(IR):
2454
3124
  variable_column, name=variable_name, dtype=schema[variable_name]
2455
3125
  ),
2456
3126
  Column(value_column, name=value_name, dtype=schema[value_name]),
2457
- ]
3127
+ ],
3128
+ stream=df.stream,
2458
3129
  )
2459
3130
  elif name == "row_index":
2460
3131
  col_name, offset = options
2461
3132
  dtype = schema[col_name]
2462
- step = plc.Scalar.from_py(1, dtype.plc)
2463
- init = plc.Scalar.from_py(offset, dtype.plc)
3133
+ step = plc.Scalar.from_py(1, dtype.plc_type, stream=df.stream)
3134
+ init = plc.Scalar.from_py(offset, dtype.plc_type, stream=df.stream)
2464
3135
  index_col = Column(
2465
- plc.filling.sequence(df.num_rows, init, step),
3136
+ plc.filling.sequence(df.num_rows, init, step, stream=df.stream),
2466
3137
  is_sorted=plc.types.Sorted.YES,
2467
3138
  order=plc.types.Order.ASCENDING,
2468
3139
  null_order=plc.types.NullOrder.AFTER,
2469
3140
  name=col_name,
2470
3141
  dtype=dtype,
2471
3142
  )
2472
- return DataFrame([index_col, *df.columns])
3143
+ return DataFrame([index_col, *df.columns], stream=df.stream)
2473
3144
  else:
2474
3145
  raise AssertionError("Should never be reached") # pragma: no cover
2475
3146
 
@@ -2490,15 +3161,20 @@ class Union(IR):
2490
3161
  schema = self.children[0].schema
2491
3162
 
2492
3163
  @classmethod
3164
+ @log_do_evaluate
2493
3165
  @nvtx_annotate_cudf_polars(message="Union")
2494
- def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
3166
+ def do_evaluate(
3167
+ cls, zlice: Zlice | None, *dfs: DataFrame, context: IRExecutionContext
3168
+ ) -> DataFrame:
2495
3169
  """Evaluate and return a dataframe."""
2496
- # TODO: only evaluate what we need if we have a slice?
2497
- return DataFrame.from_table(
2498
- plc.concatenate.concatenate([df.table for df in dfs]),
2499
- dfs[0].column_names,
2500
- dfs[0].dtypes,
2501
- ).slice(zlice)
3170
+ with context.stream_ordered_after(*dfs) as stream:
3171
+ # TODO: only evaluate what we need if we have a slice?
3172
+ return DataFrame.from_table(
3173
+ plc.concatenate.concatenate([df.table for df in dfs], stream=stream),
3174
+ dfs[0].column_names,
3175
+ dfs[0].dtypes,
3176
+ stream=stream,
3177
+ ).slice(zlice)
2502
3178
 
2503
3179
 
2504
3180
  class HConcat(IR):
@@ -2519,7 +3195,9 @@ class HConcat(IR):
2519
3195
  self.children = children
2520
3196
 
2521
3197
  @staticmethod
2522
- def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
3198
+ def _extend_with_nulls(
3199
+ table: plc.Table, *, nrows: int, stream: Stream
3200
+ ) -> plc.Table:
2523
3201
  """
2524
3202
  Extend a table with nulls.
2525
3203
 
@@ -2529,6 +3207,8 @@ class HConcat(IR):
2529
3207
  Table to extend
2530
3208
  nrows
2531
3209
  Number of additional rows
3210
+ stream
3211
+ CUDA stream used for device memory operations and kernel launches
2532
3212
 
2533
3213
  Returns
2534
3214
  -------
@@ -2539,45 +3219,61 @@ class HConcat(IR):
2539
3219
  table,
2540
3220
  plc.Table(
2541
3221
  [
2542
- plc.Column.all_null_like(column, nrows)
3222
+ plc.Column.all_null_like(column, nrows, stream=stream)
2543
3223
  for column in table.columns()
2544
3224
  ]
2545
3225
  ),
2546
- ]
3226
+ ],
3227
+ stream=stream,
2547
3228
  )
2548
3229
 
2549
3230
  @classmethod
3231
+ @log_do_evaluate
2550
3232
  @nvtx_annotate_cudf_polars(message="HConcat")
2551
3233
  def do_evaluate(
2552
3234
  cls,
2553
3235
  should_broadcast: bool, # noqa: FBT001
2554
3236
  *dfs: DataFrame,
3237
+ context: IRExecutionContext,
2555
3238
  ) -> DataFrame:
2556
3239
  """Evaluate and return a dataframe."""
2557
- # Special should_broadcast case.
2558
- # Used to recombine decomposed expressions
2559
- if should_broadcast:
2560
- return DataFrame(
2561
- broadcast(*itertools.chain.from_iterable(df.columns for df in dfs))
2562
- )
2563
-
2564
- max_rows = max(df.num_rows for df in dfs)
2565
- # Horizontal concatenation extends shorter tables with nulls
2566
- return DataFrame(
2567
- itertools.chain.from_iterable(
2568
- df.columns
2569
- for df in (
2570
- df
2571
- if df.num_rows == max_rows
2572
- else DataFrame.from_table(
2573
- cls._extend_with_nulls(df.table, nrows=max_rows - df.num_rows),
2574
- df.column_names,
2575
- df.dtypes,
2576
- )
2577
- for df in dfs
3240
+ with context.stream_ordered_after(*dfs) as stream:
3241
+ # Special should_broadcast case.
3242
+ # Used to recombine decomposed expressions
3243
+ if should_broadcast:
3244
+ result = DataFrame(
3245
+ broadcast(
3246
+ *itertools.chain.from_iterable(df.columns for df in dfs),
3247
+ stream=stream,
3248
+ ),
3249
+ stream=stream,
3250
+ )
3251
+ else:
3252
+ max_rows = max(df.num_rows for df in dfs)
3253
+ # Horizontal concatenation extends shorter tables with nulls
3254
+ result = DataFrame(
3255
+ itertools.chain.from_iterable(
3256
+ df.columns
3257
+ for df in (
3258
+ df
3259
+ if df.num_rows == max_rows
3260
+ else DataFrame.from_table(
3261
+ cls._extend_with_nulls(
3262
+ df.table,
3263
+ nrows=max_rows - df.num_rows,
3264
+ stream=stream,
3265
+ ),
3266
+ df.column_names,
3267
+ df.dtypes,
3268
+ stream=stream,
3269
+ )
3270
+ for df in dfs
3271
+ )
3272
+ ),
3273
+ stream=stream,
2578
3274
  )
2579
- )
2580
- )
3275
+
3276
+ return result
2581
3277
 
2582
3278
 
2583
3279
  class Empty(IR):
@@ -2592,16 +3288,23 @@ class Empty(IR):
2592
3288
  self.children = ()
2593
3289
 
2594
3290
  @classmethod
3291
+ @log_do_evaluate
2595
3292
  @nvtx_annotate_cudf_polars(message="Empty")
2596
- def do_evaluate(cls, schema: Schema) -> DataFrame: # pragma: no cover
3293
+ def do_evaluate(
3294
+ cls, schema: Schema, *, context: IRExecutionContext
3295
+ ) -> DataFrame: # pragma: no cover
2597
3296
  """Evaluate and return a dataframe."""
3297
+ stream = context.get_cuda_stream()
2598
3298
  return DataFrame(
2599
3299
  [
2600
3300
  Column(
2601
- plc.column_factories.make_empty_column(dtype.plc),
3301
+ plc.column_factories.make_empty_column(
3302
+ dtype.plc_type, stream=stream
3303
+ ),
2602
3304
  dtype=dtype,
2603
3305
  name=name,
2604
3306
  )
2605
3307
  for name, dtype in schema.items()
2606
- ]
3308
+ ],
3309
+ stream=stream,
2607
3310
  )