cudf-polars-cu13 25.10.0__py3-none-any.whl → 25.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +32 -8
  4. cudf_polars/containers/column.py +94 -59
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +235 -102
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +9 -3
  19. cudf_polars/dsl/expressions/unary.py +117 -58
  20. cudf_polars/dsl/ir.py +923 -290
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +294 -97
  24. cudf_polars/dsl/utils/aggregations.py +34 -26
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +45 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +791 -1
  31. cudf_polars/experimental/benchmarks/utils.py +515 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/explain.py +15 -2
  35. cudf_polars/experimental/expressions.py +22 -10
  36. cudf_polars/experimental/groupby.py +23 -4
  37. cudf_polars/experimental/io.py +93 -83
  38. cudf_polars/experimental/join.py +39 -22
  39. cudf_polars/experimental/parallel.py +60 -14
  40. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  41. cudf_polars/experimental/rapidsmpf/core.py +361 -0
  42. cudf_polars/experimental/rapidsmpf/dispatch.py +150 -0
  43. cudf_polars/experimental/rapidsmpf/io.py +604 -0
  44. cudf_polars/experimental/rapidsmpf/join.py +237 -0
  45. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  46. cudf_polars/experimental/rapidsmpf/nodes.py +494 -0
  47. cudf_polars/experimental/rapidsmpf/repartition.py +151 -0
  48. cudf_polars/experimental/rapidsmpf/shuffle.py +277 -0
  49. cudf_polars/experimental/rapidsmpf/union.py +96 -0
  50. cudf_polars/experimental/rapidsmpf/utils.py +162 -0
  51. cudf_polars/experimental/repartition.py +9 -2
  52. cudf_polars/experimental/select.py +177 -14
  53. cudf_polars/experimental/shuffle.py +28 -8
  54. cudf_polars/experimental/sort.py +92 -25
  55. cudf_polars/experimental/statistics.py +24 -5
  56. cudf_polars/experimental/utils.py +25 -7
  57. cudf_polars/testing/asserts.py +13 -8
  58. cudf_polars/testing/io.py +2 -1
  59. cudf_polars/testing/plugin.py +88 -15
  60. cudf_polars/typing/__init__.py +86 -32
  61. cudf_polars/utils/config.py +406 -58
  62. cudf_polars/utils/cuda_stream.py +70 -0
  63. cudf_polars/utils/versions.py +3 -2
  64. cudf_polars_cu13-25.12.0.dist-info/METADATA +182 -0
  65. cudf_polars_cu13-25.12.0.dist-info/RECORD +104 -0
  66. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  67. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  68. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/WHEEL +0 -0
  69. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/licenses/LICENSE +0 -0
  70. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/top_level.txt +0 -0
cudf_polars/dsl/ir.py CHANGED
@@ -17,27 +17,40 @@ import itertools
17
17
  import json
18
18
  import random
19
19
  import time
20
+ from dataclasses import dataclass
20
21
  from functools import cache
21
22
  from pathlib import Path
22
- from typing import TYPE_CHECKING, Any, ClassVar
23
+ from typing import TYPE_CHECKING, Any, ClassVar, overload
23
24
 
24
25
  from typing_extensions import assert_never
25
26
 
26
27
  import polars as pl
27
28
 
28
29
  import pylibcudf as plc
30
+ from pylibcudf import expressions as plc_expr
29
31
 
30
32
  import cudf_polars.dsl.expr as expr
31
33
  from cudf_polars.containers import Column, DataFrame, DataType
34
+ from cudf_polars.containers.dataframe import NamedColumn
32
35
  from cudf_polars.dsl.expressions import rolling, unary
33
36
  from cudf_polars.dsl.expressions.base import ExecutionContext
34
37
  from cudf_polars.dsl.nodebase import Node
35
38
  from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
36
- from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
39
+ from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
37
40
  from cudf_polars.dsl.utils.reshape import broadcast
38
- from cudf_polars.dsl.utils.windows import range_window_bounds
41
+ from cudf_polars.dsl.utils.windows import (
42
+ offsets_to_windows,
43
+ range_window_bounds,
44
+ )
39
45
  from cudf_polars.utils import dtypes
40
- from cudf_polars.utils.versions import POLARS_VERSION_LT_131
46
+ from cudf_polars.utils.config import CUDAStreamPolicy
47
+ from cudf_polars.utils.cuda_stream import (
48
+ get_cuda_stream,
49
+ get_joined_cuda_stream,
50
+ get_new_cuda_stream,
51
+ join_cuda_streams,
52
+ )
53
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_131, POLARS_VERSION_LT_134
41
54
 
42
55
  if TYPE_CHECKING:
43
56
  from collections.abc import Callable, Hashable, Iterable, Sequence
@@ -45,14 +58,15 @@ if TYPE_CHECKING:
45
58
 
46
59
  from typing_extensions import Self
47
60
 
48
- from polars.polars import _expr_nodes as pl_expr
61
+ from polars import polars # type: ignore[attr-defined]
62
+
63
+ from rmm.pylibrmm.stream import Stream
49
64
 
50
65
  from cudf_polars.containers.dataframe import NamedColumn
51
66
  from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
52
- from cudf_polars.utils.config import ParquetOptions
67
+ from cudf_polars.utils.config import ConfigOptions, ParquetOptions
53
68
  from cudf_polars.utils.timer import Timer
54
69
 
55
-
56
70
  __all__ = [
57
71
  "IR",
58
72
  "Cache",
@@ -65,6 +79,7 @@ __all__ = [
65
79
  "GroupBy",
66
80
  "HConcat",
67
81
  "HStack",
82
+ "IRExecutionContext",
68
83
  "Join",
69
84
  "MapFunction",
70
85
  "MergeSorted",
@@ -81,6 +96,53 @@ __all__ = [
81
96
  ]
82
97
 
83
98
 
99
+ @dataclass(frozen=True)
100
+ class IRExecutionContext:
101
+ """
102
+ Runtime context for IR node execution.
103
+
104
+ This dataclass holds runtime information and configuration needed
105
+ during the evaluation of IR nodes.
106
+
107
+ Parameters
108
+ ----------
109
+ get_cuda_stream
110
+ A zero-argument callable that returns a CUDA stream.
111
+ """
112
+
113
+ get_cuda_stream: Callable[[], Stream]
114
+
115
+ @classmethod
116
+ def from_config_options(cls, config_options: ConfigOptions) -> IRExecutionContext:
117
+ """Create an IRExecutionContext from ConfigOptions."""
118
+ match config_options.cuda_stream_policy:
119
+ case CUDAStreamPolicy.DEFAULT:
120
+ return cls(get_cuda_stream=get_cuda_stream)
121
+ case CUDAStreamPolicy.NEW:
122
+ return cls(get_cuda_stream=get_new_cuda_stream)
123
+ case _: # pragma: no cover
124
+ raise ValueError(
125
+ f"Invalid CUDA stream policy: {config_options.cuda_stream_policy}"
126
+ )
127
+
128
+
129
+ _BINOPS = {
130
+ plc.binaryop.BinaryOperator.EQUAL,
131
+ plc.binaryop.BinaryOperator.NOT_EQUAL,
132
+ plc.binaryop.BinaryOperator.LESS,
133
+ plc.binaryop.BinaryOperator.LESS_EQUAL,
134
+ plc.binaryop.BinaryOperator.GREATER,
135
+ plc.binaryop.BinaryOperator.GREATER_EQUAL,
136
+ # TODO: Handle other binary operations as needed
137
+ }
138
+
139
+
140
+ _DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
141
+
142
+
143
+ _FLOAT_TYPES = {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
144
+
145
+
84
146
  class IR(Node["IR"]):
85
147
  """Abstract plan node, representing an unevaluated dataframe."""
86
148
 
@@ -134,7 +196,9 @@ class IR(Node["IR"]):
134
196
  translation phase should fail earlier.
135
197
  """
136
198
 
137
- def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
199
+ def evaluate(
200
+ self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
201
+ ) -> DataFrame:
138
202
  """
139
203
  Evaluate the node (recursively) and return a dataframe.
140
204
 
@@ -146,6 +210,8 @@ class IR(Node["IR"]):
146
210
  timer
147
211
  If not None, a Timer object to record timings for the
148
212
  evaluation of the node.
213
+ context
214
+ The execution context for the node.
149
215
 
150
216
  Notes
151
217
  -----
@@ -164,16 +230,19 @@ class IR(Node["IR"]):
164
230
  If evaluation fails. Ideally this should not occur, since the
165
231
  translation phase should fail earlier.
166
232
  """
167
- children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
233
+ children = [
234
+ child.evaluate(cache=cache, timer=timer, context=context)
235
+ for child in self.children
236
+ ]
168
237
  if timer is not None:
169
238
  start = time.monotonic_ns()
170
- result = self.do_evaluate(*self._non_child_args, *children)
239
+ result = self.do_evaluate(*self._non_child_args, *children, context=context)
171
240
  end = time.monotonic_ns()
172
241
  # TODO: Set better names on each class object.
173
242
  timer.store(start, end, type(self).__name__)
174
243
  return result
175
244
  else:
176
- return self.do_evaluate(*self._non_child_args, *children)
245
+ return self.do_evaluate(*self._non_child_args, *children, context=context)
177
246
 
178
247
 
179
248
  class ErrorNode(IR):
@@ -212,29 +281,93 @@ class PythonScan(IR):
212
281
  raise NotImplementedError("PythonScan not implemented")
213
282
 
214
283
 
284
+ _DECIMAL_IDS = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}
285
+
286
+ _COMPARISON_BINOPS = {
287
+ plc.binaryop.BinaryOperator.EQUAL,
288
+ plc.binaryop.BinaryOperator.NOT_EQUAL,
289
+ plc.binaryop.BinaryOperator.LESS,
290
+ plc.binaryop.BinaryOperator.LESS_EQUAL,
291
+ plc.binaryop.BinaryOperator.GREATER,
292
+ plc.binaryop.BinaryOperator.GREATER_EQUAL,
293
+ }
294
+
295
+
296
+ def _parquet_physical_types(
297
+ schema: Schema, paths: list[str], columns: list[str] | None, stream: Stream
298
+ ) -> dict[str, plc.DataType]:
299
+ # TODO: Read the physical types as cudf::data_type's using
300
+ # read_parquet_metadata or another parquet API
301
+ options = plc.io.parquet.ParquetReaderOptions.builder(
302
+ plc.io.SourceInfo(paths)
303
+ ).build()
304
+ if columns is not None:
305
+ options.set_columns(columns)
306
+ options.set_num_rows(0)
307
+ df = plc.io.parquet.read_parquet(options, stream=stream)
308
+ return dict(zip(schema.keys(), [c.type() for c in df.tbl.columns()], strict=True))
309
+
310
+
311
+ def _cast_literal_to_decimal(
312
+ side: expr.Expr, lit: expr.Literal, phys_type_map: dict[str, plc.DataType]
313
+ ) -> expr.Expr:
314
+ if isinstance(side, expr.Cast):
315
+ col = side.children[0]
316
+ assert isinstance(col, expr.Col)
317
+ name = col.name
318
+ else:
319
+ assert isinstance(side, expr.Col)
320
+ name = side.name
321
+ if (type_ := phys_type_map[name]).id() in _DECIMAL_IDS:
322
+ scale = abs(type_.scale())
323
+ return expr.Cast(side.dtype, expr.Cast(DataType(pl.Decimal(38, scale)), lit))
324
+ return lit
325
+
326
+
327
+ def _cast_literals_to_physical_types(
328
+ node: expr.Expr, phys_type_map: dict[str, plc.DataType]
329
+ ) -> expr.Expr:
330
+ if isinstance(node, expr.BinOp):
331
+ left, right = node.children
332
+ left = _cast_literals_to_physical_types(left, phys_type_map)
333
+ right = _cast_literals_to_physical_types(right, phys_type_map)
334
+ if node.op in _COMPARISON_BINOPS:
335
+ if isinstance(left, (expr.Col, expr.Cast)) and isinstance(
336
+ right, expr.Literal
337
+ ):
338
+ right = _cast_literal_to_decimal(left, right, phys_type_map)
339
+ elif isinstance(right, (expr.Col, expr.Cast)) and isinstance(
340
+ left, expr.Literal
341
+ ):
342
+ left = _cast_literal_to_decimal(right, left, phys_type_map)
343
+
344
+ return node.reconstruct([left, right])
345
+ return node
346
+
347
+
215
348
  def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
216
349
  # TODO: Alternatively set the schema of the parquet reader to decimal128
217
- plc_decimals_ids = {
218
- plc.TypeId.DECIMAL32,
219
- plc.TypeId.DECIMAL64,
220
- plc.TypeId.DECIMAL128,
221
- }
222
350
  cast_list = []
223
351
 
224
352
  for name, col in df.column_map.items():
225
353
  src = col.obj.type()
226
- dst = schema[name].plc
354
+ dst = schema[name].plc_type
355
+
227
356
  if (
228
- src.id() in plc_decimals_ids
229
- and dst.id() in plc_decimals_ids
230
- and ((src.id() != dst.id()) or (src.scale != dst.scale))
357
+ plc.traits.is_fixed_point(src)
358
+ and plc.traits.is_fixed_point(dst)
359
+ and ((src.id() != dst.id()) or (src.scale() != dst.scale()))
231
360
  ):
232
361
  cast_list.append(
233
- Column(plc.unary.cast(col.obj, dst), name=name, dtype=schema[name])
362
+ Column(
363
+ plc.unary.cast(col.obj, dst, stream=df.stream),
364
+ name=name,
365
+ dtype=schema[name],
366
+ )
234
367
  )
235
368
 
236
369
  if cast_list:
237
- df = df.with_columns(cast_list)
370
+ df = df.with_columns(cast_list, stream=df.stream)
238
371
 
239
372
  return df
240
373
 
@@ -460,13 +593,24 @@ class Scan(IR):
460
593
  Each path is repeated according to the number of rows read from it.
461
594
  """
462
595
  (filepaths,) = plc.filling.repeat(
463
- plc.Table([plc.Column.from_arrow(pl.Series(values=map(str, paths)))]),
596
+ plc.Table(
597
+ [
598
+ plc.Column.from_arrow(
599
+ pl.Series(values=map(str, paths)),
600
+ stream=df.stream,
601
+ )
602
+ ]
603
+ ),
464
604
  plc.Column.from_arrow(
465
- pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32())
605
+ pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32()),
606
+ stream=df.stream,
466
607
  ),
608
+ stream=df.stream,
467
609
  ).columns()
468
610
  dtype = DataType(pl.String())
469
- return df.with_columns([Column(filepaths, name=name, dtype=dtype)])
611
+ return df.with_columns(
612
+ [Column(filepaths, name=name, dtype=dtype)], stream=df.stream
613
+ )
470
614
 
471
615
  def fast_count(self) -> int: # pragma: no cover
472
616
  """Get the number of rows in a Parquet Scan."""
@@ -479,6 +623,7 @@ class Scan(IR):
479
623
  return max(total_rows, 0)
480
624
 
481
625
  @classmethod
626
+ @log_do_evaluate
482
627
  @nvtx_annotate_cudf_polars(message="Scan")
483
628
  def do_evaluate(
484
629
  cls,
@@ -493,8 +638,11 @@ class Scan(IR):
493
638
  include_file_paths: str | None,
494
639
  predicate: expr.NamedExpr | None,
495
640
  parquet_options: ParquetOptions,
641
+ *,
642
+ context: IRExecutionContext,
496
643
  ) -> DataFrame:
497
644
  """Evaluate and return a dataframe."""
645
+ stream = context.get_cuda_stream()
498
646
  if typ == "csv":
499
647
 
500
648
  def read_csv_header(
@@ -551,6 +699,7 @@ class Scan(IR):
551
699
  plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
552
700
  .nrows(n_rows)
553
701
  .skiprows(skiprows + skip_rows)
702
+ .skip_blank_lines(skip_blank_lines=False)
554
703
  .lineterminator(str(eol))
555
704
  .quotechar(str(quote))
556
705
  .decimal(decimal)
@@ -567,13 +716,15 @@ class Scan(IR):
567
716
  column_names = read_csv_header(path, str(sep))
568
717
  options.set_names(column_names)
569
718
  options.set_header(header)
570
- options.set_dtypes({name: dtype.plc for name, dtype in schema.items()})
719
+ options.set_dtypes(
720
+ {name: dtype.plc_type for name, dtype in schema.items()}
721
+ )
571
722
  if usecols is not None:
572
723
  options.set_use_cols_names([str(name) for name in usecols])
573
724
  options.set_na_values(null_values)
574
725
  if comment is not None:
575
726
  options.set_comment(comment)
576
- tbl_w_meta = plc.io.csv.read_csv(options)
727
+ tbl_w_meta = plc.io.csv.read_csv(options, stream=stream)
577
728
  pieces.append(tbl_w_meta)
578
729
  if include_file_paths is not None:
579
730
  seen_paths.append(p)
@@ -589,9 +740,10 @@ class Scan(IR):
589
740
  strict=True,
590
741
  )
591
742
  df = DataFrame.from_table(
592
- plc.concatenate.concatenate(list(tables)),
743
+ plc.concatenate.concatenate(list(tables), stream=stream),
593
744
  colnames,
594
745
  [schema[colname] for colname in colnames],
746
+ stream=stream,
595
747
  )
596
748
  if include_file_paths is not None:
597
749
  df = Scan.add_file_paths(
@@ -604,42 +756,50 @@ class Scan(IR):
604
756
  filters = None
605
757
  if predicate is not None and row_index is None:
606
758
  # Can't apply filters during read if we have a row index.
607
- filters = to_parquet_filter(predicate.value)
608
- options = plc.io.parquet.ParquetReaderOptions.builder(
759
+ filters = to_parquet_filter(
760
+ _cast_literals_to_physical_types(
761
+ predicate.value,
762
+ _parquet_physical_types(
763
+ schema, paths, with_columns or list(schema.keys()), stream
764
+ ),
765
+ ),
766
+ stream=stream,
767
+ )
768
+ parquet_reader_options = plc.io.parquet.ParquetReaderOptions.builder(
609
769
  plc.io.SourceInfo(paths)
610
770
  ).build()
611
771
  if with_columns is not None:
612
- options.set_columns(with_columns)
772
+ parquet_reader_options.set_columns(with_columns)
613
773
  if filters is not None:
614
- options.set_filter(filters)
774
+ parquet_reader_options.set_filter(filters)
615
775
  if n_rows != -1:
616
- options.set_num_rows(n_rows)
776
+ parquet_reader_options.set_num_rows(n_rows)
617
777
  if skip_rows != 0:
618
- options.set_skip_rows(skip_rows)
778
+ parquet_reader_options.set_skip_rows(skip_rows)
619
779
  if parquet_options.chunked:
620
780
  reader = plc.io.parquet.ChunkedParquetReader(
621
- options,
781
+ parquet_reader_options,
622
782
  chunk_read_limit=parquet_options.chunk_read_limit,
623
783
  pass_read_limit=parquet_options.pass_read_limit,
784
+ stream=stream,
624
785
  )
625
786
  chunk = reader.read_chunk()
626
- tbl = chunk.tbl
627
787
  # TODO: Nested column names
628
788
  names = chunk.column_names(include_children=False)
629
- concatenated_columns = tbl.columns()
789
+ concatenated_columns = chunk.tbl.columns()
630
790
  while reader.has_next():
631
- chunk = reader.read_chunk()
632
- tbl = chunk.tbl
633
- for i in range(tbl.num_columns()):
791
+ columns = reader.read_chunk().tbl.columns()
792
+ # Discard columns while concatenating to reduce memory footprint.
793
+ # Reverse order to avoid O(n^2) list popping cost.
794
+ for i in range(len(concatenated_columns) - 1, -1, -1):
634
795
  concatenated_columns[i] = plc.concatenate.concatenate(
635
- [concatenated_columns[i], tbl._columns[i]]
796
+ [concatenated_columns[i], columns.pop()], stream=stream
636
797
  )
637
- # Drop residual columns to save memory
638
- tbl._columns[i] = None
639
798
  df = DataFrame.from_table(
640
799
  plc.Table(concatenated_columns),
641
800
  names=names,
642
801
  dtypes=[schema[name] for name in names],
802
+ stream=stream,
643
803
  )
644
804
  df = _align_parquet_schema(df, schema)
645
805
  if include_file_paths is not None:
@@ -647,13 +807,16 @@ class Scan(IR):
647
807
  include_file_paths, paths, chunk.num_rows_per_source, df
648
808
  )
649
809
  else:
650
- tbl_w_meta = plc.io.parquet.read_parquet(options)
810
+ tbl_w_meta = plc.io.parquet.read_parquet(
811
+ parquet_reader_options, stream=stream
812
+ )
651
813
  # TODO: consider nested column names?
652
814
  col_names = tbl_w_meta.column_names(include_children=False)
653
815
  df = DataFrame.from_table(
654
816
  tbl_w_meta.tbl,
655
817
  col_names,
656
818
  [schema[name] for name in col_names],
819
+ stream=stream,
657
820
  )
658
821
  df = _align_parquet_schema(df, schema)
659
822
  if include_file_paths is not None:
@@ -665,16 +828,16 @@ class Scan(IR):
665
828
  return df
666
829
  elif typ == "ndjson":
667
830
  json_schema: list[plc.io.json.NameAndType] = [
668
- (name, typ.plc, []) for name, typ in schema.items()
831
+ (name, typ.plc_type, []) for name, typ in schema.items()
669
832
  ]
670
- plc_tbl_w_meta = plc.io.json.read_json(
671
- plc.io.json._setup_json_reader_options(
672
- plc.io.SourceInfo(paths),
673
- lines=True,
674
- dtypes=json_schema,
675
- prune_columns=True,
676
- )
833
+ json_reader_options = (
834
+ plc.io.json.JsonReaderOptions.builder(plc.io.SourceInfo(paths))
835
+ .lines(val=True)
836
+ .dtypes(json_schema)
837
+ .prune_columns(val=True)
838
+ .build()
677
839
  )
840
+ plc_tbl_w_meta = plc.io.json.read_json(json_reader_options, stream=stream)
678
841
  # TODO: I don't think cudf-polars supports nested types in general right now
679
842
  # (but when it does, we should pass child column names from nested columns in)
680
843
  col_names = plc_tbl_w_meta.column_names(include_children=False)
@@ -682,6 +845,7 @@ class Scan(IR):
682
845
  plc_tbl_w_meta.tbl,
683
846
  col_names,
684
847
  [schema[name] for name in col_names],
848
+ stream=stream,
685
849
  )
686
850
  col_order = list(schema.keys())
687
851
  if row_index is not None:
@@ -695,26 +859,28 @@ class Scan(IR):
695
859
  name, offset = row_index
696
860
  offset += skip_rows
697
861
  dtype = schema[name]
698
- step = plc.Scalar.from_py(1, dtype.plc)
699
- init = plc.Scalar.from_py(offset, dtype.plc)
862
+ step = plc.Scalar.from_py(1, dtype.plc_type, stream=stream)
863
+ init = plc.Scalar.from_py(offset, dtype.plc_type, stream=stream)
700
864
  index_col = Column(
701
- plc.filling.sequence(df.num_rows, init, step),
865
+ plc.filling.sequence(df.num_rows, init, step, stream=stream),
702
866
  is_sorted=plc.types.Sorted.YES,
703
867
  order=plc.types.Order.ASCENDING,
704
868
  null_order=plc.types.NullOrder.AFTER,
705
869
  name=name,
706
870
  dtype=dtype,
707
871
  )
708
- df = DataFrame([index_col, *df.columns])
872
+ df = DataFrame([index_col, *df.columns], stream=df.stream)
709
873
  if next(iter(schema)) != name:
710
874
  df = df.select(schema)
711
875
  assert all(
712
- c.obj.type() == schema[name].plc for name, c in df.column_map.items()
876
+ c.obj.type() == schema[name].plc_type for name, c in df.column_map.items()
713
877
  )
714
878
  if predicate is None:
715
879
  return df
716
880
  else:
717
- (mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)
881
+ (mask,) = broadcast(
882
+ predicate.evaluate(df), target_length=df.num_rows, stream=df.stream
883
+ )
718
884
  return df.filter(mask)
719
885
 
720
886
 
@@ -775,7 +941,8 @@ class Sink(IR):
775
941
  child_schema = df.schema.values()
776
942
  if kind == "Csv":
777
943
  if not all(
778
- plc.io.csv.is_supported_write_csv(dtype.plc) for dtype in child_schema
944
+ plc.io.csv.is_supported_write_csv(dtype.plc_type)
945
+ for dtype in child_schema
779
946
  ):
780
947
  # Nested types are unsupported in polars and libcudf
781
948
  raise NotImplementedError(
@@ -826,7 +993,8 @@ class Sink(IR):
826
993
  kind == "Json"
827
994
  ): # pragma: no cover; options are validated on the polars side
828
995
  if not all(
829
- plc.io.json.is_supported_write_json(dtype.plc) for dtype in child_schema
996
+ plc.io.json.is_supported_write_json(dtype.plc_type)
997
+ for dtype in child_schema
830
998
  ):
831
999
  # Nested types are unsupported in polars and libcudf
832
1000
  raise NotImplementedError(
@@ -863,7 +1031,7 @@ class Sink(IR):
863
1031
  ) -> None:
864
1032
  """Write CSV data to a sink."""
865
1033
  serialize = options["serialize_options"]
866
- options = (
1034
+ csv_writer_options = (
867
1035
  plc.io.csv.CsvWriterOptions.builder(target, df.table)
868
1036
  .include_header(options["include_header"])
869
1037
  .names(df.column_names if options["include_header"] else [])
@@ -872,7 +1040,7 @@ class Sink(IR):
872
1040
  .inter_column_delimiter(chr(serialize["separator"]))
873
1041
  .build()
874
1042
  )
875
- plc.io.csv.write_csv(options)
1043
+ plc.io.csv.write_csv(csv_writer_options, stream=df.stream)
876
1044
 
877
1045
  @classmethod
878
1046
  def _write_json(cls, target: plc.io.SinkInfo, df: DataFrame) -> None:
@@ -889,7 +1057,7 @@ class Sink(IR):
889
1057
  .utf8_escaped(val=False)
890
1058
  .build()
891
1059
  )
892
- plc.io.json.write_json(options)
1060
+ plc.io.json.write_json(options, stream=df.stream)
893
1061
 
894
1062
  @staticmethod
895
1063
  def _make_parquet_metadata(df: DataFrame) -> plc.io.types.TableInputMetadata:
@@ -899,6 +1067,20 @@ class Sink(IR):
899
1067
  metadata.column_metadata[i].set_name(name)
900
1068
  return metadata
901
1069
 
1070
+ @overload
1071
+ @staticmethod
1072
+ def _apply_parquet_writer_options(
1073
+ builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder,
1074
+ options: dict[str, Any],
1075
+ ) -> plc.io.parquet.ChunkedParquetWriterOptionsBuilder: ...
1076
+
1077
+ @overload
1078
+ @staticmethod
1079
+ def _apply_parquet_writer_options(
1080
+ builder: plc.io.parquet.ParquetWriterOptionsBuilder,
1081
+ options: dict[str, Any],
1082
+ ) -> plc.io.parquet.ParquetWriterOptionsBuilder: ...
1083
+
902
1084
  @staticmethod
903
1085
  def _apply_parquet_writer_options(
904
1086
  builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder
@@ -944,12 +1126,16 @@ class Sink(IR):
944
1126
  and parquet_options.n_output_chunks != 1
945
1127
  and df.table.num_rows() != 0
946
1128
  ):
947
- builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
1129
+ chunked_builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
948
1130
  target
949
1131
  ).metadata(metadata)
950
- builder = cls._apply_parquet_writer_options(builder, options)
951
- writer_options = builder.build()
952
- writer = plc.io.parquet.ChunkedParquetWriter.from_options(writer_options)
1132
+ chunked_builder = cls._apply_parquet_writer_options(
1133
+ chunked_builder, options
1134
+ )
1135
+ chunked_writer_options = chunked_builder.build()
1136
+ writer = plc.io.parquet.ChunkedParquetWriter.from_options(
1137
+ chunked_writer_options, stream=df.stream
1138
+ )
953
1139
 
954
1140
  # TODO: Can be based on a heuristic that estimates chunk size
955
1141
  # from the input table size and available GPU memory.
@@ -957,6 +1143,7 @@ class Sink(IR):
957
1143
  table_chunks = plc.copying.split(
958
1144
  df.table,
959
1145
  [i * df.table.num_rows() // num_chunks for i in range(1, num_chunks)],
1146
+ stream=df.stream,
960
1147
  )
961
1148
  for chunk in table_chunks:
962
1149
  writer.write(chunk)
@@ -968,9 +1155,10 @@ class Sink(IR):
968
1155
  ).metadata(metadata)
969
1156
  builder = cls._apply_parquet_writer_options(builder, options)
970
1157
  writer_options = builder.build()
971
- plc.io.parquet.write_parquet(writer_options)
1158
+ plc.io.parquet.write_parquet(writer_options, stream=df.stream)
972
1159
 
973
1160
  @classmethod
1161
+ @log_do_evaluate
974
1162
  @nvtx_annotate_cudf_polars(message="Sink")
975
1163
  def do_evaluate(
976
1164
  cls,
@@ -980,6 +1168,8 @@ class Sink(IR):
980
1168
  parquet_options: ParquetOptions,
981
1169
  options: dict[str, Any],
982
1170
  df: DataFrame,
1171
+ *,
1172
+ context: IRExecutionContext,
983
1173
  ) -> DataFrame:
984
1174
  """Write the dataframe to a file."""
985
1175
  target = plc.io.SinkInfo([path])
@@ -993,7 +1183,7 @@ class Sink(IR):
993
1183
  elif kind == "Json":
994
1184
  cls._write_json(target, df)
995
1185
 
996
- return DataFrame([])
1186
+ return DataFrame([], stream=df.stream)
997
1187
 
998
1188
 
999
1189
  class Cache(IR):
@@ -1030,16 +1220,24 @@ class Cache(IR):
1030
1220
  return False
1031
1221
 
1032
1222
  @classmethod
1223
+ @log_do_evaluate
1033
1224
  @nvtx_annotate_cudf_polars(message="Cache")
1034
1225
  def do_evaluate(
1035
- cls, key: int, refcount: int | None, df: DataFrame
1226
+ cls,
1227
+ key: int,
1228
+ refcount: int | None,
1229
+ df: DataFrame,
1230
+ *,
1231
+ context: IRExecutionContext,
1036
1232
  ) -> DataFrame: # pragma: no cover; basic evaluation never calls this
1037
1233
  """Evaluate and return a dataframe."""
1038
1234
  # Our value has already been computed for us, so let's just
1039
1235
  # return it.
1040
1236
  return df
1041
1237
 
1042
- def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1238
+ def evaluate(
1239
+ self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
1240
+ ) -> DataFrame:
1043
1241
  """Evaluate and return a dataframe."""
1044
1242
  # We must override the recursion scheme because we don't want
1045
1243
  # to recurse if we're in the cache.
@@ -1047,7 +1245,7 @@ class Cache(IR):
1047
1245
  (result, hits) = cache[self.key]
1048
1246
  except KeyError:
1049
1247
  (value,) = self.children
1050
- result = value.evaluate(cache=cache, timer=timer)
1248
+ result = value.evaluate(cache=cache, timer=timer, context=context)
1051
1249
  cache[self.key] = (result, 0)
1052
1250
  return result
1053
1251
  else:
@@ -1110,19 +1308,22 @@ class DataFrameScan(IR):
1110
1308
  )
1111
1309
 
1112
1310
  @classmethod
1311
+ @log_do_evaluate
1113
1312
  @nvtx_annotate_cudf_polars(message="DataFrameScan")
1114
1313
  def do_evaluate(
1115
1314
  cls,
1116
1315
  schema: Schema,
1117
1316
  df: Any,
1118
1317
  projection: tuple[str, ...] | None,
1318
+ *,
1319
+ context: IRExecutionContext,
1119
1320
  ) -> DataFrame:
1120
1321
  """Evaluate and return a dataframe."""
1121
1322
  if projection is not None:
1122
1323
  df = df.select(projection)
1123
- df = DataFrame.from_polars(df)
1324
+ df = DataFrame.from_polars(df, stream=context.get_cuda_stream())
1124
1325
  assert all(
1125
- c.obj.type() == dtype.plc
1326
+ c.obj.type() == dtype.plc_type
1126
1327
  for c, dtype in zip(df.columns, schema.values(), strict=True)
1127
1328
  )
1128
1329
  return df
@@ -1169,21 +1370,26 @@ class Select(IR):
1169
1370
  return False
1170
1371
 
1171
1372
  @classmethod
1373
+ @log_do_evaluate
1172
1374
  @nvtx_annotate_cudf_polars(message="Select")
1173
1375
  def do_evaluate(
1174
1376
  cls,
1175
1377
  exprs: tuple[expr.NamedExpr, ...],
1176
1378
  should_broadcast: bool, # noqa: FBT001
1177
1379
  df: DataFrame,
1380
+ *,
1381
+ context: IRExecutionContext,
1178
1382
  ) -> DataFrame:
1179
1383
  """Evaluate and return a dataframe."""
1180
1384
  # Handle any broadcasting
1181
1385
  columns = [e.evaluate(df) for e in exprs]
1182
1386
  if should_broadcast:
1183
- columns = broadcast(*columns)
1184
- return DataFrame(columns)
1387
+ columns = broadcast(*columns, stream=df.stream)
1388
+ return DataFrame(columns, stream=df.stream)
1185
1389
 
1186
- def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1390
+ def evaluate(
1391
+ self, *, cache: CSECache, timer: Timer | None, context: IRExecutionContext
1392
+ ) -> DataFrame:
1187
1393
  """
1188
1394
  Evaluate the Select node with special handling for fast count queries.
1189
1395
 
@@ -1195,6 +1401,8 @@ class Select(IR):
1195
1401
  timer
1196
1402
  If not None, a Timer object to record timings for the
1197
1403
  evaluation of the node.
1404
+ context
1405
+ The execution context for the node.
1198
1406
 
1199
1407
  Returns
1200
1408
  -------
@@ -1214,21 +1422,23 @@ class Select(IR):
1214
1422
  and Select._is_len_expr(self.exprs)
1215
1423
  and self.children[0].typ == "parquet"
1216
1424
  and self.children[0].predicate is None
1217
- ):
1218
- scan = self.children[0] # pragma: no cover
1219
- effective_rows = scan.fast_count() # pragma: no cover
1220
- dtype = DataType(pl.UInt32()) # pragma: no cover
1425
+ ): # pragma: no cover
1426
+ stream = context.get_cuda_stream()
1427
+ scan = self.children[0]
1428
+ effective_rows = scan.fast_count()
1429
+ dtype = DataType(pl.UInt32())
1221
1430
  col = Column(
1222
1431
  plc.Column.from_scalar(
1223
- plc.Scalar.from_py(effective_rows, dtype.plc),
1432
+ plc.Scalar.from_py(effective_rows, dtype.plc_type, stream=stream),
1224
1433
  1,
1434
+ stream=stream,
1225
1435
  ),
1226
1436
  name=self.exprs[0].name or "len",
1227
1437
  dtype=dtype,
1228
- ) # pragma: no cover
1229
- return DataFrame([col]) # pragma: no cover
1438
+ )
1439
+ return DataFrame([col], stream=stream)
1230
1440
 
1231
- return super().evaluate(cache=cache, timer=timer)
1441
+ return super().evaluate(cache=cache, timer=timer, context=context)
1232
1442
 
1233
1443
 
1234
1444
  class Reduce(IR):
@@ -1252,16 +1462,19 @@ class Reduce(IR):
1252
1462
  self._non_child_args = (self.exprs,)
1253
1463
 
1254
1464
  @classmethod
1465
+ @log_do_evaluate
1255
1466
  @nvtx_annotate_cudf_polars(message="Reduce")
1256
1467
  def do_evaluate(
1257
1468
  cls,
1258
1469
  exprs: tuple[expr.NamedExpr, ...],
1259
1470
  df: DataFrame,
1471
+ *,
1472
+ context: IRExecutionContext,
1260
1473
  ) -> DataFrame: # pragma: no cover; not exposed by polars yet
1261
1474
  """Evaluate and return a dataframe."""
1262
- columns = broadcast(*(e.evaluate(df) for e in exprs))
1475
+ columns = broadcast(*(e.evaluate(df) for e in exprs), stream=df.stream)
1263
1476
  assert all(column.size == 1 for column in columns)
1264
- return DataFrame(columns)
1477
+ return DataFrame(columns, stream=df.stream)
1265
1478
 
1266
1479
 
1267
1480
  class Rolling(IR):
@@ -1270,17 +1483,19 @@ class Rolling(IR):
1270
1483
  __slots__ = (
1271
1484
  "agg_requests",
1272
1485
  "closed_window",
1273
- "following",
1486
+ "following_ordinal",
1274
1487
  "index",
1488
+ "index_dtype",
1275
1489
  "keys",
1276
- "preceding",
1490
+ "preceding_ordinal",
1277
1491
  "zlice",
1278
1492
  )
1279
1493
  _non_child = (
1280
1494
  "schema",
1281
1495
  "index",
1282
- "preceding",
1283
- "following",
1496
+ "index_dtype",
1497
+ "preceding_ordinal",
1498
+ "following_ordinal",
1284
1499
  "closed_window",
1285
1500
  "keys",
1286
1501
  "agg_requests",
@@ -1288,10 +1503,12 @@ class Rolling(IR):
1288
1503
  )
1289
1504
  index: expr.NamedExpr
1290
1505
  """Column being rolled over."""
1291
- preceding: plc.Scalar
1292
- """Preceding window extent defining start of window."""
1293
- following: plc.Scalar
1294
- """Following window extent defining end of window."""
1506
+ index_dtype: plc.DataType
1507
+ """Datatype of the index column."""
1508
+ preceding_ordinal: int
1509
+ """Preceding window extent defining start of window as a host integer."""
1510
+ following_ordinal: int
1511
+ """Following window extent defining end of window as a host integer."""
1295
1512
  closed_window: ClosedInterval
1296
1513
  """Treatment of window endpoints."""
1297
1514
  keys: tuple[expr.NamedExpr, ...]
@@ -1305,8 +1522,9 @@ class Rolling(IR):
1305
1522
  self,
1306
1523
  schema: Schema,
1307
1524
  index: expr.NamedExpr,
1308
- preceding: plc.Scalar,
1309
- following: plc.Scalar,
1525
+ index_dtype: plc.DataType,
1526
+ preceding_ordinal: int,
1527
+ following_ordinal: int,
1310
1528
  closed_window: ClosedInterval,
1311
1529
  keys: Sequence[expr.NamedExpr],
1312
1530
  agg_requests: Sequence[expr.NamedExpr],
@@ -1315,14 +1533,15 @@ class Rolling(IR):
1315
1533
  ):
1316
1534
  self.schema = schema
1317
1535
  self.index = index
1318
- self.preceding = preceding
1319
- self.following = following
1536
+ self.index_dtype = index_dtype
1537
+ self.preceding_ordinal = preceding_ordinal
1538
+ self.following_ordinal = following_ordinal
1320
1539
  self.closed_window = closed_window
1321
1540
  self.keys = tuple(keys)
1322
1541
  self.agg_requests = tuple(agg_requests)
1323
1542
  if not all(
1324
1543
  plc.rolling.is_valid_rolling_aggregation(
1325
- agg.value.dtype.plc, agg.value.agg_request
1544
+ agg.value.dtype.plc_type, agg.value.agg_request
1326
1545
  )
1327
1546
  for agg in self.agg_requests
1328
1547
  ):
@@ -1339,8 +1558,9 @@ class Rolling(IR):
1339
1558
  self.children = (df,)
1340
1559
  self._non_child_args = (
1341
1560
  index,
1342
- preceding,
1343
- following,
1561
+ index_dtype,
1562
+ preceding_ordinal,
1563
+ following_ordinal,
1344
1564
  closed_window,
1345
1565
  keys,
1346
1566
  agg_requests,
@@ -1348,31 +1568,46 @@ class Rolling(IR):
1348
1568
  )
1349
1569
 
1350
1570
  @classmethod
1571
+ @log_do_evaluate
1351
1572
  @nvtx_annotate_cudf_polars(message="Rolling")
1352
1573
  def do_evaluate(
1353
1574
  cls,
1354
1575
  index: expr.NamedExpr,
1355
- preceding: plc.Scalar,
1356
- following: plc.Scalar,
1576
+ index_dtype: plc.DataType,
1577
+ preceding_ordinal: int,
1578
+ following_ordinal: int,
1357
1579
  closed_window: ClosedInterval,
1358
1580
  keys_in: Sequence[expr.NamedExpr],
1359
1581
  aggs: Sequence[expr.NamedExpr],
1360
1582
  zlice: Zlice | None,
1361
1583
  df: DataFrame,
1584
+ *,
1585
+ context: IRExecutionContext,
1362
1586
  ) -> DataFrame:
1363
1587
  """Evaluate and return a dataframe."""
1364
- keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1588
+ keys = broadcast(
1589
+ *(k.evaluate(df) for k in keys_in),
1590
+ target_length=df.num_rows,
1591
+ stream=df.stream,
1592
+ )
1365
1593
  orderby = index.evaluate(df)
1366
1594
  # Polars casts integral orderby to int64, but only for calculating window bounds
1367
1595
  if (
1368
1596
  plc.traits.is_integral(orderby.obj.type())
1369
1597
  and orderby.obj.type().id() != plc.TypeId.INT64
1370
1598
  ):
1371
- orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
1599
+ orderby_obj = plc.unary.cast(
1600
+ orderby.obj, plc.DataType(plc.TypeId.INT64), stream=df.stream
1601
+ )
1372
1602
  else:
1373
1603
  orderby_obj = orderby.obj
1604
+
1605
+ preceding_scalar, following_scalar = offsets_to_windows(
1606
+ index_dtype, preceding_ordinal, following_ordinal, stream=df.stream
1607
+ )
1608
+
1374
1609
  preceding_window, following_window = range_window_bounds(
1375
- preceding, following, closed_window
1610
+ preceding_scalar, following_scalar, closed_window
1376
1611
  )
1377
1612
  if orderby.obj.null_count() != 0:
1378
1613
  raise RuntimeError(
@@ -1383,12 +1618,17 @@ class Rolling(IR):
1383
1618
  table = plc.Table([*(k.obj for k in keys), orderby_obj])
1384
1619
  n = table.num_columns()
1385
1620
  if not plc.sorting.is_sorted(
1386
- table, [plc.types.Order.ASCENDING] * n, [plc.types.NullOrder.BEFORE] * n
1621
+ table,
1622
+ [plc.types.Order.ASCENDING] * n,
1623
+ [plc.types.NullOrder.BEFORE] * n,
1624
+ stream=df.stream,
1387
1625
  ):
1388
1626
  raise RuntimeError("Input for grouped rolling is not sorted")
1389
1627
  else:
1390
1628
  if not orderby.check_sorted(
1391
- order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
1629
+ order=plc.types.Order.ASCENDING,
1630
+ null_order=plc.types.NullOrder.BEFORE,
1631
+ stream=df.stream,
1392
1632
  ):
1393
1633
  raise RuntimeError(
1394
1634
  f"Index column '{index.name}' in rolling is not sorted, please sort first"
@@ -1401,6 +1641,7 @@ class Rolling(IR):
1401
1641
  preceding_window,
1402
1642
  following_window,
1403
1643
  [rolling.to_request(request.value, orderby, df) for request in aggs],
1644
+ stream=df.stream,
1404
1645
  )
1405
1646
  return DataFrame(
1406
1647
  itertools.chain(
@@ -1410,7 +1651,8 @@ class Rolling(IR):
1410
1651
  Column(col, name=request.name, dtype=request.value.dtype)
1411
1652
  for col, request in zip(values.columns(), aggs, strict=True)
1412
1653
  ),
1413
- )
1654
+ ),
1655
+ stream=df.stream,
1414
1656
  ).slice(zlice)
1415
1657
 
1416
1658
 
@@ -1472,6 +1714,7 @@ class GroupBy(IR):
1472
1714
  )
1473
1715
 
1474
1716
  @classmethod
1717
+ @log_do_evaluate
1475
1718
  @nvtx_annotate_cudf_polars(message="GroupBy")
1476
1719
  def do_evaluate(
1477
1720
  cls,
@@ -1481,9 +1724,15 @@ class GroupBy(IR):
1481
1724
  maintain_order: bool, # noqa: FBT001
1482
1725
  zlice: Zlice | None,
1483
1726
  df: DataFrame,
1727
+ *,
1728
+ context: IRExecutionContext,
1484
1729
  ) -> DataFrame:
1485
1730
  """Evaluate and return a dataframe."""
1486
- keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1731
+ keys = broadcast(
1732
+ *(k.evaluate(df) for k in keys_in),
1733
+ target_length=df.num_rows,
1734
+ stream=df.stream,
1735
+ )
1487
1736
  sorted = (
1488
1737
  plc.types.Sorted.YES
1489
1738
  if all(k.is_sorted for k in keys)
@@ -1515,7 +1764,7 @@ class GroupBy(IR):
1515
1764
  col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
1516
1765
  requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
1517
1766
  names.append(name)
1518
- group_keys, raw_tables = grouper.aggregate(requests)
1767
+ group_keys, raw_tables = grouper.aggregate(requests, stream=df.stream)
1519
1768
  results = [
1520
1769
  Column(column, name=name, dtype=schema[name])
1521
1770
  for name, column, request in zip(
@@ -1529,7 +1778,7 @@ class GroupBy(IR):
1529
1778
  Column(grouped_key, name=key.name, dtype=key.dtype)
1530
1779
  for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
1531
1780
  ]
1532
- broadcasted = broadcast(*result_keys, *results)
1781
+ broadcasted = broadcast(*result_keys, *results, stream=df.stream)
1533
1782
  # Handle order preservation of groups
1534
1783
  if maintain_order and not sorted:
1535
1784
  # The order we want
@@ -1539,6 +1788,7 @@ class GroupBy(IR):
1539
1788
  plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
1540
1789
  plc.types.NullEquality.EQUAL,
1541
1790
  plc.types.NanEquality.ALL_EQUAL,
1791
+ stream=df.stream,
1542
1792
  )
1543
1793
  # The order we have
1544
1794
  have = plc.Table([key.obj for key in broadcasted[: len(keys)]])
@@ -1546,7 +1796,7 @@ class GroupBy(IR):
1546
1796
  # We know an inner join is OK because by construction
1547
1797
  # want and have are permutations of each other.
1548
1798
  left_order, right_order = plc.join.inner_join(
1549
- want, have, plc.types.NullEquality.EQUAL
1799
+ want, have, plc.types.NullEquality.EQUAL, stream=df.stream
1550
1800
  )
1551
1801
  # Now left_order is an arbitrary permutation of the ordering we
1552
1802
  # want, and right_order is a matching permutation of the ordering
@@ -1559,11 +1809,13 @@ class GroupBy(IR):
1559
1809
  plc.Table([left_order]),
1560
1810
  [plc.types.Order.ASCENDING],
1561
1811
  [plc.types.NullOrder.AFTER],
1812
+ stream=df.stream,
1562
1813
  ).columns()
1563
1814
  ordered_table = plc.copying.gather(
1564
1815
  plc.Table([col.obj for col in broadcasted]),
1565
1816
  right_order,
1566
1817
  plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1818
+ stream=df.stream,
1567
1819
  )
1568
1820
  broadcasted = [
1569
1821
  Column(reordered, name=old.name, dtype=old.dtype)
@@ -1571,7 +1823,126 @@ class GroupBy(IR):
1571
1823
  ordered_table.columns(), broadcasted, strict=True
1572
1824
  )
1573
1825
  ]
1574
- return DataFrame(broadcasted).slice(zlice)
1826
+ return DataFrame(broadcasted, stream=df.stream).slice(zlice)
1827
+
1828
+
1829
+ def _strip_predicate_casts(node: expr.Expr) -> expr.Expr:
1830
+ if isinstance(node, expr.Cast):
1831
+ (child,) = node.children
1832
+ child = _strip_predicate_casts(child)
1833
+
1834
+ src = child.dtype
1835
+ dst = node.dtype
1836
+
1837
+ if plc.traits.is_fixed_point(src.plc_type) or plc.traits.is_fixed_point(
1838
+ dst.plc_type
1839
+ ):
1840
+ return child
1841
+
1842
+ if (
1843
+ not POLARS_VERSION_LT_134
1844
+ and isinstance(child, expr.ColRef)
1845
+ and (
1846
+ (
1847
+ plc.traits.is_floating_point(src.plc_type)
1848
+ and plc.traits.is_floating_point(dst.plc_type)
1849
+ )
1850
+ or (
1851
+ plc.traits.is_integral(src.plc_type)
1852
+ and plc.traits.is_integral(dst.plc_type)
1853
+ and src.plc_type.id() == dst.plc_type.id()
1854
+ )
1855
+ )
1856
+ ):
1857
+ return child
1858
+
1859
+ if not node.children:
1860
+ return node
1861
+ return node.reconstruct([_strip_predicate_casts(child) for child in node.children])
1862
+
1863
+
1864
+ def _add_cast(
1865
+ target: DataType,
1866
+ side: expr.ColRef,
1867
+ left_casts: dict[str, DataType],
1868
+ right_casts: dict[str, DataType],
1869
+ ) -> None:
1870
+ (col,) = side.children
1871
+ assert isinstance(col, expr.Col)
1872
+ casts = (
1873
+ left_casts if side.table_ref == plc_expr.TableReference.LEFT else right_casts
1874
+ )
1875
+ casts[col.name] = target
1876
+
1877
+
1878
+ def _align_decimal_binop_types(
1879
+ left_expr: expr.ColRef,
1880
+ right_expr: expr.ColRef,
1881
+ left_casts: dict[str, DataType],
1882
+ right_casts: dict[str, DataType],
1883
+ ) -> None:
1884
+ left_type, right_type = left_expr.dtype, right_expr.dtype
1885
+
1886
+ if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
1887
+ right_type.plc_type
1888
+ ):
1889
+ target = DataType.common_decimal_dtype(left_type, right_type)
1890
+
1891
+ if left_type.id() != target.id() or left_type.scale() != target.scale():
1892
+ _add_cast(target, left_expr, left_casts, right_casts)
1893
+
1894
+ if right_type.id() != target.id() or right_type.scale() != target.scale():
1895
+ _add_cast(target, right_expr, left_casts, right_casts)
1896
+
1897
+ elif (
1898
+ plc.traits.is_fixed_point(left_type.plc_type)
1899
+ and plc.traits.is_floating_point(right_type.plc_type)
1900
+ ) or (
1901
+ plc.traits.is_fixed_point(right_type.plc_type)
1902
+ and plc.traits.is_floating_point(left_type.plc_type)
1903
+ ):
1904
+ is_decimal_left = plc.traits.is_fixed_point(left_type.plc_type)
1905
+ decimal_expr, float_expr = (
1906
+ (left_expr, right_expr) if is_decimal_left else (right_expr, left_expr)
1907
+ )
1908
+ _add_cast(decimal_expr.dtype, float_expr, left_casts, right_casts)
1909
+
1910
+
1911
+ def _collect_decimal_binop_casts(
1912
+ predicate: expr.Expr,
1913
+ ) -> tuple[dict[str, DataType], dict[str, DataType]]:
1914
+ left_casts: dict[str, DataType] = {}
1915
+ right_casts: dict[str, DataType] = {}
1916
+
1917
+ def _walk(node: expr.Expr) -> None:
1918
+ if isinstance(node, expr.BinOp) and node.op in _BINOPS:
1919
+ left_expr, right_expr = node.children
1920
+ if isinstance(left_expr, expr.ColRef) and isinstance(
1921
+ right_expr, expr.ColRef
1922
+ ):
1923
+ _align_decimal_binop_types(
1924
+ left_expr, right_expr, left_casts, right_casts
1925
+ )
1926
+ for child in node.children:
1927
+ _walk(child)
1928
+
1929
+ _walk(predicate)
1930
+ return left_casts, right_casts
1931
+
1932
+
1933
+ def _apply_casts(df: DataFrame, casts: dict[str, DataType]) -> DataFrame:
1934
+ if not casts:
1935
+ return df
1936
+
1937
+ columns = []
1938
+ for col in df.columns:
1939
+ target = casts.get(col.name)
1940
+ if target is None:
1941
+ columns.append(Column(col.obj, dtype=col.dtype, name=col.name))
1942
+ else:
1943
+ casted = col.astype(target, stream=df.stream)
1944
+ columns.append(Column(casted.obj, dtype=casted.dtype, name=col.name))
1945
+ return DataFrame(columns, stream=df.stream)
1575
1946
 
1576
1947
 
1577
1948
  class ConditionalJoin(IR):
@@ -1585,7 +1956,14 @@ class ConditionalJoin(IR):
1585
1956
 
1586
1957
  def __init__(self, predicate: expr.Expr):
1587
1958
  self.predicate = predicate
1588
- self.ast = to_ast(predicate)
1959
+ stream = get_cuda_stream()
1960
+ ast_result = to_ast(predicate, stream=stream)
1961
+ stream.synchronize()
1962
+ if ast_result is None:
1963
+ raise NotImplementedError(
1964
+ f"Conditional join with predicate {predicate}"
1965
+ ) # pragma: no cover; polars never delivers expressions we can't handle
1966
+ self.ast = ast_result
1589
1967
 
1590
1968
  def __reduce__(self) -> tuple[Any, ...]:
1591
1969
  """Pickle a Predicate object."""
@@ -1598,8 +1976,9 @@ class ConditionalJoin(IR):
1598
1976
  options: tuple[
1599
1977
  tuple[
1600
1978
  str,
1601
- pl_expr.Operator | Iterable[pl_expr.Operator],
1602
- ],
1979
+ polars._expr_nodes.Operator | Iterable[polars._expr_nodes.Operator],
1980
+ ]
1981
+ | None,
1603
1982
  bool,
1604
1983
  Zlice | None,
1605
1984
  str,
@@ -1620,7 +1999,14 @@ class ConditionalJoin(IR):
1620
1999
  self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
1621
2000
  ) -> None:
1622
2001
  self.schema = schema
2002
+ predicate = _strip_predicate_casts(predicate)
1623
2003
  self.predicate = predicate
2004
+ # options[0] is a tuple[str, Operator, ...]
2005
+ # The Operator class can't be pickled, but we don't use it anyway so
2006
+ # just throw that away
2007
+ if options[0] is not None:
2008
+ options = (None, *options[1:])
2009
+
1624
2010
  self.options = options
1625
2011
  self.children = (left, right)
1626
2012
  predicate_wrapper = self.Predicate(predicate)
@@ -1629,51 +2015,70 @@ class ConditionalJoin(IR):
1629
2015
  assert not nulls_equal
1630
2016
  assert not coalesce
1631
2017
  assert maintain_order == "none"
1632
- if predicate_wrapper.ast is None:
1633
- raise NotImplementedError(
1634
- f"Conditional join with predicate {predicate}"
1635
- ) # pragma: no cover; polars never delivers expressions we can't handle
1636
- self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
2018
+ self._non_child_args = (predicate_wrapper, options)
1637
2019
 
1638
2020
  @classmethod
2021
+ @log_do_evaluate
1639
2022
  @nvtx_annotate_cudf_polars(message="ConditionalJoin")
1640
2023
  def do_evaluate(
1641
2024
  cls,
1642
2025
  predicate_wrapper: Predicate,
1643
- zlice: Zlice | None,
1644
- suffix: str,
1645
- maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
2026
+ options: tuple,
1646
2027
  left: DataFrame,
1647
2028
  right: DataFrame,
2029
+ *,
2030
+ context: IRExecutionContext,
1648
2031
  ) -> DataFrame:
1649
2032
  """Evaluate and return a dataframe."""
2033
+ stream = get_joined_cuda_stream(
2034
+ context.get_cuda_stream,
2035
+ upstreams=(
2036
+ left.stream,
2037
+ right.stream,
2038
+ ),
2039
+ )
2040
+ left_casts, right_casts = _collect_decimal_binop_casts(
2041
+ predicate_wrapper.predicate
2042
+ )
2043
+ _, _, zlice, suffix, _, _ = options
2044
+
1650
2045
  lg, rg = plc.join.conditional_inner_join(
1651
- left.table,
1652
- right.table,
2046
+ _apply_casts(left, left_casts).table,
2047
+ _apply_casts(right, right_casts).table,
1653
2048
  predicate_wrapper.ast,
2049
+ stream=stream,
1654
2050
  )
1655
- left = DataFrame.from_table(
2051
+ left_result = DataFrame.from_table(
1656
2052
  plc.copying.gather(
1657
- left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
2053
+ left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK, stream=stream
1658
2054
  ),
1659
2055
  left.column_names,
1660
2056
  left.dtypes,
2057
+ stream=stream,
1661
2058
  )
1662
- right = DataFrame.from_table(
2059
+ right_result = DataFrame.from_table(
1663
2060
  plc.copying.gather(
1664
- right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
2061
+ right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK, stream=stream
1665
2062
  ),
1666
2063
  right.column_names,
1667
2064
  right.dtypes,
2065
+ stream=stream,
1668
2066
  )
1669
- right = right.rename_columns(
2067
+ right_result = right_result.rename_columns(
1670
2068
  {
1671
2069
  name: f"{name}{suffix}"
1672
2070
  for name in right.column_names
1673
2071
  if name in left.column_names_set
1674
2072
  }
1675
2073
  )
1676
- result = left.with_columns(right.columns)
2074
+ result = left_result.with_columns(right_result.columns, stream=stream)
2075
+
2076
+ # Join the original streams back into the result stream to ensure that the
2077
+ # deallocations (on the original streams) happen after the result is ready
2078
+ join_cuda_streams(
2079
+ downstreams=(left.stream, right.stream), upstreams=(result.stream,)
2080
+ )
2081
+
1677
2082
  return result.slice(zlice)
1678
2083
 
1679
2084
 
@@ -1704,6 +2109,19 @@ class Join(IR):
1704
2109
  - maintain_order: which DataFrame row order to preserve, if any
1705
2110
  """
1706
2111
 
2112
+ SWAPPED_ORDER: ClassVar[
2113
+ dict[
2114
+ Literal["none", "left", "right", "left_right", "right_left"],
2115
+ Literal["none", "left", "right", "left_right", "right_left"],
2116
+ ]
2117
+ ] = {
2118
+ "none": "none",
2119
+ "left": "right",
2120
+ "right": "left",
2121
+ "left_right": "right_left",
2122
+ "right_left": "left_right",
2123
+ }
2124
+
1707
2125
  def __init__(
1708
2126
  self,
1709
2127
  schema: Schema,
@@ -1719,9 +2137,6 @@ class Join(IR):
1719
2137
  self.options = options
1720
2138
  self.children = (left, right)
1721
2139
  self._non_child_args = (self.left_on, self.right_on, self.options)
1722
- # TODO: Implement maintain_order
1723
- if options[5] != "none":
1724
- raise NotImplementedError("maintain_order not implemented yet")
1725
2140
 
1726
2141
  @staticmethod
1727
2142
  @cache
@@ -1770,6 +2185,9 @@ class Join(IR):
1770
2185
  right_rows: int,
1771
2186
  rg: plc.Column,
1772
2187
  right_policy: plc.copying.OutOfBoundsPolicy,
2188
+ *,
2189
+ left_primary: bool = True,
2190
+ stream: Stream,
1773
2191
  ) -> list[plc.Column]:
1774
2192
  """
1775
2193
  Reorder gather maps to satisfy polars join order restrictions.
@@ -1788,30 +2206,70 @@ class Join(IR):
1788
2206
  Right gather map
1789
2207
  right_policy
1790
2208
  Nullify policy for right map
2209
+ left_primary
2210
+ Whether to preserve the left input row order first, and which
2211
+ input stream to use for the primary sort.
2212
+ Defaults to True.
2213
+ stream
2214
+ CUDA stream used for device memory operations and kernel launches.
1791
2215
 
1792
2216
  Returns
1793
2217
  -------
1794
- list of reordered left and right gather maps.
2218
+ list[plc.Column]
2219
+ Reordered left and right gather maps.
1795
2220
 
1796
2221
  Notes
1797
2222
  -----
1798
- For a left join, the polars result preserves the order of the
1799
- left keys, and is stable wrt the right keys. For all other
1800
- joins, there is no order obligation.
2223
+ When ``left_primary`` is True, the pair of gather maps is stably sorted by
2224
+ the original row order of the left side, breaking ties by the right side.
2225
+ And vice versa when ``left_primary`` is False.
1801
2226
  """
1802
- init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
1803
- step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
1804
- left_order = plc.copying.gather(
1805
- plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
1806
- )
1807
- right_order = plc.copying.gather(
1808
- plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
2227
+ init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE, stream=stream)
2228
+ step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE, stream=stream)
2229
+
2230
+ (left_order_col,) = plc.copying.gather(
2231
+ plc.Table(
2232
+ [
2233
+ plc.filling.sequence(
2234
+ left_rows,
2235
+ init,
2236
+ step,
2237
+ stream=stream,
2238
+ )
2239
+ ]
2240
+ ),
2241
+ lg,
2242
+ left_policy,
2243
+ stream=stream,
2244
+ ).columns()
2245
+ (right_order_col,) = plc.copying.gather(
2246
+ plc.Table(
2247
+ [
2248
+ plc.filling.sequence(
2249
+ right_rows,
2250
+ init,
2251
+ step,
2252
+ stream=stream,
2253
+ )
2254
+ ]
2255
+ ),
2256
+ rg,
2257
+ right_policy,
2258
+ stream=stream,
2259
+ ).columns()
2260
+
2261
+ keys = (
2262
+ plc.Table([left_order_col, right_order_col])
2263
+ if left_primary
2264
+ else plc.Table([right_order_col, left_order_col])
1809
2265
  )
2266
+
1810
2267
  return plc.sorting.stable_sort_by_key(
1811
2268
  plc.Table([lg, rg]),
1812
- plc.Table([*left_order.columns(), *right_order.columns()]),
2269
+ keys,
1813
2270
  [plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
1814
2271
  [plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
2272
+ stream=stream,
1815
2273
  ).columns()
1816
2274
 
1817
2275
  @staticmethod
@@ -1822,31 +2280,35 @@ class Join(IR):
1822
2280
  left: bool = True,
1823
2281
  empty: bool = False,
1824
2282
  rename: Callable[[str], str] = lambda name: name,
2283
+ stream: Stream,
1825
2284
  ) -> list[Column]:
1826
2285
  if empty:
1827
2286
  return [
1828
2287
  Column(
1829
- plc.column_factories.make_empty_column(col.dtype.plc),
2288
+ plc.column_factories.make_empty_column(
2289
+ col.dtype.plc_type, stream=stream
2290
+ ),
1830
2291
  col.dtype,
1831
2292
  name=rename(col.name),
1832
2293
  )
1833
2294
  for col in template
1834
2295
  ]
1835
2296
 
1836
- columns = [
2297
+ result = [
1837
2298
  Column(new, col.dtype, name=rename(col.name))
1838
2299
  for new, col in zip(columns, template, strict=True)
1839
2300
  ]
1840
2301
 
1841
2302
  if left:
1842
- columns = [
2303
+ result = [
1843
2304
  col.sorted_like(orig)
1844
- for col, orig in zip(columns, template, strict=True)
2305
+ for col, orig in zip(result, template, strict=True)
1845
2306
  ]
1846
2307
 
1847
- return columns
2308
+ return result
1848
2309
 
1849
2310
  @classmethod
2311
+ @log_do_evaluate
1850
2312
  @nvtx_annotate_cudf_polars(message="Join")
1851
2313
  def do_evaluate(
1852
2314
  cls,
@@ -1862,14 +2324,21 @@ class Join(IR):
1862
2324
  ],
1863
2325
  left: DataFrame,
1864
2326
  right: DataFrame,
2327
+ *,
2328
+ context: IRExecutionContext,
1865
2329
  ) -> DataFrame:
1866
2330
  """Evaluate and return a dataframe."""
1867
- how, nulls_equal, zlice, suffix, coalesce, _ = options
2331
+ stream = get_joined_cuda_stream(
2332
+ context.get_cuda_stream, upstreams=(left.stream, right.stream)
2333
+ )
2334
+ how, nulls_equal, zlice, suffix, coalesce, maintain_order = options
1868
2335
  if how == "Cross":
1869
2336
  # Separate implementation, since cross_join returns the
1870
2337
  # result, not the gather maps
1871
2338
  if right.num_rows == 0:
1872
- left_cols = Join._build_columns([], left.columns, empty=True)
2339
+ left_cols = Join._build_columns(
2340
+ [], left.columns, empty=True, stream=stream
2341
+ )
1873
2342
  right_cols = Join._build_columns(
1874
2343
  [],
1875
2344
  right.columns,
@@ -1878,96 +2347,145 @@ class Join(IR):
1878
2347
  rename=lambda name: name
1879
2348
  if name not in left.column_names_set
1880
2349
  else f"{name}{suffix}",
2350
+ stream=stream,
2351
+ )
2352
+ result = DataFrame([*left_cols, *right_cols], stream=stream)
2353
+ else:
2354
+ columns = plc.join.cross_join(
2355
+ left.table, right.table, stream=stream
2356
+ ).columns()
2357
+ left_cols = Join._build_columns(
2358
+ columns[: left.num_columns], left.columns, stream=stream
2359
+ )
2360
+ right_cols = Join._build_columns(
2361
+ columns[left.num_columns :],
2362
+ right.columns,
2363
+ rename=lambda name: name
2364
+ if name not in left.column_names_set
2365
+ else f"{name}{suffix}",
2366
+ left=False,
2367
+ stream=stream,
2368
+ )
2369
+ result = DataFrame([*left_cols, *right_cols], stream=stream).slice(
2370
+ zlice
1881
2371
  )
1882
- return DataFrame([*left_cols, *right_cols])
1883
2372
 
1884
- columns = plc.join.cross_join(left.table, right.table).columns()
1885
- left_cols = Join._build_columns(
1886
- columns[: left.num_columns],
1887
- left.columns,
1888
- )
1889
- right_cols = Join._build_columns(
1890
- columns[left.num_columns :],
1891
- right.columns,
1892
- rename=lambda name: name
1893
- if name not in left.column_names_set
1894
- else f"{name}{suffix}",
1895
- left=False,
1896
- )
1897
- return DataFrame([*left_cols, *right_cols]).slice(zlice)
1898
- # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
1899
- left_on = DataFrame(broadcast(*(e.evaluate(left) for e in left_on_exprs)))
1900
- right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
1901
- null_equality = (
1902
- plc.types.NullEquality.EQUAL
1903
- if nulls_equal
1904
- else plc.types.NullEquality.UNEQUAL
1905
- )
1906
- join_fn, left_policy, right_policy = cls._joiners(how)
1907
- if right_policy is None:
1908
- # Semi join
1909
- lg = join_fn(left_on.table, right_on.table, null_equality)
1910
- table = plc.copying.gather(left.table, lg, left_policy)
1911
- result = DataFrame.from_table(table, left.column_names, left.dtypes)
1912
2373
  else:
1913
- if how == "Right":
1914
- # Right join is a left join with the tables swapped
1915
- left, right = right, left
1916
- left_on, right_on = right_on, left_on
1917
- lg, rg = join_fn(left_on.table, right_on.table, null_equality)
1918
- if how == "Left" or how == "Right":
1919
- # Order of left table is preserved
1920
- lg, rg = cls._reorder_maps(
1921
- left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
1922
- )
1923
- if coalesce:
1924
- if how == "Full":
1925
- # In this case, keys must be column references,
1926
- # possibly with dtype casting. We should use them in
1927
- # preference to the columns from the original tables.
1928
- left = left.with_columns(left_on.columns, replace_only=True)
1929
- right = right.with_columns(right_on.columns, replace_only=True)
1930
- else:
1931
- right = right.discard_columns(right_on.column_names_set)
1932
- left = DataFrame.from_table(
1933
- plc.copying.gather(left.table, lg, left_policy),
1934
- left.column_names,
1935
- left.dtypes,
2374
+ # how != "Cross"
2375
+ # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
2376
+ left_on = DataFrame(
2377
+ broadcast(*(e.evaluate(left) for e in left_on_exprs), stream=stream),
2378
+ stream=stream,
1936
2379
  )
1937
- right = DataFrame.from_table(
1938
- plc.copying.gather(right.table, rg, right_policy),
1939
- right.column_names,
1940
- right.dtypes,
2380
+ right_on = DataFrame(
2381
+ broadcast(*(e.evaluate(right) for e in right_on_exprs), stream=stream),
2382
+ stream=stream,
1941
2383
  )
1942
- if coalesce and how == "Full":
1943
- left = left.with_columns(
1944
- (
1945
- Column(
1946
- plc.replace.replace_nulls(left_col.obj, right_col.obj),
1947
- name=left_col.name,
1948
- dtype=left_col.dtype,
2384
+ null_equality = (
2385
+ plc.types.NullEquality.EQUAL
2386
+ if nulls_equal
2387
+ else plc.types.NullEquality.UNEQUAL
2388
+ )
2389
+ join_fn, left_policy, right_policy = cls._joiners(how)
2390
+ if right_policy is None:
2391
+ # Semi join
2392
+ lg = join_fn(left_on.table, right_on.table, null_equality, stream)
2393
+ table = plc.copying.gather(left.table, lg, left_policy, stream=stream)
2394
+ result = DataFrame.from_table(
2395
+ table, left.column_names, left.dtypes, stream=stream
2396
+ )
2397
+ else:
2398
+ if how == "Right":
2399
+ # Right join is a left join with the tables swapped
2400
+ left, right = right, left
2401
+ left_on, right_on = right_on, left_on
2402
+ maintain_order = Join.SWAPPED_ORDER[maintain_order]
2403
+
2404
+ lg, rg = join_fn(
2405
+ left_on.table, right_on.table, null_equality, stream=stream
2406
+ )
2407
+ if (
2408
+ how in ("Inner", "Left", "Right", "Full")
2409
+ and maintain_order != "none"
2410
+ ):
2411
+ lg, rg = cls._reorder_maps(
2412
+ left.num_rows,
2413
+ lg,
2414
+ left_policy,
2415
+ right.num_rows,
2416
+ rg,
2417
+ right_policy,
2418
+ left_primary=maintain_order.startswith("left"),
2419
+ stream=stream,
2420
+ )
2421
+ if coalesce:
2422
+ if how == "Full":
2423
+ # In this case, keys must be column references,
2424
+ # possibly with dtype casting. We should use them in
2425
+ # preference to the columns from the original tables.
2426
+
2427
+ # We need to specify `stream` here. We know that `{left,right}_on`
2428
+ # is valid on `stream`, which is ordered after `{left,right}.stream`.
2429
+ left = left.with_columns(
2430
+ left_on.columns, replace_only=True, stream=stream
1949
2431
  )
1950
- for left_col, right_col in zip(
1951
- left.select_columns(left_on.column_names_set),
1952
- right.select_columns(right_on.column_names_set),
1953
- strict=True,
2432
+ right = right.with_columns(
2433
+ right_on.columns, replace_only=True, stream=stream
1954
2434
  )
1955
- ),
1956
- replace_only=True,
2435
+ else:
2436
+ right = right.discard_columns(right_on.column_names_set)
2437
+ left = DataFrame.from_table(
2438
+ plc.copying.gather(left.table, lg, left_policy, stream=stream),
2439
+ left.column_names,
2440
+ left.dtypes,
2441
+ stream=stream,
1957
2442
  )
1958
- right = right.discard_columns(right_on.column_names_set)
1959
- if how == "Right":
1960
- # Undo the swap for right join before gluing together.
1961
- left, right = right, left
1962
- right = right.rename_columns(
1963
- {
1964
- name: f"{name}{suffix}"
1965
- for name in right.column_names
1966
- if name in left.column_names_set
1967
- }
1968
- )
1969
- result = left.with_columns(right.columns)
1970
- return result.slice(zlice)
2443
+ right = DataFrame.from_table(
2444
+ plc.copying.gather(right.table, rg, right_policy, stream=stream),
2445
+ right.column_names,
2446
+ right.dtypes,
2447
+ stream=stream,
2448
+ )
2449
+ if coalesce and how == "Full":
2450
+ left = left.with_columns(
2451
+ (
2452
+ Column(
2453
+ plc.replace.replace_nulls(
2454
+ left_col.obj, right_col.obj, stream=stream
2455
+ ),
2456
+ name=left_col.name,
2457
+ dtype=left_col.dtype,
2458
+ )
2459
+ for left_col, right_col in zip(
2460
+ left.select_columns(left_on.column_names_set),
2461
+ right.select_columns(right_on.column_names_set),
2462
+ strict=True,
2463
+ )
2464
+ ),
2465
+ replace_only=True,
2466
+ stream=stream,
2467
+ )
2468
+ right = right.discard_columns(right_on.column_names_set)
2469
+ if how == "Right":
2470
+ # Undo the swap for right join before gluing together.
2471
+ left, right = right, left
2472
+ right = right.rename_columns(
2473
+ {
2474
+ name: f"{name}{suffix}"
2475
+ for name in right.column_names
2476
+ if name in left.column_names_set
2477
+ }
2478
+ )
2479
+ result = left.with_columns(right.columns, stream=stream)
2480
+ result = result.slice(zlice)
2481
+
2482
+ # Join the original streams back into the result stream to ensure that the
2483
+ # deallocations (on the original streams) happen after the result is ready
2484
+ join_cuda_streams(
2485
+ downstreams=(left.stream, right.stream), upstreams=(result.stream,)
2486
+ )
2487
+
2488
+ return result
1971
2489
 
1972
2490
 
1973
2491
  class HStack(IR):
@@ -1992,18 +2510,23 @@ class HStack(IR):
1992
2510
  self.children = (df,)
1993
2511
 
1994
2512
  @classmethod
2513
+ @log_do_evaluate
1995
2514
  @nvtx_annotate_cudf_polars(message="HStack")
1996
2515
  def do_evaluate(
1997
2516
  cls,
1998
2517
  exprs: Sequence[expr.NamedExpr],
1999
2518
  should_broadcast: bool, # noqa: FBT001
2000
2519
  df: DataFrame,
2520
+ *,
2521
+ context: IRExecutionContext,
2001
2522
  ) -> DataFrame:
2002
2523
  """Evaluate and return a dataframe."""
2003
2524
  columns = [c.evaluate(df) for c in exprs]
2004
2525
  if should_broadcast:
2005
2526
  columns = broadcast(
2006
- *columns, target_length=df.num_rows if df.num_columns != 0 else None
2527
+ *columns,
2528
+ target_length=df.num_rows if df.num_columns != 0 else None,
2529
+ stream=df.stream,
2007
2530
  )
2008
2531
  else:
2009
2532
  # Polars ensures this is true, but let's make sure nothing
@@ -2014,7 +2537,7 @@ class HStack(IR):
2014
2537
  # never be turned into a pylibcudf Table with all columns
2015
2538
  # by the Select, which is why this is safe.
2016
2539
  assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs)
2017
- return df.with_columns(columns)
2540
+ return df.with_columns(columns, stream=df.stream)
2018
2541
 
2019
2542
 
2020
2543
  class Distinct(IR):
@@ -2057,6 +2580,7 @@ class Distinct(IR):
2057
2580
  }
2058
2581
 
2059
2582
  @classmethod
2583
+ @log_do_evaluate
2060
2584
  @nvtx_annotate_cudf_polars(message="Distinct")
2061
2585
  def do_evaluate(
2062
2586
  cls,
@@ -2065,6 +2589,8 @@ class Distinct(IR):
2065
2589
  zlice: Zlice | None,
2066
2590
  stable: bool, # noqa: FBT001
2067
2591
  df: DataFrame,
2592
+ *,
2593
+ context: IRExecutionContext,
2068
2594
  ) -> DataFrame:
2069
2595
  """Evaluate and return a dataframe."""
2070
2596
  if subset is None:
@@ -2079,6 +2605,7 @@ class Distinct(IR):
2079
2605
  indices,
2080
2606
  keep,
2081
2607
  plc.types.NullEquality.EQUAL,
2608
+ stream=df.stream,
2082
2609
  )
2083
2610
  else:
2084
2611
  distinct = (
@@ -2092,13 +2619,15 @@ class Distinct(IR):
2092
2619
  keep,
2093
2620
  plc.types.NullEquality.EQUAL,
2094
2621
  plc.types.NanEquality.ALL_EQUAL,
2622
+ df.stream,
2095
2623
  )
2096
2624
  # TODO: Is this sortedness setting correct
2097
2625
  result = DataFrame(
2098
2626
  [
2099
2627
  Column(new, name=old.name, dtype=old.dtype).sorted_like(old)
2100
2628
  for new, old in zip(table.columns(), df.columns, strict=True)
2101
- ]
2629
+ ],
2630
+ stream=df.stream,
2102
2631
  )
2103
2632
  if keys_sorted or stable:
2104
2633
  result = result.sorted_like(df)
@@ -2147,6 +2676,7 @@ class Sort(IR):
2147
2676
  self.children = (df,)
2148
2677
 
2149
2678
  @classmethod
2679
+ @log_do_evaluate
2150
2680
  @nvtx_annotate_cudf_polars(message="Sort")
2151
2681
  def do_evaluate(
2152
2682
  cls,
@@ -2156,17 +2686,24 @@ class Sort(IR):
2156
2686
  stable: bool, # noqa: FBT001
2157
2687
  zlice: Zlice | None,
2158
2688
  df: DataFrame,
2689
+ *,
2690
+ context: IRExecutionContext,
2159
2691
  ) -> DataFrame:
2160
2692
  """Evaluate and return a dataframe."""
2161
- sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
2693
+ sort_keys = broadcast(
2694
+ *(k.evaluate(df) for k in by), target_length=df.num_rows, stream=df.stream
2695
+ )
2162
2696
  do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
2163
2697
  table = do_sort(
2164
2698
  df.table,
2165
2699
  plc.Table([k.obj for k in sort_keys]),
2166
2700
  list(order),
2167
2701
  list(null_order),
2702
+ stream=df.stream,
2703
+ )
2704
+ result = DataFrame.from_table(
2705
+ table, df.column_names, df.dtypes, stream=df.stream
2168
2706
  )
2169
- result = DataFrame.from_table(table, df.column_names, df.dtypes)
2170
2707
  first_key = sort_keys[0]
2171
2708
  name = by[0].name
2172
2709
  first_key_in_result = (
@@ -2197,8 +2734,11 @@ class Slice(IR):
2197
2734
  self.children = (df,)
2198
2735
 
2199
2736
  @classmethod
2737
+ @log_do_evaluate
2200
2738
  @nvtx_annotate_cudf_polars(message="Slice")
2201
- def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
2739
+ def do_evaluate(
2740
+ cls, offset: int, length: int, df: DataFrame, *, context: IRExecutionContext
2741
+ ) -> DataFrame:
2202
2742
  """Evaluate and return a dataframe."""
2203
2743
  return df.slice((offset, length))
2204
2744
 
@@ -2218,10 +2758,15 @@ class Filter(IR):
2218
2758
  self.children = (df,)
2219
2759
 
2220
2760
  @classmethod
2761
+ @log_do_evaluate
2221
2762
  @nvtx_annotate_cudf_polars(message="Filter")
2222
- def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
2763
+ def do_evaluate(
2764
+ cls, mask_expr: expr.NamedExpr, df: DataFrame, *, context: IRExecutionContext
2765
+ ) -> DataFrame:
2223
2766
  """Evaluate and return a dataframe."""
2224
- (mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
2767
+ (mask,) = broadcast(
2768
+ mask_expr.evaluate(df), target_length=df.num_rows, stream=df.stream
2769
+ )
2225
2770
  return df.filter(mask)
2226
2771
 
2227
2772
 
@@ -2237,14 +2782,19 @@ class Projection(IR):
2237
2782
  self.children = (df,)
2238
2783
 
2239
2784
  @classmethod
2785
+ @log_do_evaluate
2240
2786
  @nvtx_annotate_cudf_polars(message="Projection")
2241
- def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
2787
+ def do_evaluate(
2788
+ cls, schema: Schema, df: DataFrame, *, context: IRExecutionContext
2789
+ ) -> DataFrame:
2242
2790
  """Evaluate and return a dataframe."""
2243
2791
  # This can reorder things.
2244
2792
  columns = broadcast(
2245
- *(df.column_map[name] for name in schema), target_length=df.num_rows
2793
+ *(df.column_map[name] for name in schema),
2794
+ target_length=df.num_rows,
2795
+ stream=df.stream,
2246
2796
  )
2247
- return DataFrame(columns)
2797
+ return DataFrame(columns, stream=df.stream)
2248
2798
 
2249
2799
 
2250
2800
  class MergeSorted(IR):
@@ -2270,24 +2820,40 @@ class MergeSorted(IR):
2270
2820
  self._non_child_args = (key,)
2271
2821
 
2272
2822
  @classmethod
2823
+ @log_do_evaluate
2273
2824
  @nvtx_annotate_cudf_polars(message="MergeSorted")
2274
- def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
2825
+ def do_evaluate(
2826
+ cls, key: str, *dfs: DataFrame, context: IRExecutionContext
2827
+ ) -> DataFrame:
2275
2828
  """Evaluate and return a dataframe."""
2829
+ stream = get_joined_cuda_stream(
2830
+ context.get_cuda_stream, upstreams=[df.stream for df in dfs]
2831
+ )
2276
2832
  left, right = dfs
2277
2833
  right = right.discard_columns(right.column_names_set - left.column_names_set)
2278
2834
  on_col_left = left.select_columns({key})[0]
2279
2835
  on_col_right = right.select_columns({key})[0]
2280
- return DataFrame.from_table(
2836
+ result = DataFrame.from_table(
2281
2837
  plc.merge.merge(
2282
2838
  [right.table, left.table],
2283
2839
  [left.column_names.index(key), right.column_names.index(key)],
2284
2840
  [on_col_left.order, on_col_right.order],
2285
2841
  [on_col_left.null_order, on_col_right.null_order],
2842
+ stream=stream,
2286
2843
  ),
2287
2844
  left.column_names,
2288
2845
  left.dtypes,
2846
+ stream=stream,
2847
+ )
2848
+
2849
+ # Join the original streams back into the result stream to ensure that the
2850
+ # deallocations (on the original streams) happen after the result is ready
2851
+ join_cuda_streams(
2852
+ downstreams=[df.stream for df in dfs], upstreams=(result.stream,)
2289
2853
  )
2290
2854
 
2855
+ return result
2856
+
2291
2857
 
2292
2858
  class MapFunction(IR):
2293
2859
  """Apply some function to a dataframe."""
@@ -2347,7 +2913,7 @@ class MapFunction(IR):
2347
2913
  index = frozenset(indices)
2348
2914
  pivotees = [name for name in df.schema if name not in index]
2349
2915
  if not all(
2350
- dtypes.can_cast(df.schema[p].plc, self.schema[value_name].plc)
2916
+ dtypes.can_cast(df.schema[p].plc_type, self.schema[value_name].plc_type)
2351
2917
  for p in pivotees
2352
2918
  ):
2353
2919
  raise NotImplementedError(
@@ -2390,9 +2956,16 @@ class MapFunction(IR):
2390
2956
  )
2391
2957
 
2392
2958
  @classmethod
2959
+ @log_do_evaluate
2393
2960
  @nvtx_annotate_cudf_polars(message="MapFunction")
2394
2961
  def do_evaluate(
2395
- cls, schema: Schema, name: str, options: Any, df: DataFrame
2962
+ cls,
2963
+ schema: Schema,
2964
+ name: str,
2965
+ options: Any,
2966
+ df: DataFrame,
2967
+ *,
2968
+ context: IRExecutionContext,
2396
2969
  ) -> DataFrame:
2397
2970
  """Evaluate and return a dataframe."""
2398
2971
  if name == "rechunk":
@@ -2409,7 +2982,10 @@ class MapFunction(IR):
2409
2982
  index = df.column_names.index(to_explode)
2410
2983
  subset = df.column_names_set - {to_explode}
2411
2984
  return DataFrame.from_table(
2412
- plc.lists.explode_outer(df.table, index), df.column_names, df.dtypes
2985
+ plc.lists.explode_outer(df.table, index, stream=df.stream),
2986
+ df.column_names,
2987
+ df.dtypes,
2988
+ stream=df.stream,
2413
2989
  ).sorted_like(df, subset=subset)
2414
2990
  elif name == "unpivot":
2415
2991
  (
@@ -2423,7 +2999,7 @@ class MapFunction(IR):
2423
2999
  index_columns = [
2424
3000
  Column(tiled, name=name, dtype=old.dtype)
2425
3001
  for tiled, name, old in zip(
2426
- plc.reshape.tile(selected.table, npiv).columns(),
3002
+ plc.reshape.tile(selected.table, npiv, stream=df.stream).columns(),
2427
3003
  indices,
2428
3004
  selected.columns,
2429
3005
  strict=True,
@@ -2434,18 +3010,23 @@ class MapFunction(IR):
2434
3010
  [
2435
3011
  plc.Column.from_arrow(
2436
3012
  pl.Series(
2437
- values=pivotees, dtype=schema[variable_name].polars
2438
- )
3013
+ values=pivotees, dtype=schema[variable_name].polars_type
3014
+ ),
3015
+ stream=df.stream,
2439
3016
  )
2440
3017
  ]
2441
3018
  ),
2442
3019
  df.num_rows,
3020
+ stream=df.stream,
2443
3021
  ).columns()
2444
3022
  value_column = plc.concatenate.concatenate(
2445
3023
  [
2446
- df.column_map[pivotee].astype(schema[value_name]).obj
3024
+ df.column_map[pivotee]
3025
+ .astype(schema[value_name], stream=df.stream)
3026
+ .obj
2447
3027
  for pivotee in pivotees
2448
- ]
3028
+ ],
3029
+ stream=df.stream,
2449
3030
  )
2450
3031
  return DataFrame(
2451
3032
  [
@@ -2454,22 +3035,23 @@ class MapFunction(IR):
2454
3035
  variable_column, name=variable_name, dtype=schema[variable_name]
2455
3036
  ),
2456
3037
  Column(value_column, name=value_name, dtype=schema[value_name]),
2457
- ]
3038
+ ],
3039
+ stream=df.stream,
2458
3040
  )
2459
3041
  elif name == "row_index":
2460
3042
  col_name, offset = options
2461
3043
  dtype = schema[col_name]
2462
- step = plc.Scalar.from_py(1, dtype.plc)
2463
- init = plc.Scalar.from_py(offset, dtype.plc)
3044
+ step = plc.Scalar.from_py(1, dtype.plc_type, stream=df.stream)
3045
+ init = plc.Scalar.from_py(offset, dtype.plc_type, stream=df.stream)
2464
3046
  index_col = Column(
2465
- plc.filling.sequence(df.num_rows, init, step),
3047
+ plc.filling.sequence(df.num_rows, init, step, stream=df.stream),
2466
3048
  is_sorted=plc.types.Sorted.YES,
2467
3049
  order=plc.types.Order.ASCENDING,
2468
3050
  null_order=plc.types.NullOrder.AFTER,
2469
3051
  name=col_name,
2470
3052
  dtype=dtype,
2471
3053
  )
2472
- return DataFrame([index_col, *df.columns])
3054
+ return DataFrame([index_col, *df.columns], stream=df.stream)
2473
3055
  else:
2474
3056
  raise AssertionError("Should never be reached") # pragma: no cover
2475
3057
 
@@ -2490,16 +3072,33 @@ class Union(IR):
2490
3072
  schema = self.children[0].schema
2491
3073
 
2492
3074
  @classmethod
3075
+ @log_do_evaluate
2493
3076
  @nvtx_annotate_cudf_polars(message="Union")
2494
- def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
3077
+ def do_evaluate(
3078
+ cls, zlice: Zlice | None, *dfs: DataFrame, context: IRExecutionContext
3079
+ ) -> DataFrame:
2495
3080
  """Evaluate and return a dataframe."""
3081
+ stream = get_joined_cuda_stream(
3082
+ context.get_cuda_stream, upstreams=[df.stream for df in dfs]
3083
+ )
3084
+
2496
3085
  # TODO: only evaluate what we need if we have a slice?
2497
- return DataFrame.from_table(
2498
- plc.concatenate.concatenate([df.table for df in dfs]),
3086
+ result = DataFrame.from_table(
3087
+ plc.concatenate.concatenate([df.table for df in dfs], stream=stream),
2499
3088
  dfs[0].column_names,
2500
3089
  dfs[0].dtypes,
3090
+ stream=stream,
2501
3091
  ).slice(zlice)
2502
3092
 
3093
+ # now join the original streams *back* to the new result stream
3094
+ # to ensure that the deallocations (on the original streams)
3095
+ # happen after the result is ready
3096
+ join_cuda_streams(
3097
+ downstreams=[df.stream for df in dfs], upstreams=(result.stream,)
3098
+ )
3099
+
3100
+ return result
3101
+
2503
3102
 
2504
3103
  class HConcat(IR):
2505
3104
  """Concatenate dataframes horizontally."""
@@ -2519,7 +3118,9 @@ class HConcat(IR):
2519
3118
  self.children = children
2520
3119
 
2521
3120
  @staticmethod
2522
- def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
3121
+ def _extend_with_nulls(
3122
+ table: plc.Table, *, nrows: int, stream: Stream
3123
+ ) -> plc.Table:
2523
3124
  """
2524
3125
  Extend a table with nulls.
2525
3126
 
@@ -2529,6 +3130,8 @@ class HConcat(IR):
2529
3130
  Table to extend
2530
3131
  nrows
2531
3132
  Number of additional rows
3133
+ stream
3134
+ CUDA stream used for device memory operations and kernel launches
2532
3135
 
2533
3136
  Returns
2534
3137
  -------
@@ -2539,46 +3142,69 @@ class HConcat(IR):
2539
3142
  table,
2540
3143
  plc.Table(
2541
3144
  [
2542
- plc.Column.all_null_like(column, nrows)
3145
+ plc.Column.all_null_like(column, nrows, stream=stream)
2543
3146
  for column in table.columns()
2544
3147
  ]
2545
3148
  ),
2546
- ]
3149
+ ],
3150
+ stream=stream,
2547
3151
  )
2548
3152
 
2549
3153
  @classmethod
3154
+ @log_do_evaluate
2550
3155
  @nvtx_annotate_cudf_polars(message="HConcat")
2551
3156
  def do_evaluate(
2552
3157
  cls,
2553
3158
  should_broadcast: bool, # noqa: FBT001
2554
3159
  *dfs: DataFrame,
3160
+ context: IRExecutionContext,
2555
3161
  ) -> DataFrame:
2556
3162
  """Evaluate and return a dataframe."""
3163
+ stream = get_joined_cuda_stream(
3164
+ context.get_cuda_stream, upstreams=[df.stream for df in dfs]
3165
+ )
3166
+
2557
3167
  # Special should_broadcast case.
2558
3168
  # Used to recombine decomposed expressions
2559
3169
  if should_broadcast:
2560
- return DataFrame(
2561
- broadcast(*itertools.chain.from_iterable(df.columns for df in dfs))
3170
+ result = DataFrame(
3171
+ broadcast(
3172
+ *itertools.chain.from_iterable(df.columns for df in dfs),
3173
+ stream=stream,
3174
+ ),
3175
+ stream=stream,
2562
3176
  )
2563
-
2564
- max_rows = max(df.num_rows for df in dfs)
2565
- # Horizontal concatenation extends shorter tables with nulls
2566
- return DataFrame(
2567
- itertools.chain.from_iterable(
2568
- df.columns
2569
- for df in (
2570
- df
2571
- if df.num_rows == max_rows
2572
- else DataFrame.from_table(
2573
- cls._extend_with_nulls(df.table, nrows=max_rows - df.num_rows),
2574
- df.column_names,
2575
- df.dtypes,
3177
+ else:
3178
+ max_rows = max(df.num_rows for df in dfs)
3179
+ # Horizontal concatenation extends shorter tables with nulls
3180
+ result = DataFrame(
3181
+ itertools.chain.from_iterable(
3182
+ df.columns
3183
+ for df in (
3184
+ df
3185
+ if df.num_rows == max_rows
3186
+ else DataFrame.from_table(
3187
+ cls._extend_with_nulls(
3188
+ df.table, nrows=max_rows - df.num_rows, stream=stream
3189
+ ),
3190
+ df.column_names,
3191
+ df.dtypes,
3192
+ stream=stream,
3193
+ )
3194
+ for df in dfs
2576
3195
  )
2577
- for df in dfs
2578
- )
3196
+ ),
3197
+ stream=stream,
2579
3198
  )
3199
+
3200
+ # Join the original streams back into the result stream to ensure that the
3201
+ # deallocations (on the original streams) happen after the result is ready
3202
+ join_cuda_streams(
3203
+ downstreams=[df.stream for df in dfs], upstreams=(result.stream,)
2580
3204
  )
2581
3205
 
3206
+ return result
3207
+
2582
3208
 
2583
3209
  class Empty(IR):
2584
3210
  """Represents an empty DataFrame with a known schema."""
@@ -2592,16 +3218,23 @@ class Empty(IR):
2592
3218
  self.children = ()
2593
3219
 
2594
3220
  @classmethod
3221
+ @log_do_evaluate
2595
3222
  @nvtx_annotate_cudf_polars(message="Empty")
2596
- def do_evaluate(cls, schema: Schema) -> DataFrame: # pragma: no cover
3223
+ def do_evaluate(
3224
+ cls, schema: Schema, *, context: IRExecutionContext
3225
+ ) -> DataFrame: # pragma: no cover
2597
3226
  """Evaluate and return a dataframe."""
3227
+ stream = context.get_cuda_stream()
2598
3228
  return DataFrame(
2599
3229
  [
2600
3230
  Column(
2601
- plc.column_factories.make_empty_column(dtype.plc),
3231
+ plc.column_factories.make_empty_column(
3232
+ dtype.plc_type, stream=stream
3233
+ ),
2602
3234
  dtype=dtype,
2603
3235
  name=name,
2604
3236
  )
2605
3237
  for name, dtype in schema.items()
2606
- ]
3238
+ ],
3239
+ stream=stream,
2607
3240
  )